image.png
    Pytorch Tensor的通道排序:[batch,channel,height,width]

    1. import torch.nn as nn
    2. import torch.nn.functional as F

    搭建网络的流程:
    1、定义一个类,继承于父类nn.Module
    2、类中定义两个方法,第一个是初始化函数,在初始化函数中会实现搭建网络过程中需要的网络层结构;第二个是forward函数,定义正向传播过程

    实例化类之后,讲参数传入实例中,就会进行正向传播

    使用super函数,定义类的过程中继承了nn.Module这个类,而super函数解决在多层继承中调用父类方法中可能出现的一系列问题,

    1. class LeNet(nn.Module):
    2. def __init__(self):
    3. super(LeNet,self).__init__()
    4. self.conv1 = nn.Conv2d(3,16,5)
    5. self.pool1 = nn.MaxPool2d(2,2)
    6. self.conv2 = nn.Conv2d(16,32,5)
    7. self.pool2 = nn.MaxPool2d(2,2)
    8. self.fc1 = nn.Linear(32*5*5,120)
    9. self.fc2 = nn.Linear(120,84)
    10. self.fc3 = nn.Linear(84,10)
    11. def forward(self,x):
    12. x = F.relu(self.conv1(x))
    13. x = self.pool1(x)
    14. x = F.relu(self.conv2(x))
    15. x = self.pool2(x)
    16. x = x.view(-1,32*5*5)
    17. x = F.relu(self.fc1(x))
    18. x = F.relu(self.fc2(x))
    19. x = self.fc3(x)
    20. return x