参考来源:
torch.utils.data 官方手册(1.9.0)手册
torch.utils.data 官方手册中文(1.4)翻译
CSDN:pytorch 实现自由的数据读取-torch.utils.data 的学习
CSDN:pytorch 技巧 五: 自定义数据集 torch.utils.data.DataLoader 及 Dataset 的使用
CSDN:pytorch 学习笔记六:torch.utils.data 下的 TensorDataset 和 DataLoader 的使用

1. torch.utils.data

torch.utils.data 主要包括以下三个类:

  1. class torch.utils.data.Dataset

作用:创建数据集,有 __getitem__(self, index) 函数来根据索引序号获取图片和标签,有 __len__(self) 函数来获取数据集的长度。
其他的数据集类必须是 torch.utils.data.Dataset 的子类,比如说 torchvision.ImageFolder

  1. class torch.utils.data.sampler.Sampler(data_source)

作用:创建一个采样器,class torch.utils.data.sampler.Sampler 是所有的 Sampler 的基类;其中,iter(self) 函数来获取一个迭代器,对数据集中元素的索引进行迭代,len(self) 方法返回迭代器中包含元素的长度。

  1. class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)

2. 数据传递机制

在 pytorch 中数据传递按以下顺序:

  1. 创建 datasets ,也就是所需要读取的数据集。
  2. datasets 传入 DataLoader
  3. DataLoader 迭代产生训练数据提供给模型。

3. torch.utils.data.Dataset

Pytorch 提供两种数据集: Map(映射)式数据集和 Iterable(迭代)式数据集。
其中 Map 式数据集继承 torch.utils.data.DatasetIterable 式数据集继承 torch.utils.data.IterableDataset
本文只介绍 Map 式数据集。一个 Map 式的数据集必须要重写 __getitem__(self, index)__len__(self) 两个方法,用来表示从索引到样本的映射(Map)。 __getitem__(self, index) 按索引映射到对应的数据, __len__(self) 则会返回这个数据集的长度。
基本格式如下:

  1. import torch.utils.data as data
  2. class VOCDetection(data.Dataset):
  3. '''
  4. 必须继承data.Dataset类
  5. '''
  6. def __init__(self):
  7. '''
  8. 在这里进行初始化,一般是初始化文件路径或文件列表
  9. '''
  10. pass
  11. def __getitem__(self, index):
  12. '''
  13. 1. 按照index,读取文件中对应的数据 (读取一个数据!!!!我们常读取的数据是图片,一般我们送入模型的数据成批的,但在这里只是读取一张图片,成批后面会说到)
  14. 2. 对读取到的数据进行数据增强 (数据增强是深度学习中经常用到的,可以提高模型的泛化能力)
  15. 3. 返回数据对 (一般我们要返回 图片,对应的标签) 在这里因为我没有写完整的代码,返回值用 0 代替
  16. '''
  17. return 0
  18. def __len__(self):
  19. '''
  20. 返回数据集的长度
  21. '''
  22. return 0

可直接运行的例子:

  1. import torch.utils.data as data
  2. import numpy as np
  3. x = np.array(range(80)).reshape(8, 10) # 模拟输入, 8个样本,每个样本长度为10
  4. y = np.array(range(8)) # 模拟对应样本的标签, 8个标签
  5. class Mydataset(data.Dataset):
  6. def __init__(self, x, y):
  7. self.x = x
  8. self.y = y
  9. self.idx = list()
  10. for item in x:
  11. self.idx.append(item)
  12. pass
  13. def __getitem__(self, index):
  14. input_data = self.idx[index] #可继续进行数据增强,这里没有进行数据增强操作
  15. target = self.y[index]
  16. return input_data, target
  17. def __len__(self):
  18. return len(self.idx)
  19. datasets = Mydataset(x, y) # 初始化
  20. print(datasets.__len__()) # 调用 __len__() 返回数据的长度
  21. for i in range(len(y)):
  22. input_data, target = datasets.__getitem__(i) # 调用 __getitem__(index) 返回读取的数据对
  23. print('input_data%d =' % i, input_data)
  24. print('target%d = ' % i, target)

结果如下:

  1. 8
  2. input_data0 = [0 1 2 3 4 5 6 7 8 9]
  3. target0 = 0
  4. input_data1 = [10 11 12 13 14 15 16 17 18 19]
  5. target1 = 1
  6. input_data2 = [20 21 22 23 24 25 26 27 28 29]
  7. target2 = 2
  8. input_data3 = [30 31 32 33 34 35 36 37 38 39]
  9. target3 = 3
  10. input_data4 = [40 41 42 43 44 45 46 47 48 49]
  11. target4 = 4
  12. input_data5 = [50 51 52 53 54 55 56 57 58 59]
  13. target5 = 5
  14. input_data6 = [60 61 62 63 64 65 66 67 68 69]
  15. target6 = 6
  16. input_data7 = [70 71 72 73 74 75 76 77 78 79]
  17. target7 = 7
  18. Process finished with exit code 0

4. torch.utils.data.DataLoader

PyTorch 中数据读取的一个重要接口是 torch.utils.data.DataLoader
该接口主要用来将自定义的数据读取接口的输出或者 PyTorch 已有的数据读取接口的输入按照 batch_size 封装成 Tensor ,后续只需要再包装成 Variable 即可作为模型的输入。

  1. data_iter=torch.utils.data.DataLoader(dataset, batch_size=1,
  2. shuffle=False, sampler=None,
  3. batch_sampler=None, num_workers=0,
  4. collate_fn=None, pin_memory=False,
  5. drop_last=False, timeout=0,
  6. worker_init_fn=None,
  7. multiprocessing_context=None)

