PyTorch
PyTorch 中有一些基础概念在构建网络的时候很重要,比如 **nn.Module**, **nn.ModuleList**, **nn.Sequential**,这些类称之为容器 (containers),因为可以添加模块 (module) 到它们之中。
这些容器之间很容易混淆,本文中主要学习一下 nn.ModuleListnn.Sequential,并判断在什么时候用哪一个比较合适
这里的例子使用的是 PyTorch 1.0 版本。

nn.ModuleList

首先说说 nn.ModuleList 这个类,可以把任意 **nn.Module** 的子类 (比如 nn.Conv2d, nn.Linear 之类的) 加到这个 list 里面,方法和 Python 自带的 list 一样,无非是 extendappend 等操作。
不同于一般的 list,加入到 nn.ModuleList 里面的 module 是会自动注册到整个网络上的,同时 module 的 parameters 也会自动添加到整个网络中。描述看起来很枯燥,来看几个例子。
第一个网络,先来看看使用 **nn.ModuleList** 来构建一个小型网络,包括3个全连接层:

  1. class net1(nn.Module):
  2. def __init__(self):
  3. super(net1, self).__init__()
  4. self.linears = nn.ModuleList([nn.Linear(10,10) for i in range(2)])
  5. def forward(self, x):
  6. for m in self.linears:
  7. x = m(x)
  8. return x
  9. net = net1()
  10. print(net)
  11. # net1(
  12. # (modules): ModuleList(
  13. # (0): Linear(in_features=10, out_features=10, bias=True)
  14. # (1): Linear(in_features=10, out_features=10, bias=True)
  15. # )
  16. # )
  17. for param in net.parameters():
  18. print(type(param.data), param.size())
  19. # <class 'torch.Tensor'> torch.Size([10, 10])
  20. # <class 'torch.Tensor'> torch.Size([10])
  21. # <class 'torch.Tensor'> torch.Size([10, 10])
  22. # <class 'torch.Tensor'> torch.Size([10])

可以看到,这个网络包含两个全连接层,他们的权重 (weithgs) 和偏置 (bias) 都在这个网络之内
接下来看看第二个网络,它使用 Python 自带的 list

  1. class net2(nn.Module):
  2. def __init__(self):
  3. super(net2, self).__init__()
  4. self.linears = [nn.Linear(10,10) for i in range(2)]
  5. def forward(self, x):
  6. for m in self.linears:
  7. x = m(x)
  8. return x
  9. net = net2()
  10. print(net)
  11. # net2()
  12. print(list(net.parameters()))
  13. # []

显然,使用 Python 的 list 添加的全连接层和它们的 parameters 并没有自动注册到网络中。当然,还是可以使用 forward 来计算输出结果。
但是如果用 net2 实例化的网络进行训练的时候,因为这些层的 parameters 不在整个网络之中,所以其网络参数也不会被更新,也就是无法训练
好,看到这里,大致明白了 **nn.ModuleList** 是干什么的了:它是一个储存不同 module并自动将每个 module 的 parameters 添加到网络之中的容器
但是,需要注意到,nn.ModuleList 并没有定义一个网络,它只是将不同的模块储存在一起,这些模块之间没有什么先后顺序可言,比如:

  1. class net3(nn.Module):
  2. def __init__(self):
  3. super(net3, self).__init__()
  4. self.linears = nn.ModuleList([nn.Linear(10,20), nn.Linear(20,30), nn.Linear(5,10)])
  5. def forward(self, x):
  6. x = self.linears[2](x)
  7. x = self.linears[0](x)
  8. x = self.linears[1](x)
  9. return x
  10. net = net3()
  11. print(net)
  12. # net3(
  13. # (linears): ModuleList(
  14. # (0): Linear(in_features=10, out_features=20, bias=True)
  15. # (1): Linear(in_features=20, out_features=30, bias=True)
  16. # (2): Linear(in_features=5, out_features=10, bias=True)
  17. # )
  18. # )
  19. input = torch.randn(32, 5)
  20. print(net(input).shape)
  21. # torch.Size([32, 30])

