PyTorch 源码解读之 BN & SyncBN:BN 与 多卡同步 BN 详解;【DL】数据规范化:你确定了解我吗?
BatchNormNd类
- 包括
BatchNorm1d
,BatchNorm2d
,BatchNorm3d
。区别只是检查了输入的合法性。 - BatchNorm1d接受 2D 或 3D 的输入,BatchNorm2d接受 4D 的输入,BatchNorm3d接受 5D 的输入。
这里简单贴一下BatchNorm2d
的实现:
class BatchNorm2d(_BatchNorm):
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))
- BatchNorm1d:输入输出 shape 为(N, C)或则(N, C, L);
- BatchNorm2d:输入输出 shape 为(N, C,H, W);
BatchNorm3d:输入输出 shape 为(N, C,D, H, W)。
多卡同步BN
BN 的性能和 batch size 有很大的关系。batch size 越大,BN 的统计量也会越准。
- 然而像检测这样的任务,占用显存较高,一张显卡往往只能拿较少的图片(比如 2 张)来训练,这就导致 BN 的表现变差。
- 一个解决方式是 SyncBN:所有卡共享同一个 BN,得到全局的统计量。
复习一下方差的计算方式:
单卡上的 BN 会计算该卡对应输入的均值、方差,然后做 Normalize;SyncBN 则需要得到全局的统计量,也就是“所有卡上的输入”对应的均值、方差。一个简单的想法是分两个步骤:
- 每张卡单独计算其均值,然后做一次同步,得到全局均值
- 用全局均值去算每张卡对应的方差,然后做一次同步,得到全局方差
但两次同步会消耗更多时间,事实上一次同步就可以实现 和 的计算:
只需要在同步时算好 和 即可。这里用一张图来描述这一过程。