数据下载和预处理一直都是机器学习,深度学习实际项目中最耗时又最重要的任务之一,往往占据了项目的大部分时间。好在Pytorch提供了专门的数据下载,数据处理包,学会使用它们,能极大的提高我们的开发效率和数据质量。

概述

  1. Dataset类
    任何自定义的数据集类都必须继承自PyTorch的数据集类。自定义的类必须实现两个函数:__len__(self),__getitem__任何和Dataset类表现类似的自定义类都应和下面的代码类似
  1. class FirstDataset(data.Dataset):#需要继承data.Dataset
  2. def __init__(self,root_dir,size=(16,16)):
  3. # TODO
  4. # 1. 初始化文件路径或文件名列表。
  5. # 2. 设置一些基本参数
  6. #也就是在这个模块里,我们所做的工作就是初始化该类的一些基本参数。
  7. self.files = os.listdir(root_dir)
  8. self.size = size
  9. def __getitem__(self, index):
  10. #1。从文件中读取一个数据(例如,使用numpy.fromfile,PIL.Image.open)。
  11. #2。预处理数据(例如torchvision.Transform)。
  12. #3。返回数据对(例如图像和标签)。
  13. #这里需要注意的是,使用index索引
  14. img = self.files[index][0]
  15. label = self.files[index][1]
  16. return img,label
  17. def __len__(self):
  18. # 将0更改为数据集的总大小。
  19. return len(self.files)

定义了数据集类之后就可以创建对象并在其上进行迭代,例如:

  1. datasets = FirstDataset('../data/')
  2. for image,label in datasets:
  3. pass
  1. Dataloader
    Dataset类一般用于调用单个数据实例,现代的GPU都对批数据的执行进行了性能优化,DataLoader类通过多进程迭代器,为我们提供批量图片。
  1. train_loader = DataLoader(dataset=train_data, batch_size=6, shuffle=True num_workers=4)
  2. test_loader = DataLoader(dataset=test_data, batch_size=6, shuffle=Falsenum_workers=4)

batch_size:类似将数据打包成小份,设置每一小份的大小

shuffle=True:是否对数据进行洗牌操作(shuffling),是否打乱数据集内数据分布的顺序

num_workers=2:可以并行加载数据(利用多核处理器加快载入数据的效率)

  1. torchvision
  • dataset 一些基本的,常用的数据集
  • models 图像分类,图像分割,图像检测,视频分类的一些常用网络模型都有官方代码
  • transforms 对图片进行基础处理,切割,转换通道,归一化等。详细的’torchvision.transforms’的全部细节操作可以看这里
  • io/utils/ops 提供一些处理视频或一些特殊操作的接口,用到了在详细查询即可。

    基本流程

  1. 先将图片分成三个文件夹,分别是训练验证测试
  2. 然后将对应文件夹的图片和标签的路径读出来,写入txt,保证读取顺序
  3. 读取txt路径,创建DATASET类,用DataLoader读取

这是图片的读取方式,一些小细节要注意,图片的读取方式,一般为RGB,如果不是要转换一下。如果是调用现成的网络结构最好根据网络输入transform里切割或者resize一下,减少调整shape的工作量。

实例

要根据自己的数据格式来具体调整导入数据的方式,如果原始数据不是图片,只需要把数据导入成图片格式的矩阵维度即可,如果是语义分割任务,label也是一张图片,这里要注意一些细节,label的切割,transform会把类别变成小数。

  1. import os
  2. from torch.utils.data import Dataset, DataLoader
  3. import torchvision
  4. import torchvision.transforms as transforms
  5. from torch import nn
  6. class CustomDataset(Dataset):
  7. def __init__(self,data_root,NUM_CLASSES):
  8. self.train_data = np.load(os.path.join(data_root,'trainAVISO-SSH_2000-2010.npy'))
  9. self.train_label = np.load(os.path.join(data_root,'trainSegmentation_2000-2010.npy'))
  10. self.data_transform = transforms.Compose([
  11. transforms.ToPILImage(), \
  12. transforms.CenterCrop(10), \
  13. transforms.ToTensor(), \
  14. transforms.Normalize(mean=[0.485, 0.456, 0.406],
  15. std=[0.229, 0.224, 0.225])
  16. ])
  17. def __len__(self):
  18. return self.train_data.shape[0]
  19. def __getitem__(self, index):
  20. images = Image.fromarray(self.train_data[index,:,:])
  21. if images.mode != 'RGB':
  22. images = images.convert('RGB')
  23. image = self.data_transform(images)
  24. # ----------label--------------
  25. labels = Image.fromarray(self.train_label[index,:,:])
  26. self.train_labels = self.data_transform_label(labels)
  27. mask_each_classes = torch.zeros(NUM_CLASSES, image.shape[1], image.shape[2])
  28. for i in range(NUM_CLASSES):
  29. class_value = np.unique(self.train_labels.cpu().numpy()) # 类别经过归一化不再是 0,1,2
  30. mask_each_classes[i][self.train_labels[0,:,:] == class_value[i]] = 1
  31. batch = {'input': image, 'target': mask_each_classes}
  32. return batch
  1. DATA_ROOT = 'data/data_origin/'
  2. train_dataset = CustomDataset(DATA_ROOT,NUM_CLASSES = 3)
  3. train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)