比如现在有一个特别大的TXT文件,每行都是训练数据和标签。
在以前的训练中,大多数情况下都是直接读取到内存,保存在列表里面,按照 batch_size 放入模型训练。
大概的读取代码如下:

  1. file = './datas/THUCNews/train.txt'
  2. with open(file, encoding="utf-8") as f:
  3. sentences_and_labels = [line for line in f.readlines()]
  4. f.close()
  5. # 前几句
  6. sentences_and_labels[0:10]

如果这个文件特别大,一次性读不下怎么办呢?这时候我们就要考虑一次读取一批数据,不要太占用硬盘了。

data.IterableDataset

Pytorch 为我们提供了 torch.utils.data.IterableDataset 方法。一般来说,当我们不知道文件有多大(比如特别大的时候),而且这个文件是可分割可迭代的,那么就可以使用这个方法来读取文件。

数据集说明

比如著名的新闻标题分类数据集,其格式参考如下:
image.png
有十种分类,每一行都是一条训练、测试数据。符合我们的要求。

构建数据集(单进程)

这里我们先尝试单进程读取数据集。
首先构建数据集:

  1. import torch
  2. class THUnewsDataset(torch.utils.data.IterableDataset):
  3. def __init__(self, file_path):
  4. super(THUnewsDataset).__init__()
  5. self.file_path = file_path
  6. def _tar_line_iterator(self, file_path: str) -> str:
  7. """
  8. 读取指定数据集。每次读取一行,返回迭代器。
  9. Args:
  10. file_path (str): 文件路径
  11. Returns:
  12. str: 每行数据。比如:"中国人民公安大学2012年硕士研究生目录及书目 3"
  13. Yields:
  14. Iterator[str]: 每行数据的迭代器。
  15. """
  16. with open(file_path, encoding="utf-8") as f:
  17. while True:
  18. line = f.readline()
  19. if not line:
  20. break
  21. yield f.readline()
  22. f.close()
  23. def __iter__(self):
  24. yield from self._tar_line_iterator(self.file_path)

简单来说:

  • 首先定义了数据集类 THUnewsDataset
    • 继承了 torch.utils.data.IterableDataset 方法,因为我们的新闻数据集是可迭代的数据集
  • 重写了 __iter__ 方法,返回了迭代器

    简单测试

    ```python

    import sys

print(sys.getsizeof(dataloader)) 64

print(sys.getsizeof(THUnewsDataset(file))) 64

print(THUnewsDataset(file).iter())

print(iter(THUnewsDataset(file)))

print(next(iter(THUnewsDataset(file)))) 中国人民公安大学2012年硕士研究生目录及书目 3 ```

Dataloader

```python file = ‘./datas/THUCNews/test.txt’ dataloader = torch.utils.data.DataLoader(THUnewsDataset(file), batch_size=128)

all_data = []

for i in dataloader: all_data.append(i) ``` 可以看看结果

遗留问题

虽然内存是不爆炸了,但是显存会不会爆炸呢?
我们不知道变成 Tensor 进入显存以后,是不是训练一批清空一批的,所以要注意这点。
不过按道理会自动清空的。