:::warning 最近发现自己的基础有点偏弱,再回头补补基础知识,再把学到的知识写成文档记录下来。 -2021.12.16晚 :::

摘要

今天讲的是BatchNorm,主要包括

BatchNorm的优点

BatchNorm已经作为常用的手段应用在深度学习中,效果显著,加快了训练速度,保证了梯度的流动,防止过拟合,降低网络对初始化权重的敏感程度,减少对调参的要求。

1. 加速训练

输出分布向着激活函数的上下限偏移,带来的问题就是梯度的降低,(比如说激活函数是sigmoid),通过normalization,数据在一一个合适的分布空间,经过激活函数,仍然得到不错的梯度。梯度好了自然加速训练。

2. 降低参数初始化敏感

以往模型需要设置一个不错的初始化才适合训练,加了BN就不用管这些了,现在初始化方法中随便选择一个用,训练得到的模型就能收敛。

BatchNorm的灵感来源

在讲解BN之前,我们需要了解BN是怎么被提出来的。在机器学习领域,数据分布是很重要的概念,如果训练集和测试集的分布很不相同,那么在训练集上训练好的模型,在测试集上应该不奏效(比如用ImageNet训练的分类网络去在灰度医学图像上finetue在测试,效果应该不好)。对于神经网络来说,如果每一层的数据分布都不一样,后一层的网络则需要去学习适应前一层的数据分布,这相当于去做了domain的adaptation.无疑增加了训练难度,尤其是在网络越来越深的情况下。
实际上,确实如此,不同层的输出的分布是有差异的。在BN的那篇论文中指出,不同层的数据分布会往激活函数的上限或者下限偏移。论文中称这种偏移为internal Covariate Shift, internal指的是网络内部。
BN就是为了解决偏移的,解决的方式也很简单,就是让每一层的分布都normalize到标准高斯分布。(这里的每一层并不准确,BN是根据划分数据集的集合去做Normalization,不同的划分方式也就出现了不同的Normalization, 如GN, LN, IN等)

BatchNorm的计算过程(理论知识)

我们假设网络中间经过某些卷积操作之后输出的feature map的尺寸为4322
其中,batch=4, channel=3, 2
2为feature map的长宽
整个BN层的运算过程如下图
batch size=4, 每个batch的feature map的size是3*2*2
对于所有的batch中的同一个channel的元素进行求均值和方差,比如上图,对于所有的batch, 都拿出来最后一个chanel, 一共有(22) 4=16个元素
然后求这16个元素的均值与方差(上图只求了mean, 没有求方差)
求完了均值和方差后,对于这16个元素中的每个元素减去求得的均值和方差,然后乘以BatchNorm2D() - 图2加上BatchNorm2D() - 图3,公式如下(1)
BatchNorm2D() - 图4
所以对于一个batch normalization层而言,求取的均值和方差是对于所有的batch中的同一个channel进行求取,batch normalization层能够学习到的参数,对于一个特定的channel而言, 实际上是两个参数,BatchNorm2D() - 图5BatchNorm2D() - 图6,对于total的channel而言,实际上是channel数目的两倍。

使用PyTorch验证理论知识

  1. # -*-coding:utf-8-*-
  2. from torch import nn
  3. import torch
  4. m = nn.BatchNorm2d(3) # bn设置的参数实际上是channel的参数
  5. input = torch.randn(4, 3, 2, 2)
  6. output = m(input)
  7. # print(output)
  8. a = (input[0, 0, :, :]+input[1, 0, :, :]+input[2, 0, :, :]+input[3, 0, :, :]).sum()/16
  9. b = (input[0, 1, :, :]+input[1, 1, :, :]+input[2, 1, :, :]+input[3, 1, :, :]).sum()/16
  10. c = (input[0, 2, :, :]+input[1, 2, :, :]+input[2, 2, :, :]+input[3, 2, :, :]).sum()/16
  11. print('计算出的第1个channel的均值:%f' % a.data)
  12. print('计算出的第2个channel的均值:%f' % b.data)
  13. print('计算出的第3个channel的均值:%f' % c.data)
  14. print('nn.BatchNorm的官方代码算的均值:%f, %f, %f' % (m.running_mean.data[0],m.running_mean.data[1],m.running_mean.data[2]))
  15. print()
  16. print('计算出的方差:没算')
  17. print('nn.BatchNorm的官方代码算的方差,{}, {}, {}'.format(m.running_var.data[0], m.running_var.data[1], m.running_var.data[2]))
  18. print(m)

计算出的第1个channel的均值:-0.044225 计算出的第2个channel的均值:-0.427377 计算出的第3个channel的均值:-0.269129 nn.BatchNorm的官方代码算的均值:-0.004422, -0.042738, -0.026913

计算出的方差:没算 nn.BatchNorm的官方代码算的方差,0.9888013601303101, 1.0757482051849365, 1.0104997158050537 BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

:::info 由运行的结果可知,理论和实践是相符合的,没毛病! :::

参考链接

  1. https://blog.csdn.net/qq_34914551/article/details/102736271
  2. https://www.cnblogs.com/yongjieShi/p/9332655.html