不要这样去做模式切换:
正确做法是:
""" train和eval的切换 """
from torch import nn
m = nn.Module()
# 训练只有一种写法,默认参数True可以不写
m.train(True)
# 推断有两种等价写法,eval本质上也是调用train(False)实现的
m.train(False)
m.eval()
之所以要用m.train()的方式,是因为模式的切换,并不只是一个标记变量的问题,而是要遍历所有嵌套module进行处理的
是要执行一个函数,遍历检查,统一切换
只是执行这个函数的时候,也会把training的标记变量改了
所以可以通过training来知道模型所处模式
train的源码:
注意eval常跟with torch.nograd()配合使用