Pytorch Tensor的通道排序:[batch,channel,height,width]
import torch.nn as nn
import torch.nn.functional as F
搭建网络的流程:
1、定义一个类,继承于父类nn.Module
2、类中定义两个方法,第一个是初始化函数,在初始化函数中会实现搭建网络过程中需要的网络层结构;第二个是forward函数,定义正向传播过程
实例化类之后,讲参数传入实例中,就会进行正向传播
使用super函数,定义类的过程中继承了nn.Module这个类,而super函数解决在多层继承中调用父类方法中可能出现的一系列问题,
class LeNet(nn.Module):
def __init__(self):
super(LeNet,self).__init__()
self.conv1 = nn.Conv2d(3,16,5)
self.pool1 = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(16,32,5)
self.pool2 = nn.MaxPool2d(2,2)
self.fc1 = nn.Linear(32*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10)
def forward(self,x):
x = F.relu(self.conv1(x))
x = self.pool1(x)
x = F.relu(self.conv2(x))
x = self.pool2(x)
x = x.view(-1,32*5*5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x