优化器
import torch
from torch import nn
import torchvision
from torch.nn import Conv2d, Sequential, MaxPool2d, Flatten, Linear, CrossEntropyLoss
from torch.utils.data import DataLoader
dataset = torchvision.datasets.CIFAR10("./download_data",transform=torchvision.transforms.ToTensor(),
train=False, download=True)
dataloader = DataLoader(dataset,batch_size=1)
class Liucy(nn.Module):
def __init__(self):
super(Liucy,self).__init__()
self.modle = Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self,x):
x = self.modle(x)
return x
liu = Liucy()
cross = CrossEntropyLoss()
# 设置优化器
optim = torch.optim.SGD(liu.parameters(),lr=0.01) #lr为学习速率
for i in range(10):
lost_value = 0.0
for data in dataloader:
img,label = data
output = liu(img)
# print(output) #最后会有十个分类
result = cross(output,label)
optim.zero_grad()
result.backward()
optim.step()
lost_value = lost_value + result
print(lost_value)