在PyTorch中,torch.nntorch.nn.Functional中有一些功能类似的函数,但是它们之间存在一些不同,其中的nn.DropoutFunctional.Dropout的是线上就有一些不同。

例子

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class Model1(nn.Module):
  5. # Model1使用的是functioal.dropout
  6. def __init__(self, p=0.0):
  7. super().__init__()
  8. self.p = p
  9. def forward(self, inputs):
  10. return F.dropout(inputs, p=self.p, training=True)
  11. class Model2(nn.Module):
  12. # Model2使用的是nn.Dropout
  13. def __init__(self, p=0.0):
  14. super().__init__()
  15. self.drop_layer = nn.Dropout(p=p)
  16. def forward(self, inputs):
  17. return self.drop_layer(inputs)
  18. model1 = Model1(p=0.5) # torch.nn.functional.dropout
  19. model2 = Model2(p=0.5) # torch.nn.Dropout
  20. # 输入
  21. inputs = torch.rand(10)
  22. print('torch.nn.functional.dropout的输出{0}\n'.format(model1(inputs)))
  23. print('torch.nn.Dropout的输出{0}\n'.format(model2(inputs)))
  24. model1.eval()
  25. model2.eval()
  26. # 通过eval函数关掉dropout
  27. print('----------------查看eval函数的映像-------------\n')
  28. print('torch.nn.functional.dropout的输出{0}\n'.format(model1(inputs)))
  29. print('torch.nn.Dropout的输出{0}\n'.format(model2(inputs)))
  30. # 打印出模型
  31. print(model1)
  32. print(model2)

torch.nn.functional.dropout的输出tensor([0.0000, 0.2708, 0.0000, 0.7128, 0.0000, 1.7333, 0.0000, 0.3477, 0.0000, 0.0000])

torch.nn.Dropout的输出tensor([0.0000, 0.2708, 0.0000, 0.7128, 0.0000, 1.7333, 0.0000, 0.0000, 0.3211,

  1. 0.7905])

————————查看eval函数的映像——————-

torch.nn.functional.dropout的输出tensor([0.0000, 0.0000, 1.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.3477, 0.3211,

  1. 0.7905])

torch.nn.Dropout的输出tensor([0.8236, 0.1354, 0.8133, 0.3564, 0.9222, 0.8666, 0.9603, 0.1739, 0.1606,

  1. 0.3952])

Model1()

Model2(

(drop_layer): Dropout(p=0.5, inplace=False)

)

结论

在训练过程中,这两个函数是没有什么区别的。但是,在测试的时候,我们通常需要将参数设置为dropout(p=0.0)我们常见的model.eval()能够停止掉nn.Dropout的使用。对于torch.nn.functional.dropout就需要手动设置了,所以,一般测试过程中,通常使用nn.Dropout而不是用nn.functional.dropout

model.eval()还有另外一个作用,是关于Batch Normalization的,这里就不展开了。