track_running_stats,这个翻译过来,就是“追踪运行中的统计值”,这里stats统计值包括均值、方差。
    1、track_running_stats=False,不追踪,那就是每个mini-batch过来独立计算,跟历史数据无关。
    2、track_running_stats=True,如果追踪,分情况讨论
    ① train阶段:会学习历史数据的stat,用于每次mini-batch,而且存下权重,eval阶段要用
    ② eval阶段:不再更新bn层的权重值,也不使用mini-batch的stat,而是bn层在train学到的stat

    1. """ bn层,track机制实验,存储在runing_mean、runing_var的历史值"""
    2. import torch
    3. from torch import nn
    4. batch1 = torch.tensor([[[[1.0, 2.0]]]])
    5. batch2 = torch.tensor([[[[3.0, 4.0, 5.0]]]])
    6. bn = nn.BatchNorm2d(1, track_running_stats=True)
    7. print(bn.running_mean, bn.running_var)
    8. # (tensor([0.]), tensor([1.])),均值和方差初始值是0和1
    9. bn(batch1)
    10. print(bn.running_mean, bn.running_var)
    11. # (tensor([0.1500]), tensor([0.9500]))
    12. # 前传batch1后,记录了均值和方差,因为momentum=0.1数值差了一点
    13. # 更新的mean=0.9*旧值0 + 0.1*新值1.5=0.15
    14. # new_var = 0.9*1 + 0.1*0.5 = 0.95
    15. print(batch1.mean(), batch1.var())
    16. # (tensor(1.5000), tensor(0.5000))
    17. bn(batch2)
    18. print(bn.running_mean, bn.running_var)
    19. # tensor([0.5350]) tensor([0.9550])
    20. print(batch2.mean(), batch2.var())
    21. # tensor(4.) tensor(1.)