参考

  • https://www.cnblogs.com/shuimuqingyang/p/14007260.html?ivk_sa=1024320u

    affine

    初始化时修改

    affine设为True时,BatchNorm层才会学习参数gamma和beta,否则不包含这两个变量,变量名是weight和bias。

    .train()

  • 如果affine==True,则对归一化后的batch进行仿射变换,即乘以模块内部的weight(初值是[1., 1., 1., 1.])然后加上模块内部的bias(初值是[0., 0., 0., 0.]),这两个变量会在反向传播时得到更新。

  • 如果affine==False,则BatchNorm中不含有weight和bias两个变量,什么都都不做。

    .eval()

  • 如果affine==True,则对归一化后的batch进行放射变换,即乘以模块内部的weight然后加上模块内部的bias,这两个变量都是网络训练时学习到的。

  • 如果affine==False,则BatchNorm中不含有weight和bias两个变量,什么都不做。

    修改实例属性

    无影响,仍按照初始化时的设定。

    track_running_stats

    由于BN的前向传播中涉及到了该属性,所以实例属性的修改会影响最终的计算过程。 ```python class NormBase(Module): “””Common base of InstanceNorm and _BatchNorm””” _version = 2 __constants = [‘track_running_stats’, ‘momentum’, ‘eps’,

    1. 'num_features', 'affine']

    num_features: int eps: float momentum: float affine: bool track_running_stats: bool

    WARNING: weight and bias purposely not defined here.

    See https://github.com/pytorch/pytorch/issues/39670

    def init(

    1. self,
    2. num_features: int,
    3. eps: float = 1e-5,
    4. momentum: float = 0.1,
    5. affine: bool = True,
    6. track_running_stats: bool = True

    ) -> None:

    1. super(_NormBase, self).__init__()
    2. self.num_features = num_features
    3. self.eps = eps
    4. self.momentum = momentum
    5. self.affine = affine
    6. self.track_running_stats = track_running_stats
    7. if self.affine:
    8. self.weight = Parameter(torch.Tensor(num_features))
    9. self.bias = Parameter(torch.Tensor(num_features))
    10. else:
    11. self.register_parameter('weight', None)
    12. self.register_parameter('bias', None)
    13. if self.track_running_stats:
    14. self.register_buffer('running_mean', torch.zeros(num_features))
    15. self.register_buffer('running_var', torch.ones(num_features))
    16. self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
    17. else:
    18. self.register_parameter('running_mean', None)
    19. self.register_parameter('running_var', None)
    20. self.register_parameter('num_batches_tracked', None)
    21. self.reset_parameters()

class _BatchNorm(_NormBase): …

  1. def forward(self, input: Tensor) -> Tensor:
  2. self._check_input_dim(input)
  3. if self.momentum is None:
  4. exponential_average_factor = 0.0
  5. else:
  6. exponential_average_factor = self.momentum
  7. if self.training and self.track_running_stats:
  8. if self.num_batches_tracked is not None: # type: ignore
  9. self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore
  10. if self.momentum is None: # use cumulative moving average
  11. exponential_average_factor = 1.0 / float(self.num_batches_tracked)
  12. else: # use exponential moving average
  13. exponential_average_factor = self.momentum
  14. r"""
  15. Decide whether the mini-batch stats should be used for normalization rather than the buffers.
  16. Mini-batch stats are used in training mode, and in eval mode when buffers are None.
  17. 可以看到这里的bn_training控制的是,数据运算使用当前batch计算得到的统计量(True)
  18. """
  19. if self.training:
  20. bn_training = True
  21. else:
  22. bn_training = (self.running_mean is None) and (self.running_var is None)
  23. r"""
  24. Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
  25. passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
  26. used for normalization (i.e. in eval mode when buffers are not None).
  27. 这里强调的是统计量buffer的使用条件(self.running_mean, self.running_var)
  28. - training==True and track_running_stats==False, 这些属性被传入F.batch_norm中时,均替换为None
  29. - training==True and track_running_stats==True, 会使用这些属性中存放的内容
  30. - training==False and track_running_stats==True, 会使用这些属性中存放的内容
  31. - training==False and track_running_stats==False, 会使用这些属性中存放的内容
  32. """
  33. assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor)
  34. assert self.running_var is None or isinstance(self.running_var, torch.Tensor)
  35. return F.batch_norm(
  36. input,
  37. # If buffers are not to be tracked, ensure that they won't be updated
  38. self.running_mean if not self.training or self.track_running_stats else None,
  39. self.running_var if not self.training or self.track_running_stats else None,
  40. self.weight, self.bias, bn_training, exponential_average_factor, self.eps)
  1. <a name="RIEdR"></a>
  2. ### `.train()`
  3. 注意代码中的注释:Buffers are only updated if they are to be tracked and we are in training mode. 即仅当处于训练模式,且`track_running_stats==True`时会更新这些统计量buffer。<br />另外,此时`self.training==True`。`bn_training=True`。
  4. <a name="Ta9FG"></a>
  5. #### `track_running_stats==True`
  6. BatchNorm层会统计全局均值running_mean和方差running_var,而对batch归一化时,仅使用当前batch的统计量。
  7. ```python
  8. self.register_buffer('running_mean', torch.zeros(num_features))
  9. self.register_buffer('running_var', torch.ones(num_features))
  10. self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))

