一、数据来源

1. 原始数据来源:Kaggle

2. python模块

  1. #模块
  2. import torch
  3. from torchvision import datasets, transforms
  4. #方法
  5. datasets.ImageFolder(filepath, transform=transform)

注意:自带父目录的标签(建立清晰的分类目录)PyTorch 专题 | 输入-数据加载 - 图1

二、数据整理 transforms()

  • transforms.Compose([ ])
    输入为transforms.操作()的列表:多个transform组合起来使用

1. 尺寸:比例缩放

  1. transforms.Resize()
  • 二维输入:(height, width)
  • 一维输入:较小的边匹配这个输入size,比如height>width时,尺度调整为 (size * height / width, size)
  1. torchvision.transforms.Scale()
    不推荐使用:可能有畸变

2. 尺寸:裁剪

  1. transforms.CenterCrop()
    基于中心裁剪:输出正方形(size, size)
  2. transforms.RandomResizedCrop()
    根据比例,裁剪为随机大小

3. 位移:旋转,翻转

  1. transforms.RandomRotation() 旋转
  2. transforms.RandomHorizontalFlip(p=0.5) 水平翻转 (概率为0.5)

4. 数据转换

  1. transforms.ToTensor()
    图像变为pytorch张量;适用情境:灰度值变为彩色图像
  2. transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    RGB三个通道的归一化(均值0.5, 标准差0.5)

三、针对模型调整输入数据的格式

1. 明确操作对象

如图:PyTorch 专题 | 输入-数据加载 - 图2