比如现在有一个特别大的TXT文件,每行都是训练数据和标签。
在以前的训练中,大多数情况下都是直接读取到内存,保存在列表里面,按照 batch_size 放入模型训练。
大概的读取代码如下:
file = './datas/THUCNews/train.txt'with open(file, encoding="utf-8") as f:sentences_and_labels = [line for line in f.readlines()]f.close()# 前几句sentences_and_labels[0:10]
如果这个文件特别大,一次性读不下怎么办呢?这时候我们就要考虑一次读取一批数据,不要太占用硬盘了。
data.IterableDataset
Pytorch 为我们提供了 torch.utils.data.IterableDataset 方法。一般来说,当我们不知道文件有多大(比如特别大的时候),而且这个文件是可分割可迭代的,那么就可以使用这个方法来读取文件。
数据集说明
比如著名的新闻标题分类数据集,其格式参考如下:
有十种分类,每一行都是一条训练、测试数据。符合我们的要求。
构建数据集(单进程)
这里我们先尝试单进程读取数据集。
首先构建数据集:
import torchclass THUnewsDataset(torch.utils.data.IterableDataset):def __init__(self, file_path):super(THUnewsDataset).__init__()self.file_path = file_pathdef _tar_line_iterator(self, file_path: str) -> str:"""读取指定数据集。每次读取一行,返回迭代器。Args:file_path (str): 文件路径Returns:str: 每行数据。比如:"中国人民公安大学2012年硕士研究生目录及书目 3"Yields:Iterator[str]: 每行数据的迭代器。"""with open(file_path, encoding="utf-8") as f:while True:line = f.readline()if not line:breakyield f.readline()f.close()def __iter__(self):yield from self._tar_line_iterator(self.file_path)
简单来说:
- 首先定义了数据集类
THUnewsDataset- 继承了
torch.utils.data.IterableDataset方法,因为我们的新闻数据集是可迭代的数据集
- 继承了
- 重写了
__iter__方法,返回了迭代器简单测试
```pythonimport sys
print(sys.getsizeof(dataloader)) 64
print(sys.getsizeof(THUnewsDataset(file))) 64
print(THUnewsDataset(file).iter())
print(iter(THUnewsDataset(file)))
print(next(iter(THUnewsDataset(file)))) 中国人民公安大学2012年硕士研究生目录及书目 3 ```
Dataloader
```python file = ‘./datas/THUCNews/test.txt’ dataloader = torch.utils.data.DataLoader(THUnewsDataset(file), batch_size=128)
all_data = []
for i in dataloader: all_data.append(i) ``` 可以看看结果
遗留问题
虽然内存是不爆炸了,但是显存会不会爆炸呢?
我们不知道变成 Tensor 进入显存以后,是不是训练一批清空一批的,所以要注意这点。
不过按道理会自动清空的。
