参考来源:
CSDN:【pytorch】model.eval() 和 model.train()
CSDN:torch.nn.Module中的training 属性详情,与 Module.train() 和 Module.eval() 的关系
CSDN:torch.nn.Module.train(mode=True)

1. training 属性

Module类的构造函数:

  1. def __init__(self):
  2. """
  3. Initializes internal Module state, shared by both nn.Module and ScriptModule.
  4. """
  5. torch._C._log_api_usage_once("python.nn_module")
  6. self.training = True
  7. self._parameters = OrderedDict()
  8. self._buffers = OrderedDict()
  9. self._backward_hooks = OrderedDict()
  10. self._forward_hooks = OrderedDict()
  11. self._forward_pre_hooks = OrderedDict()
  12. self._state_dict_hooks = OrderedDict()
  13. self._load_state_dict_pre_hooks = OrderedDict()
  14. self._modules = OrderedDict()

其中 **training** 属性表示 **BatchNorm****Dropout** 层在训练阶段和测试阶段中采取的策略不同,通过判断 **training** 值来决定前向传播策略。
对于一些含有 BatchNormDropout 等层的模型,在训练和验证时使用的 forward 在计算上不太一样。在前向训练的过程中指定当前模型是在训练还是在验证。使用 **module.train()****module.eval()** 进行使用,其中这两个方法的实现均有 training 属性实现。
关于这两个方法的定义源码如下:
train():

  1. def train(self, mode=True):
  2. r"""Sets the module in training mode.
  3. This has any effect only on certain modules. See documentations of
  4. particular modules for details of their behaviors in training/evaluation
  5. mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
  6. etc.
  7. Args:
  8. mode (bool): whether to set training mode (``True``) or evaluation
  9. mode (``False``). Default: ``True``.
  10. Returns:
  11. Module: self
  12. """
  13. self.training = mode
  14. for module in self.children():
  15. module.train(mode)
  16. return self

eval():

  1. def eval(self):
  2. r"""Sets the module in evaluation mode.
  3. This has any effect only on certain modules. See documentations of
  4. particular modules for details of their behaviors in training/evaluation
  5. mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
  6. etc.
  7. This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.
  8. Returns:
  9. Module: self
  10. """
  11. return self.train(False)

从源码中可以看出,train()eval() 方法将本层及子层的 training 属性同时设为 TrueFalse 。具体如下:
net.train():将本层及子层的 training 设定为 True
net.eval():将本层及子层的 training 设定为 False
net.training = True:注意,对 module 的设置仅仅影响本层,子 module 不受影响。
net.training, net.submodel1.training

2. train()

**model.train()**model 变成训练模式
image.png

3. eval()

**self.eval()****self.train(False)** 等价
**eval()** 在测试之前加,否则有输入数据即使不训练,它也会改变权值。
pytorch 会自己把 BatchNormalizationDropOut 固定住,不会取平均,而是用训练好的值。
image.png