1 摘要
一般可以直接print(model),但这样只能看到模型都按什么顺序定义了哪些层,并不能看出真实的forward情况。
还可以跟踪看源码,这样比较容易看出封装了哪些重复使用的子组件。
不过从可视化来说,最推荐的还是使用tensorboard · 语雀。
2 print(model)存在的问题
import torch
from torch import nn
class NumNet(nn.Module):
def __init__(self):
super().__init__()
self.line1 = nn.Linear(5, 10)
self.relu1 = nn.ReLU()
self.classifier = nn.Sequential(
nn.Linear(5, 10),
nn.ReLU(),
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 2),
nn.ReLU(),
)
self.line3 = nn.Linear(5, 2)
self.relu3 = nn.ReLU()
self.line2 = nn.Linear(10, 5)
self.relu2 = nn.ReLU()
def forward(self, batched_inputs):
print('print(model)不会执行forward,无法知道真实的module顺序')
x = batched_inputs
# logits = self.classifier(x)
x = self.line1(x)
x = self.relu1(x)
x = self.line2(x)
x = self.relu2(x)
x = self.line3(x)
x = self.relu3(x)
logits = x
return logits
model = NumNet()
print(model)
# NumNet(
# (line1): Linear(in_features=5, out_features=10, bias=True)
# (relu1): ReLU()
# (classifier): Sequential(
# (0): Linear(in_features=5, out_features=10, bias=True)
# (1): ReLU()
# (2): Linear(in_features=10, out_features=5, bias=True)
# (3): ReLU()
# (4): Linear(in_features=5, out_features=2, bias=True)
# (5): ReLU()
# )
# (line3): Linear(in_features=5, out_features=2, bias=True)
# (relu3): ReLU()
# (line2): Linear(in_features=10, out_features=5, bias=True)
# (relu2): ReLU()
# )
可以看到第28行并没有输出,没有执行forward。
再看输出顺序,完全是按照在init定义顺序展示的,也不管这个module是否在forward有用到。
所以用print分析模型结构的时候,谨慎些,最好还是用tensorboard查看。