torch.utils.data.DataLoader() 的可用参数如下:
dataset(Dataset):数据读取接口,该输入是 torch.utils.data.Dataset 类的对象(或者继承自该类的自定义类的对象)。
batch_size(int, 可选):批训练数据量的大小,根据具体情况设置即可。一般为 2 的 N 次方(默认:1)。
shuffle(bool, 可选):是否打乱数据,一般在训练数据中会采用(默认:False)。
sampler(Sampler, 可选):从数据集中提取样本的策略。如果指定,“shuffle”必须为 False 。我没有用过,不太了解。
batch_sampler(Sampler, 可选):和 batch_size、shuffle 等参数互斥,一般用默认。
num_workers(int,可选):这个参数必须大于等于 0 ,为 0 时默认使用主线程读取数据,其他大于 0 的数表示通过多个进程来读取数据,可以加快数据读取速度,一般设置为 2 的 N 次方,且小于 batch_size(默认:0)。
collate_fn(callable, 可选):合并样本清单以形成小批量。用来处理不同情况下的输入 dataset 的封装。
pin_memory (bool, 可选):如果设置为 True,那么 dataloader 将会在返回它们之前,将 tensors 拷贝到 CUDA 中的固定内存中.
drop_last (bool, 可选):如果数据集大小不能被批大小整除,则设置为“True”以除去最后一个未完成的批。如果“False”那么最后一批将较小(默认:False)。
timeout(numeric, 可选):设置数据读取时间限制,超过这个时间还没读取到数据的话就会报错。(默认:0)
worker_init_fn (callable, 可选):每个 worker 初始化函数(默认:None)

可直接运行的例子:

  1. import torch.utils.data as data
  2. import numpy as np
  3. x = np.array(range(80)).reshape(8, 10) # 模拟输入, 8个样本,每个样本长度为10
  4. y = np.array(range(8)) # 模拟对应样本的标签, 8个标签
  5. class Mydataset(data.Dataset):
  6. def __init__(self, x, y):
  7. self.x = x
  8. self.y = y
  9. self.idx = list()
  10. for item in x:
  11. self.idx.append(item)
  12. pass
  13. def __getitem__(self, index):
  14. input_data = self.idx[index]
  15. target = self.y[index]
  16. return input_data, target
  17. def __len__(self):
  18. return len(self.idx)
  19. if __name__ ==('__main__'):
  20. datasets = Mydataset(x, y) # 初始化
  21. dataloader = data.DataLoader(datasets, batch_size=4, num_workers=2)
  22. for i, (input_data, target) in enumerate(dataloader):
  23. print('input_data%d' % i, input_data)
  24. print('target%d' % i, target)

结果如下:(注意看类别,DataLoader 把数据封装为 Tensor)

  1. input_data0 tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
  2. [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
  3. [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
  4. [30, 31, 32, 33, 34, 35, 36, 37, 38, 39]], dtype=torch.int32)
  5. target0 tensor([0, 1, 2, 3], dtype=torch.int32)
  6. input_data1 tensor([[40, 41, 42, 43, 44, 45, 46, 47, 48, 49],
  7. [50, 51, 52, 53, 54, 55, 56, 57, 58, 59],
  8. [60, 61, 62, 63, 64, 65, 66, 67, 68, 69],
  9. [70, 71, 72, 73, 74, 75, 76, 77, 78, 79]], dtype=torch.int32)
  10. target1 tensor([4, 5, 6, 7], dtype=torch.int32)
  11. Process finished with exit code 0

5. sampler

上面对一些重要常用的参数做了说明,其中有一个参数是 sampler ,下面我们对它有哪些具体取值再做一下说明。只列出几个常用的取值:

  • torch.utils.data.sampler.SequentialSampler(dataset):样本元素按顺序采样,始终以相同的顺序。
  • torch.utils.data.sampler.RandomSampler(dataset):样本元素随机采样,没有替换。
  • torch.utils.data.sampler.SubsetRandomSampler(indices):样本元素从指定的索引列表中随机抽取,没有替换。

6. TensorDataset

对给定的 tensor 数据(样本和标签),将它们包装成 dataset
注意:如果是 numpyarray,或者 PandasDataFrame 需要先转换成 Tensor

  1. '''
  2. data_tensor (Tensor) - 样本数据
  3. target_tensor (Tensor) - 样本目标(标签)
  4. '''
  5. dataset=torch.utils.data.TensorDataset(data_tensor, target_tensor)

下面举个例子:

  • 我们先定义一下样本数据和标签数据,一共有 1000 个样本 ```python import torch import numpy as np

num_inputs = 2 num_examples = 1000 true_w = [2, -3.4] true_b = 4.2

features = torch.tensor(np.random.normal(0, 1, (num_examples, num_inputs)), dtype=torch.float)

labels = true_w[0] features[:, 0] + \ true_w[1] features[:, 1] + true_b

labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float)

print(features.shape) print(labels.shape)

‘’’ 输出:torch.Size([1000, 2]) torch.Size([1000]) ‘’’

  1. - 然后我们使用 `TensorDataset` 来生成数据集
  2. ```python
  3. import torch.utils.data as Data
  4. # 将训练数据的特征和标签组合
  5. dataset = Data.TensorDataset(features, labels)