为什么使用Dataset

DatasetPyTorch自定义的一种数据集格式,主要是为Dataloader服务,Dataloader是一种可以批量(batch)读取数据和数据label的迭代器,非常方便我们提取数据进行深度学习,要生成Dataloader迭代器,必须需要Dataset类型的数据。

PyTorch自带的可下载的数据集全部继承于Dataset,也就是可以直接在Dataloader中使用,如果我们想使用自己的数据集,且想使用Dataloader封装我们的数据集的话,我们就需要自定义一个继承Dataset类型的数据集。

自定义Dataset类型数据

基本框架

  • 自定义Dataset必须继承torch.utils.data.Dataset类,并且必须实现 __getitem____len__两个函数。

    • __getitem__接收传入的索引index,并且返回index所对应的数据(data)和标签(label)
    • __len__返回数据集的长度

      1. class 类名(torch.utils.data.Dataset):
      2. def __getitem__(self, index):
      3. # 返回对应index的数据和标签
      4. return data,target
      5. def __len__(self):
      6. return lenth

      实现案例

      假设我们本地有一个训练集,有两个类别antsbees,层级如下图:
      image.png
      其中,ants文件夹内都是ants类别的图片,bees文件夹内都是bees类别的图片。
      数据集链接: https://pan.baidu.com/s/1jZoTmoFzaTLWh4lKBHVbEA 密码: 5suq

现在我们想把这个训练集做成Dataset类型的数据,传入Dataloader进行训练,我们只需要想两个问题,如何根据索引取出相应的图片和标签?如何确定数据集的长度?其实就是如何实现__getitem____len__函数?

这个其实就是Python的问题。我对具体实现代码进行了详细的注释。
代码阅读顺序:直接阅读__init__即可。

  1. # 首先我们要导入Dataset类进行继承,os类处理文件路径,Image类读取图片。
  2. from torch.utils.data import Dataset
  3. from PIL import Image
  4. import os
  5. class MyData(Dataset):
  6. def __init__(self,path) -> None:
  7. # path是trian的文件路径,在我电脑就是'/Users/bruce/Downloads/数据集/hymenoptera_data/train/'
  8. self.path=path
  9. # os.listdir(path)以列表形式返回path下的文件名/文件夹名
  10. self.label_list = os.listdir(path)
  11. self.data_list = [] # 用来保存数据
  12. # 下面的循环就是,以(data,label)的形式,将数据存入data_list
  13. for label in self.label_list:
  14. # 这个if语句用来过滤mac系统下的.DS_Store的隐藏文件,没有隐藏文件的可不加
  15. # os官网说listdir不显示隐藏文件,然而不知道为什么还是有
  16. if label.startswith('.'):
  17. continue
  18. # 下面代码可以直接写成列表推导式
  19. for data in os.listdir(path+label):
  20. # os.path.join(a,b,c)就是连接成一个路径/a/b/c
  21. data = Image.open(os.path.join(path,label,data))
  22. self.data_list.append((data,label))
  23. def __getitem__(self,index):
  24. # 这里我们data_list保存着所有数据,且格式就是(data,label),直接返回即可
  25. return self.data_list[index]
  26. def __len__(self):
  27. return len(self.data_list)

实现完成之后,我们就可以实例化,通过索引获取数据,如图:image.png
image.png

问题

表面上看,我们的工作好像完成了,已经可以通过索引获取数据了,但是当我们把上面的train_dataset传入Dataloader后,使用Dataloader取数据时会报错,解决完一个错之后,又报另一个,具体报错信息不放出来了,有兴趣的可以自己试一下,总得来说一共三个报错的点:

  1. 传入的图片大小不一致。现实中的图片就是这样,几乎不会有所有图片大小都一致的情况
  2. 当我们处理完图片大小后,又报错:传入的图片通道不一致。
    1. 大部分图像都是三通道,里面夹杂着有一个单通道图像,所以报错
  3. Dataloader要求传入的数据是tensor类型,而不是PIL.image的图片类型

下面是修改后的代码:
阅读建议:直接从__init__阅读即可,比较大的修改是利用了torchvision中的transforms,先把数据转换成统一大小,再转换成tensor向量

  1. from torchvision import transforms
  2. from torch.utils.data import Dataset
  3. from PIL import Image
  4. import os
  5. class MyData(Dataset):
  6. def __init__(self,path) -> None:
  7. self.path=path
  8. self.label_list = os.listdir(path)
  9. self.data_list = []
  10. for label in self.label_list:
  11. if label.startswith('.'):
  12. continue
  13. # 下面代码可以直接写成列表推导式
  14. for data in os.listdir(os.path.join(path,label)):
  15. data = Image.open(os.path.join(path,label,data))
  16. # 这里我们只要3通道的图像
  17. if len(data.split())==3:
  18. self.data_list.append((self.transfrom(data),label))
  19. def __getitem__(self,index):
  20. return self.data_list[index]
  21. def __len__(self):
  22. return len(self.data_list)
  23. def transfrom(self,data):
  24. trans = transforms.Compose([
  25. transforms.Resize((300,300)),
  26. transforms.ToTensor(),
  27. ])
  28. data = trans(data)
  29. '''
  30. 如果不想实例化之后再调用,可以直接
  31. data = transforms.Compose([
  32. transforms.Resize((300,300)),
  33. transforms.ToTensor(),
  34. ])(data)
  35. '''
  36. return data

Dataset的可加性

PyTorch允许两个Dataset的实例进行相加,从而实现数据集的整合。
假设下面path1path2是两个不同的路径,都是我们的训练集,我们就可以直接把两个训练集加起来:
image.png

ImageFolder

其实对于上面这种形式的数据集,PyTorch有一个专门的类ImageFolder来处理,非常的方便。
详见ImageFolder