transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
transforms.ToTensor:原本输入的图像格式是 (H x W x C)在(0,255)之间,通过ToTensor把每张输入图像转换成(C x H x W)格式,在(0,1)之间。
transforms.Normalize:正则化
#定义正向传播过程
def forward(self, x):
x = F.relu(self.conv1(x)) # input(3, 32, 32) output(16, 28, 28)
x = self.pool1(x) # output(16, 14, 14)
x = F.relu(self.conv2(x)) # output(32, 10, 10)
x = self.pool2(x) # output(32, 5, 5)
x = x.view(-1, 32*5*5) # output(32*5*5),通过view把数据展平成一维向量(batch,节点个数)
x = F.relu(self.fc1(x)) # output(120)
x = F.relu(self.fc2(x)) # output(84)
x = self.fc3(x) # output(10)
return x
view:通过view把数据展平成一维向量(batch,节点个数)