1 摘要

一般可以直接print(model),但这样只能看到模型都按什么顺序定义了哪些层,并不能看出真实的forward情况。
还可以跟踪看源码,这样比较容易看出封装了哪些重复使用的子组件。
不过从可视化来说,最推荐的还是使用tensorboard · 语雀

2 print(model)存在的问题

  1. import torch
  2. from torch import nn
  3. class NumNet(nn.Module):
  4. def __init__(self):
  5. super().__init__()
  6. self.line1 = nn.Linear(5, 10)
  7. self.relu1 = nn.ReLU()
  8. self.classifier = nn.Sequential(
  9. nn.Linear(5, 10),
  10. nn.ReLU(),
  11. nn.Linear(10, 5),
  12. nn.ReLU(),
  13. nn.Linear(5, 2),
  14. nn.ReLU(),
  15. )
  16. self.line3 = nn.Linear(5, 2)
  17. self.relu3 = nn.ReLU()
  18. self.line2 = nn.Linear(10, 5)
  19. self.relu2 = nn.ReLU()
  20. def forward(self, batched_inputs):
  21. print('print(model)不会执行forward,无法知道真实的module顺序')
  22. x = batched_inputs
  23. # logits = self.classifier(x)
  24. x = self.line1(x)
  25. x = self.relu1(x)
  26. x = self.line2(x)
  27. x = self.relu2(x)
  28. x = self.line3(x)
  29. x = self.relu3(x)
  30. logits = x
  31. return logits
  32. model = NumNet()
  33. print(model)
  34. # NumNet(
  35. # (line1): Linear(in_features=5, out_features=10, bias=True)
  36. # (relu1): ReLU()
  37. # (classifier): Sequential(
  38. # (0): Linear(in_features=5, out_features=10, bias=True)
  39. # (1): ReLU()
  40. # (2): Linear(in_features=10, out_features=5, bias=True)
  41. # (3): ReLU()
  42. # (4): Linear(in_features=5, out_features=2, bias=True)
  43. # (5): ReLU()
  44. # )
  45. # (line3): Linear(in_features=5, out_features=2, bias=True)
  46. # (relu3): ReLU()
  47. # (line2): Linear(in_features=10, out_features=5, bias=True)
  48. # (relu2): ReLU()
  49. # )

可以看到第28行并没有输出,没有执行forward。

再看输出顺序,完全是按照在init定义顺序展示的,也不管这个module是否在forward有用到。

所以用print分析模型结构的时候,谨慎些,最好还是用tensorboard查看。