参考来源:
CSDN:【pytorch】model.eval() 和 model.train()
CSDN:torch.nn.Module中的training 属性详情,与 Module.train() 和 Module.eval() 的关系
CSDN:torch.nn.Module.train(mode=True)
1. training 属性
Module类的构造函数:
def __init__(self):"""Initializes internal Module state, shared by both nn.Module and ScriptModule."""torch._C._log_api_usage_once("python.nn_module")self.training = Trueself._parameters = OrderedDict()self._buffers = OrderedDict()self._backward_hooks = OrderedDict()self._forward_hooks = OrderedDict()self._forward_pre_hooks = OrderedDict()self._state_dict_hooks = OrderedDict()self._load_state_dict_pre_hooks = OrderedDict()self._modules = OrderedDict()
其中 **training** 属性表示 **BatchNorm** 与 **Dropout** 层在训练阶段和测试阶段中采取的策略不同,通过判断 **training** 值来决定前向传播策略。
对于一些含有 BatchNorm、Dropout 等层的模型,在训练和验证时使用的 forward 在计算上不太一样。在前向训练的过程中指定当前模型是在训练还是在验证。使用 **module.train()** 和 **module.eval()** 进行使用,其中这两个方法的实现均有 training 属性实现。
关于这两个方法的定义源码如下:
train():
def train(self, mode=True):r"""Sets the module in training mode.This has any effect only on certain modules. See documentations ofparticular modules for details of their behaviors in training/evaluationmode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,etc.Args:mode (bool): whether to set training mode (``True``) or evaluationmode (``False``). Default: ``True``.Returns:Module: self"""self.training = modefor module in self.children():module.train(mode)return self
eval():
def eval(self):r"""Sets the module in evaluation mode.This has any effect only on certain modules. See documentations ofparticular modules for details of their behaviors in training/evaluationmode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,etc.This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.Returns:Module: self"""return self.train(False)
从源码中可以看出,train() 和 eval() 方法将本层及子层的 training 属性同时设为 True 或 False 。具体如下:net.train():将本层及子层的 training 设定为 True 。net.eval():将本层及子层的 training 设定为 False 。net.training = True:注意,对 module 的设置仅仅影响本层,子 module 不受影响。
net.training, net.submodel1.training
2. train()
**model.train()** 让 model 变成训练模式
3. eval()
**self.eval()** 和 **self.train(False)** 等价**eval()** 在测试之前加,否则有输入数据即使不训练,它也会改变权值。
pytorch 会自己把 BatchNormalization 和 DropOut 固定住,不会取平均,而是用训练好的值。
