13.Sequential

下图是CIFAR 10的网络框架图

一个RGB三通道的32×32的图片经过多次卷积、池化、展平和线性层之后输出output

本节内容只作为展示sequential和模型各层的搭建过程,未假如非线性层做分类。

卷积

以conv1为例,输入为3通道32×32,输出为32通道32×32,设置in_channels = 3 out_channels = 32即可,pytorch会自动生成32个通道为3的卷积核进行卷积运算,这里还涉及到padding的计算:

https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d

Hout = Wout = 32,dilation一般取1,kernel_size = 5,stride默认取1,可以得到padding为2

[弹幕]:same padding时,p=(f-1)\2 , f是过滤器的尺寸,也是卷积核的尺寸,也就是kernel的大小,为5,所以5-1除以2等于2。

  1. import torch
  2. from torch import nn
  3. from torch.nn import Conv2d, MaxPool2d, Linear
  4. class DEMO(nn.Module):
  5. def __init__(self):
  6. super(DEMO, self).__init__()
  7. self.conv1 = Conv2d(in_channels=3,out_channels=32,kernel_size=5,padding=2)
  8. self.maxpool = MaxPool2d(kernel_size=2)
  9. self.conv2 = Conv2d(in_channels=32,out_channels=32,kernel_size=5,padding=2)
  10. self.conv3 = Conv2d(in_channels=32,out_channels=64,kernel_size=5,padding=2)
  11. self.flatten = torch.nn.Flatten()
  12. self.linear1 = Linear(in_features=1024,out_features=64)
  13. self.linear2 = Linear(in_features=64,out_features=10)
  14. def forward(self,x):
  15. x = self.conv1(x)
  16. x = self.maxpool(x)
  17. x = self.conv2(x)
  18. x = self.maxpool(x)
  19. x = self.conv3(x)
  20. x = self.maxpool(x)
  21. x = self.flatten(x)
  22. x = self.linear1(x)
  23. x = self.linear2(x)
  24. return x
  25. demo = DEMO()
  26. print(demo)
  27. #我们可以用torch.ones()创建一个值全为1的shape如自己设置的假想input,来测试网络结构
  28. #设置成CIFAR 10网络输入层的shape
  29. input = torch.ones((64,3,32,32))
  30. output = demo(input)
  31. print(output.shape)

Sequential

Sequential是一个顺序容器。模块将按照它们在构造函数中传递的顺序添加到其中,结构更加直观。

https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html?highlight=sequential#torch.nn.Sequential

  1. import torch
  2. from torch import nn
  3. from torch.nn import Conv2d, MaxPool2d, Linear
  4. from torch.utils.tensorboard import SummaryWriter
  5. class DEMO(nn.Module):
  6. def __init__(self):
  7. super(DEMO, self).__init__()
  8. self.model = torch.nn.Sequential(
  9. Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),
  10. MaxPool2d(kernel_size=2),
  11. Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),
  12. MaxPool2d(kernel_size=2),
  13. Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
  14. MaxPool2d(kernel_size=2),
  15. torch.nn.Flatten(),
  16. Linear(in_features=1024, out_features=64),
  17. Linear(in_features=64, out_features=10),
  18. )
  19. def forward(self,x):
  20. x = self.model(x)
  21. return x
  22. demo = DEMO()
  23. print(demo)
  24. #我们可以用torch.ones()创建一个值全为1的shape如自己设置的假想input,来测试网络结构
  25. #设置成CIFAR 10网络输入层的shape
  26. input = torch.ones((64,3,32,32))
  27. output = demo(input)
  28. print(output.shape)
  29. writer = SummaryWriter('./logs_seq')
  30. writer.add_graph(demo,input)
  31. writer.close()

这样我们可以通过tensorboard来查看该结构