PyTorch 源码解读之 BN & SyncBN:BN 与 多卡同步 BN 详解【DL】数据规范化:你确定了解我吗?

BatchNormNd类

  • 包括BatchNorm1dBatchNorm2dBatchNorm3d。区别只是检查了输入的合法性。
  • BatchNorm1d接受 2D 或 3D 的输入,BatchNorm2d接受 4D 的输入,BatchNorm3d接受 5D 的输入。

这里简单贴一下BatchNorm2d的实现:

  1. class BatchNorm2d(_BatchNorm):
  2. def _check_input_dim(self, input):
  3. if input.dim() != 4:
  4. raise ValueError('expected 4D input (got {}D input)'
  5. .format(input.dim()))
  • BatchNorm1d:输入输出 shape 为(N, C)或则(N, C, L);
  • BatchNorm2d:输入输出 shape 为(N, C,H, W);
  • BatchNorm3d:输入输出 shape 为(N, C,D, H, W)。

    多卡同步BN

  • BN 的性能和 batch size 有很大的关系。batch size 越大,BN 的统计量也会越准。

  • 然而像检测这样的任务,占用显存较高,一张显卡往往只能拿较少的图片(比如 2 张)来训练,这就导致 BN 的表现变差。
  • 一个解决方式是 SyncBN:所有卡共享同一个 BN,得到全局的统计量。

复习一下方差的计算方式: pytorch 源码解读 BN & SyncBN - 图1
单卡上的 BN 会计算该卡对应输入的均值、方差,然后做 Normalize;SyncBN 则需要得到全局的统计量,也就是“所有卡上的输入”对应的均值、方差。一个简单的想法是分两个步骤:

  1. 每张卡单独计算其均值,然后做一次同步,得到全局均值
  2. 用全局均值去算每张卡对应的方差,然后做一次同步,得到全局方差

但两次同步会消耗更多时间,事实上一次同步就可以实现 pytorch 源码解读 BN & SyncBN - 图2pytorch 源码解读 BN & SyncBN - 图3 的计算
pytorch 源码解读 BN & SyncBN - 图4
只需要在同步时算好 pytorch 源码解读 BN & SyncBN - 图5pytorch 源码解读 BN & SyncBN - 图6 即可。这里用一张图来描述这一过程。
pytorch 源码解读 BN & SyncBN - 图7