根据 net3 的结果,可以看出来这个 ModuleList 里面的顺序并不能决定什么,网络的执行顺序是根据 forward 函数来决定的。
如果非要 ModuleList 和 forward 中的顺序不一样, PyTorch 表示它无所谓,但以后 review 代码的人可能会意见比较大。
再考虑另外一种情况,既然这个 ModuleList 可以根据序号来调用,那么一个模块是否可以在 forward 函数中被调用多次呢?
答案当然是可以的,但是,被调用多次的模块,是使用同一组 parameters 的,也就是它们的参数是共享的,无论之后怎么更新。例子如下,虽然在 forward 中用了 **nn.Linear(10,10)** 两次,但是它们只有一组参数。这么做有什么用处呢,目前没有想到…

  1. class net4(nn.Module):
  2. def __init__(self):
  3. super(net4, self).__init__()
  4. self.linears = nn.ModuleList([nn.Linear(5, 10), nn.Linear(10, 10)])
  5. def forward(self, x):
  6. x = self.linears[0](x)
  7. x = self.linears[1](x)
  8. x = self.linears[1](x)
  9. return x
  10. net = net4()
  11. print(net)
  12. # net4(
  13. # (linears): ModuleList(
  14. # (0): Linear(in_features=5, out_features=10, bias=True)
  15. # (1): Linear(in_features=10, out_features=10, bias=True)
  16. # )
  17. # )
  18. for name, param in net.named_parameters():
  19. print(name, param.size())
  20. # linears.0.weight torch.Size([10, 5])
  21. # linears.0.bias torch.Size([10])
  22. # linears.1.weight torch.Size([10, 10])
  23. # linears.1.bias torch.Size([10])

nn.Sequential

现在来研究一下 **nn.Sequential**不同于 **nn.ModuleList**它已经实现了内部的 forward 函数,而且里面的模块必须是按照顺序进行排列的,所以必须确保前一个模块的输出大小和下一个模块的输入大小是一致的,如下面的例子所示:

  1. class net5(nn.Module):
  2. def __init__(self):
  3. super(net5, self).__init__()
  4. self.block = nn.Sequential(nn.Conv2d(1,20,5),
  5. nn.ReLU(),
  6. nn.Conv2d(20,64,5),
  7. nn.ReLU())
  8. def forward(self, x):
  9. x = self.block(x)
  10. return x
  11. net = net5()
  12. print(net)
  13. # net5(
  14. # (block): Sequential(
  15. # (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  16. # (1): ReLU()
  17. # (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  18. # (3): ReLU()
  19. # )
  20. # )

下面给出了两个 nn.Sequential 初始化的例子,来自于官网教程。在第二个初始化中用到了 OrderedDict 来指定每个 module 的名字,而不是采用默认的命名方式 (按序号 0,1,2,3…) 。

  1. # https://pytorch.org/docs/stable/nn.html#sequential
  2. # Example of using Sequential
  3. model1 = nn.Sequential(
  4. nn.Conv2d(1,20,5),
  5. nn.ReLU(),
  6. nn.Conv2d(20,64,5),
  7. nn.ReLU()
  8. )
  9. print(model1)
  10. # Sequential(
  11. # (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  12. # (1): ReLU()
  13. # (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  14. # (3): ReLU()
  15. # )
  16. # Example of using Sequential with OrderedDict
  17. import collections
  18. model2 = nn.Sequential(collections.OrderedDict([
  19. ('conv1', nn.Conv2d(1,20,5)),
  20. ('relu1', nn.ReLU()),
  21. ('conv2', nn.Conv2d(20,64,5)),
  22. ('relu2', nn.ReLU())
  23. ]))
  24. print(model2)
  25. # Sequential(
  26. # (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  27. # (relu1): ReLU()
  28. # (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  29. # (relu2): ReLU()
  30. # )

有同学可能发现了,这个 model1 和 从类 net5 实例化来的 net 有什么区别吗?是没有的。这两个网络是相同的,因为 nn.Sequential 就是一个 nn.Module 的子类,也就是 nn.Module 所有的方法 (method) 它都有。并且直接使用 **nn.Sequential** 不用写 forward 函数,因为它内部已经写好了。
这时候有同学该说了,既然 nn.Sequential 这么好,以后都直接用它了。如果确定 nn.Sequential 里面的顺序是想要的,而且不需要再添加一些其他处理的函数,那么完全可以直接用 nn.Sequential。这么做的代价就是失去了部分灵活性,毕竟不能自己去定制 forward 函数里面的内容了。
一般情况下 nn.Sequential用法是来组成卷积块 (block),然后像拼积木一样把不同的 block 拼成整个网络,让代码更简洁,更加结构化。

