15.nn_optim

https://pytorch.org/docs/stable/optim.html

优化器用于神经网络参数进行更新,在backward后使用,要注意的是,每次backward前都要进行清零。

Example:

  1. for input, target in dataset:
  2. optimizer.zero_grad()#梯度清零
  3. output = model(input)
  4. loss = loss_fn(output, target)#计算损失
  5. loss.backward()#反向传播,计算loss对各个参数的grad
  6. optimizer.step()#优化器参数更新

我们在:梯度清零,方向传播,参数更新处设置三个断点进行debug

demo/Protected Attributes/modules/model/Protected Attributes/modules/2 (查看第二个卷积层处的参数)

debug,梯度清零 (因为是第一个循环,所以grad初始值为0)

debug,backward,计算出梯度

debug,参数更新,data中的参数开始更新

for循环执行完一轮后,参数变化不够,训练时必须进行多个epoch。

  1. import torch
  2. import torchvision
  3. from torch import nn
  4. from torch.nn import Conv2d, MaxPool2d, Linear, Flatten
  5. from torch.utils.data import DataLoader
  6. dataset = torchvision.datasets.CIFAR10('./dataset',train=False,transform=torchvision.transforms.ToTensor(),
  7. download=True)
  8. dataloader = DataLoader(dataset,batch_size=1)
  9. class DEMO(nn.Module):
  10. def __init__(self):
  11. super(DEMO, self).__init__()
  12. self.model = nn.Sequential(
  13. Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),
  14. MaxPool2d(kernel_size=2),
  15. Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),
  16. MaxPool2d(kernel_size=2),
  17. Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
  18. MaxPool2d(kernel_size=2),
  19. Flatten(),
  20. Linear(in_features=1024, out_features=64),
  21. Linear(in_features=64, out_features=10),
  22. )
  23. def forward(self,x):
  24. x = self.model(x)
  25. return x
  26. demo = DEMO()
  27. loss_cross = nn.CrossEntropyLoss()
  28. optim = torch.optim.SGD(demo.parameters(),lr= 0.01)
  29. # range(2), 也就是两个epoch,epoch会执行 0 1 而不是 0 1 2
  30. for epoch in range(2):
  31. running_loss = 0.0
  32. for data in dataloader:
  33. imgs,targets = data
  34. output = demo(imgs)
  35. # print(output) output神经网络输出的一组得分(不是概率)
  36. # print(targets) target就是目标
  37. loss = loss_cross(output,targets)
  38. optim.zero_grad()
  39. loss.backward()
  40. optim.step()
  41. running_loss = running_loss + loss
  42. print(running_loss) # 每个epoch的总损失
  43. print("epoch =",epoch+1)
  44. print('ok')