今天在看别人的网络训练框架的时候,看到了一句with torch.no_grad()
,看了很多的训练方案,只有这个代码中有这句话,而且是在训练过程完成,测试和验证过程开始之前,猜测这句话可能和model.eval()
拥有类似的功能用来切换状态。
这句话本身有什么作用呢?
model.eval()
将会切换网络中的batchnorm和dropout层的状态,原来在训练时候被激活的状态在测试过程中就不会被激活了torch.no_grad()
影响的是自动求导的引擎并且使其不能够在求导,能够减少内存使用加速计算速度,在使用了这句话之后整个网络的梯度都会停止在进入验证集之前。
另外,还有一个函数是torch.set_grad_enabled
,这个函数的功能是手动地开启/关闭梯度的计算,相比于torch.no_grad()
,它更注重于手动操作。
参考来源:‘model.eval()’ vs ‘with torch.no_grad()’ - PyTorch Forums