nn.ModuleListnn.Sequential:到底该用哪个

前边已经简单介绍了这两个类,现在来讨论一下在两个不同的场景中,选择哪一个比较合适
场景一,有的时候网络中有很多相似或者重复的层,一般会考虑用 for 循环来创建它们,而不是一行一行地写,比如:

  1. layers = [nn.Linear(10, 10) for i in range(5)]

这个时候,很自然而然地,会想到使用 ModuleList,像这样:

  1. class net6(nn.Module):
  2. def __init__(self):
  3. super(net6, self).__init__()
  4. self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(3)])
  5. def forward(self, x):
  6. for layer in self.linears:
  7. x = layer(x)
  8. return x
  9. net = net6()
  10. print(net)
  11. # net6(
  12. # (linears): ModuleList(
  13. # (0): Linear(in_features=10, out_features=10, bias=True)
  14. # (1): Linear(in_features=10, out_features=10, bias=True)
  15. # (2): Linear(in_features=10, out_features=10, bias=True)
  16. # )
  17. # )

这个是比较一般的方法,但如果不想这么麻烦,也可以用 Sequential 来实现,如 net7 所示!注意 * 这个操作符,它可以把一个 list 拆开成一个个独立的元素
但是,请注意这个 list 里面的模块必须是按照想要的顺序来进行排列的。在 场景一 中,个人觉得使用 net7 这种方法比较方便和整洁。

  1. class net7(nn.Module):
  2. def __init__(self):
  3. super(net7, self).__init__()
  4. self.linear_list = [nn.Linear(10, 10) for i in range(3)]
  5. self.linears = nn.Sequential(*self.linears_list)
  6. def forward(self, x):
  7. self.x = self.linears(x)
  8. return x
  9. net = net7()
  10. print(net)
  11. # net7(
  12. # (linears): Sequential(
  13. # (0): Linear(in_features=10, out_features=10, bias=True)
  14. # (1): Linear(in_features=10, out_features=10, bias=True)
  15. # (2): Linear(in_features=10, out_features=10, bias=True)
  16. # )
  17. # )

下面考虑 场景二,当需要之前层的信息的时候,比如 ResNets 中的 shortcut 结构,或者是像 FCN 中用到的 skip architecture 之类的,当前层的结果需要和之前层中的结果进行融合,一般使用 ModuleList 比较方便,一个非常简单的例子如下:

  1. class net8(nn.Module):
  2. def __init__(self):
  3. super(net8, self).__init__()
  4. self.linears = nn.ModuleList([nn.Linear(10, 20), nn.Linear(20, 30), nn.Linear(30, 50)])
  5. self.trace = []
  6. def forward(self, x):
  7. for layer in self.linears:
  8. x = layer(x)
  9. self.trace.append(x)
  10. return x
  11. net = net8()
  12. input = torch.randn(32, 10) # input batch size: 32
  13. output = net(input)
  14. for each in net.trace:
  15. print(each.shape)
  16. # torch.Size([32, 20])
  17. # torch.Size([32, 30])
  18. # torch.Size([32, 50])

使用了一个 trace 的列表来储存网络每层的输出结果,这样如果以后的层要用的话,就可以很方便地调用了。

总结

通过一些实例学习了 ModuleList 和 Sequential 这两种 nn containers,ModuleList 就是一个储存各种模块的 list,这些模块之间没有联系没有实现 forward 功能,但相比于普通的 Python list,ModuleList 可以把添加到其中的模块和参数自动注册到网络上
Sequential 内的模块需要按照顺序排列,要保证相邻层的输入输出大小相匹配内部 forward 功能已经实现,可以使代码更加整洁。在不同场景中,如果二者都适用,那就看个人偏好了。非常推荐大家看一下 PyTorch 官方的 TorchVision 下面模型实现的代码,能学到很多构建网络的技巧。