13.Sequential
下图是CIFAR 10的网络框架图
一个RGB三通道的32×32的图片经过多次卷积、池化、展平和线性层之后输出output
本节内容只作为展示sequential和模型各层的搭建过程,未假如非线性层做分类。
卷积
以conv1为例,输入为3通道32×32,输出为32通道32×32,设置in_channels = 3 out_channels = 32即可,pytorch会自动生成32个通道为3的卷积核进行卷积运算,这里还涉及到padding的计算:
https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d
Hout = Wout = 32,dilation一般取1,kernel_size = 5,stride默认取1,可以得到padding为2
[弹幕]:same padding时,p=(f-1)\2 , f是过滤器的尺寸,也是卷积核的尺寸,也就是kernel的大小,为5,所以5-1除以2等于2。
import torchfrom torch import nnfrom torch.nn import Conv2d, MaxPool2d, Linearclass DEMO(nn.Module):def __init__(self):super(DEMO, self).__init__()self.conv1 = Conv2d(in_channels=3,out_channels=32,kernel_size=5,padding=2)self.maxpool = MaxPool2d(kernel_size=2)self.conv2 = Conv2d(in_channels=32,out_channels=32,kernel_size=5,padding=2)self.conv3 = Conv2d(in_channels=32,out_channels=64,kernel_size=5,padding=2)self.flatten = torch.nn.Flatten()self.linear1 = Linear(in_features=1024,out_features=64)self.linear2 = Linear(in_features=64,out_features=10)def forward(self,x):x = self.conv1(x)x = self.maxpool(x)x = self.conv2(x)x = self.maxpool(x)x = self.conv3(x)x = self.maxpool(x)x = self.flatten(x)x = self.linear1(x)x = self.linear2(x)return xdemo = DEMO()print(demo)#我们可以用torch.ones()创建一个值全为1的shape如自己设置的假想input,来测试网络结构#设置成CIFAR 10网络输入层的shapeinput = torch.ones((64,3,32,32))output = demo(input)print(output.shape)
Sequential
Sequential是一个顺序容器。模块将按照它们在构造函数中传递的顺序添加到其中,结构更加直观。
import torchfrom torch import nnfrom torch.nn import Conv2d, MaxPool2d, Linearfrom torch.utils.tensorboard import SummaryWriterclass DEMO(nn.Module):def __init__(self):super(DEMO, self).__init__()self.model = torch.nn.Sequential(Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),MaxPool2d(kernel_size=2),Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),MaxPool2d(kernel_size=2),Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),MaxPool2d(kernel_size=2),torch.nn.Flatten(),Linear(in_features=1024, out_features=64),Linear(in_features=64, out_features=10),)def forward(self,x):x = self.model(x)return xdemo = DEMO()print(demo)#我们可以用torch.ones()创建一个值全为1的shape如自己设置的假想input,来测试网络结构#设置成CIFAR 10网络输入层的shapeinput = torch.ones((64,3,32,32))output = demo(input)print(output.shape)writer = SummaryWriter('./logs_seq')writer.add_graph(demo,input)writer.close()
这样我们可以通过tensorboard来查看该结构
