1 摘要
一般可以直接print(model),但这样只能看到模型都按什么顺序定义了哪些层,并不能看出真实的forward情况。
还可以跟踪看源码,这样比较容易看出封装了哪些重复使用的子组件。
不过从可视化来说,最推荐的还是使用tensorboard · 语雀。
2 print(model)存在的问题
import torchfrom torch import nnclass NumNet(nn.Module):def __init__(self):super().__init__()self.line1 = nn.Linear(5, 10)self.relu1 = nn.ReLU()self.classifier = nn.Sequential(nn.Linear(5, 10),nn.ReLU(),nn.Linear(10, 5),nn.ReLU(),nn.Linear(5, 2),nn.ReLU(),)self.line3 = nn.Linear(5, 2)self.relu3 = nn.ReLU()self.line2 = nn.Linear(10, 5)self.relu2 = nn.ReLU()def forward(self, batched_inputs):print('print(model)不会执行forward,无法知道真实的module顺序')x = batched_inputs# logits = self.classifier(x)x = self.line1(x)x = self.relu1(x)x = self.line2(x)x = self.relu2(x)x = self.line3(x)x = self.relu3(x)logits = xreturn logitsmodel = NumNet()print(model)# NumNet(# (line1): Linear(in_features=5, out_features=10, bias=True)# (relu1): ReLU()# (classifier): Sequential(# (0): Linear(in_features=5, out_features=10, bias=True)# (1): ReLU()# (2): Linear(in_features=10, out_features=5, bias=True)# (3): ReLU()# (4): Linear(in_features=5, out_features=2, bias=True)# (5): ReLU()# )# (line3): Linear(in_features=5, out_features=2, bias=True)# (relu3): ReLU()# (line2): Linear(in_features=10, out_features=5, bias=True)# (relu2): ReLU()# )
可以看到第28行并没有输出,没有执行forward。
再看输出顺序,完全是按照在init定义顺序展示的,也不管这个module是否在forward有用到。
所以用print分析模型结构的时候,谨慎些,最好还是用tensorboard查看。
