不要这样去做模式切换:
    image.png

    正确做法是:

    1. """ train和eval的切换 """
    2. from torch import nn
    3. m = nn.Module()
    4. # 训练只有一种写法,默认参数True可以不写
    5. m.train(True)
    6. # 推断有两种等价写法,eval本质上也是调用train(False)实现的
    7. m.train(False)
    8. m.eval()

    之所以要用m.train()的方式,是因为模式的切换,并不只是一个标记变量的问题,而是要遍历所有嵌套module进行处理的

    是要执行一个函数,遍历检查,统一切换
    只是执行这个函数的时候,也会把training的标记变量改了
    所以可以通过training来知道模型所处模式

    train的源码:
    image.png


    注意eval常跟with torch.nograd()配合使用