参考来源:
CSDN:【pytorch】model.eval() 和 model.train()
CSDN:torch.nn.Module中的training 属性详情,与 Module.train() 和 Module.eval() 的关系
CSDN:torch.nn.Module.train(mode=True)
1. training 属性
Module类的构造函数:
def __init__(self):
"""
Initializes internal Module state, shared by both nn.Module and ScriptModule.
"""
torch._C._log_api_usage_once("python.nn_module")
self.training = True
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()
其中 **training**
属性表示 **BatchNorm**
与 **Dropout**
层在训练阶段和测试阶段中采取的策略不同,通过判断 **training**
值来决定前向传播策略。
对于一些含有 BatchNorm
、Dropout
等层的模型,在训练和验证时使用的 forward
在计算上不太一样。在前向训练的过程中指定当前模型是在训练还是在验证。使用 **module.train()**
和 **module.eval()**
进行使用,其中这两个方法的实现均有 training
属性实现。
关于这两个方法的定义源码如下:
train():
def train(self, mode=True):
r"""Sets the module in training mode.
This has any effect only on certain modules. See documentations of
particular modules for details of their behaviors in training/evaluation
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.
Args:
mode (bool): whether to set training mode (``True``) or evaluation
mode (``False``). Default: ``True``.
Returns:
Module: self
"""
self.training = mode
for module in self.children():
module.train(mode)
return self
eval():
def eval(self):
r"""Sets the module in evaluation mode.
This has any effect only on certain modules. See documentations of
particular modules for details of their behaviors in training/evaluation
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.
This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.
Returns:
Module: self
"""
return self.train(False)
从源码中可以看出,train()
和 eval()
方法将本层及子层的 training
属性同时设为 True
或 False
。具体如下:net.train()
:将本层及子层的 training
设定为 True
。net.eval()
:将本层及子层的 training
设定为 False
。net.training = True
:注意,对 module
的设置仅仅影响本层,子 module
不受影响。
net.training, net.submodel1.training
2. train()
**model.train()**
让 model
变成训练模式
3. eval()
**self.eval()**
和 **self.train(False)**
等价**eval()**
在测试之前加,否则有输入数据即使不训练,它也会改变权值。
pytorch 会自己把 BatchNormalization
和 DropOut
固定住,不会取平均,而是用训练好的值。