使用momentum更新模块内部的running_mean。

  • 如果momentum是None,那么就是用累计移动平均(这里会使用属性self.num_batches_tracked来统计已经经过的batch数量),否则就使用指数移动平均(使用momentum作为系数)。二者的更新公式基本框架是一样的:BN参数详解 - 图1,只是具体的BN参数详解 - 图2有所不同。其中BN参数详解 - 图3代表更新后的runningmean和running_var;![](https://cdn.nlark.com/yuque/__latex/c6c3709e7635ece77599f982eb585a13.svg#card=math&code=x%7Bcur%7D&id=fe1xX)表示更新前的runningmean和running_var;![](https://cdn.nlark.com/yuque/__latex/4b9f30b4a23d10e0e64432949cd6e0a4.svg#card=math&code=x%7Bbatch%7D&id=hH2Qr)表示当前batch的均值和无偏样本方差。
  • 累计移动平均的更新中BN参数详解 - 图4
  • 指数移动平均的更新公式是BN参数详解 - 图5
    修改实例属性
    如果设置.track_running_stats==False,此时self.num_batches_tracked不会更新,而且exponential_average_factor也不会被重新调整。
    而由于:
    1
    2
              self.running_mean if not self.training or self.track_running_stats else None,
              self.running_var if not self.training or self.track_running_stats else None,
    
    且此时self.training==True,并且self.track_running_stats==False,所以送入F.batch_normself.running_mean&self.running_var两个参数都是None。
    也就是说,此时和直接在初始化中设置**track_running_stats==False**是一样的效果。
    但是要小心这里的~~exponential_average_factor~~的变化。不过由于通常我们初始化BN时,仅仅会送入~~num_features~~,所以默认会使用~~exponential_average_factor = self.momentum~~来构造指数移动平均更新运行时统计量。(此时exponential_average_factor不会发挥作用)

    track_running_stats==False

    则BatchNorm中不含有running_mean和running_var两个变量,也就是仅仅使用当前batch的统计量来归一化batch。
    1
    2
    3
              self.register_parameter('running_mean', None)
              self.register_parameter('running_var', None)
              self.register_parameter('num_batches_tracked', None)
    
    修改实例属性
    如果设置.track_running_stats==True,此时self.num_batches_tracked仍然不会更新,因为其初始值是None。
    整体来看,这样的修改并没有实际影响。

    .eval()

    此时self.training==False
    1
    2
              self.running_mean if not self.training or self.track_running_stats else None,
              self.running_var if not self.training or self.track_running_stats else None,
    
    此时送入F.batch_norm的两个统计量buffer和初始化时的结果是一致的。

    track_running_stats==True

    1
    2
    3
              self.register_buffer('running_mean', torch.zeros(num_features))
              self.register_buffer('running_var', torch.ones(num_features))
              self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
    
    此时bn_training = (self.running_mean is None) and (self.running_var is None) = False。所以使用全局的统计量。
    对batch进行归一化,公式为BN参数详解 - 图6,注意这里的均值和方差是running_mean和running_var,在网络训练时统计出来的全局均值和无偏样本方差
    修改实例属性
    如果设置.track_running_stats==False,此时bn_training不变,仍未False,所以仍然使用全局的统计量。也就是self.running_mean, self.running_var中存放的内容。
    整体而言,此时修改属性没有影响。

    track_running_stats==False

    1
    2
    3
              self.register_parameter('running_mean', None)
              self.register_parameter('running_var', None)
              self.register_parameter('num_batches_tracked', None)
    
    此时bn_training = (self.running_mean is None) and (self.running_var is None) = True。所以使用当前batch的统计量。
    对batch进行归一化,公式为BN参数详解 - 图7,注意这里的均值和方差是batch自己的mean和var,此时BatchNorm里不含有running_mean和running_var。
    注意此时使用的是无偏样本方差(和训练时不同),因此如果batch_size=1,会使分母为0,就报错了。
    修改实例属性
    如果设置.track_running_stats==True,此时bn_training不变,仍为True,所以仍然使用当前batch的统计量。也就是忽略self.running_mean, self.running_var中存放的内容。
    此时的行为和未修改时一致。

    汇总

    | track_running_stats | | | | | | —- | —- | —- | —- | —- | | 实例化设置 | True(默认值) | | False | | | | Buffer('running_mean', torch.zeros(num_features))
    Buffer('running_var', torch.ones(num_features))
    Buffer('num_batches_tracked', torch.tensor(0)) | | Parameter('running_mean', None)
    Parameter('running_var', None)
    Parameter('num_batches_tracked', None) | | | 修改属性值 | True(原始值) | False | True | False(原始值) | | training=True
    (.train()) | bn_training=True | | | | | |
    - 移动平均更新running_mean和running_var
    - 使用mini-batch的统计量归一化输入
    |
    - 不更新running_mean和running_var
    - 使用mini-batch的统计量归一化输入
    |
    - running_mean和running_var均为None
    - 使用mini-batch的统计量归一化输入
    | | | training=False
    (.eval()) | bn_training=False | | bn_training=True | | | |
    - 不更新running_mean和running_var
    - 使用全局的统计量归一化输入
    | |
    - running_mean和running_var均为None
    - 使用mini-batch的统计量归一化输入
    | |