track_running_stats,这个翻译过来,就是“追踪运行中的统计值”,这里stats统计值包括均值、方差。
1、track_running_stats=False,不追踪,那就是每个mini-batch过来独立计算,跟历史数据无关。
2、track_running_stats=True,如果追踪,分情况讨论
① train阶段:会学习历史数据的stat,用于每次mini-batch,而且存下权重,eval阶段要用
② eval阶段:不再更新bn层的权重值,也不使用mini-batch的stat,而是bn层在train学到的stat
""" bn层,track机制实验,存储在runing_mean、runing_var的历史值"""
import torch
from torch import nn
batch1 = torch.tensor([[[[1.0, 2.0]]]])
batch2 = torch.tensor([[[[3.0, 4.0, 5.0]]]])
bn = nn.BatchNorm2d(1, track_running_stats=True)
print(bn.running_mean, bn.running_var)
# (tensor([0.]), tensor([1.])),均值和方差初始值是0和1
bn(batch1)
print(bn.running_mean, bn.running_var)
# (tensor([0.1500]), tensor([0.9500]))
# 前传batch1后,记录了均值和方差,因为momentum=0.1数值差了一点
# 更新的mean=0.9*旧值0 + 0.1*新值1.5=0.15
# new_var = 0.9*1 + 0.1*0.5 = 0.95
print(batch1.mean(), batch1.var())
# (tensor(1.5000), tensor(0.5000))
bn(batch2)
print(bn.running_mean, bn.running_var)
# tensor([0.5350]) tensor([0.9550])
print(batch2.mean(), batch2.var())
# tensor(4.) tensor(1.)