1. transforms 运行机制
  2. 数据标准化—— transforms.normalize()

一、transforms 运行机制

在安装 PyTorch 的时候,我们安装了两个安装包,一个是 PyTorch ,一个是 torchvisiontorchvision 是计算机视觉工具包。在torchvision 当中有三个主要的模块。第一个模块是我们图像预处理模块;第二个主要的模块是 datasets :这里有常用的公开数据集的dataset;最后一个主要模块是 model :提供大量常用的预训练模型。图像预处理 transforms 子模块当中提供了很多的处理方法,例如数据中心化、标准化、缩放、裁剪、旋转、翻转、填充以及各种变化等等。

深度学习是有数据驱动的,数据的数量以及分布对于模型的优劣是起到决定性作用。所以我们需要对数据进行一定的预处理,以及数据增强,用来提升我们模型的泛化能力。如果我们做数据增强的时候。生成了一些与测试样本很相似的图片。那我们的模型的方法能力自然就会得到提高。这就是我们做数据增强的原因

  • torchvision :计算机视觉工具包
    • torchvision.transforms : 常用的图像预处理方法
    • torchvision.datasets : 常用数据集的dataset实现, MNISTCIFAR-10ImageNet
    • torchvision.model : 常用的模型预训练, AlexNetVGGResNetGoogLeNet
  • torchvision.transforms : 常用的图像预处理方法

    • 数据中心化、数据标准化
    • 缩放、裁剪、旋转、翻转、填充、噪声添加
    • 灰度变换、线性变换、仿射变换
    • 亮度、饱和度及对比度变换

      二、transforms.normalize()

      1. transforms.Normalize(
      2. mean,
      3. std,
      4. inplace=False)
  • 功能:逐 channel 的对图像进行标准化,将数据分布转化为方差为0, 标准差为1。

  • output = (input - mean) / std
    • mean :各通道的均值
    • std :各通道的标准差
    • inplace :是否原地操作

经过断点调试,该API的实现在 functional.pynormalize() 函数

  1. def normalize(tensor, mean, std, inplace=False):
  2. """Normalize a tensor image with mean and standard deviation.
  3. .. note::
  4. This transform acts out of place by default, i.e., it does not mutates the input tensor.
  5. See :class:`~torchvision.transforms.Normalize` for more details.
  6. Args:
  7. tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
  8. mean (sequence): Sequence of means for each channel.
  9. std (sequence): Sequence of standard deviations for each channel.
  10. inplace(bool,optional): Bool to make this operation inplace.
  11. Returns:
  12. Tensor: Normalized Tensor image.
  13. """
  14. if not torch.is_tensor(tensor):
  15. raise TypeError('tensor should be a torch tensor. Got {}.'.format(type(tensor)))
  16. if tensor.ndimension() != 3:
  17. raise ValueError('Expected tensor to be a tensor image of size (C, H, W). Got tensor.size() = '
  18. '{}.'.format(tensor.size()))
  19. if not inplace:
  20. tensor = tensor.clone()
  21. dtype = tensor.dtype
  22. mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
  23. std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
  24. if (std == 0).any():
  25. raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype))
  26. if mean.ndim == 1:
  27. mean = mean[:, None, None]
  28. if std.ndim == 1:
  29. std = std[:, None, None]
  30. tensor.sub_(mean).div_(std)
  31. return tensor

如果我们的训练数据有良好的分布以及良好的初始化,可以加速我们模型的收敛