完整的模型训练

  • train.py
  1. import torchvision
  2. from torch.utils.data import DataLoader
  3. from torch.utils.tensorboard import SummaryWriter
  4. from model import *
  5. train_data = torchvision.datasets.CIFAR10(root="download_data", train=True,
  6. transform=torchvision.transforms.ToTensor(), download=True)
  7. test_data = torchvision.datasets.CIFAR10(root="download_data", train=False,
  8. transform=torchvision.transforms.ToTensor(), download=True)
  9. data_size = len(train_data)
  10. test_data_size = len(test_data)
  11. print(data_size)
  12. print(test_data_size)
  13. # 格式化字符串
  14. print("数据集长度:{}".format(data_size))
  15. data_loader = DataLoader(train_data,batch_size=64)
  16. test_loader = DataLoader(test_data,batch_size=64)
  17. # 创建网络模型
  18. l = Lcy()
  19. # 定义损失函数
  20. loss_function = nn.CrossEntropyLoss()
  21. # 优化器
  22. # learning_rate = 0.01
  23. # 1e-2 : 1 x (10)^(-2) = 1 x 0.01 = 0.01
  24. # 两种写法表达的数字相同
  25. learning_rate = 1e-2
  26. optimizer = torch.optim.SGD(l.parameters(),lr=learning_rate)
  27. # 设置初始化参数
  28. # 记录训练的次数
  29. total_train_step = 0
  30. # 记录测试的次数
  31. total_test_step = 0
  32. # 训练的轮数
  33. epoch = 10
  34. writer = SummaryWriter("logs")
  35. for i in range(epoch):
  36. print("-------第{}轮训练-------".format(i+1))
  37. # 训练的步骤
  38. for data in data_loader:
  39. img, targets = data
  40. output = l(img)
  41. loss = loss_function(output,targets)
  42. optimizer.zero_grad()
  43. loss.backward()
  44. optimizer.step()
  45. total_test_step += 1
  46. if total_test_step % 100 == 0:
  47. print("训练次数{},loss:{}".format(total_test_step,loss))
  48. writer.add_scalar("train_loss",loss, total_test_step)
  49. #测试步骤:
  50. total_test_loss = 0
  51. with torch.no_grad():
  52. for data in test_loader:
  53. img, targets = data
  54. output = l(img)
  55. loss = loss_function(output,targets)
  56. total_test_loss += loss.item()
  57. print("整体测试上的损失:{}".format(total_test_loss))
  58. writer.add_scalar("test_loss",total_test_loss,total_test_step)
  59. writer.close()
  • model.py
  1. import torch
  2. from torch import nn
  3. class Lcy(nn.Module):
  4. def __init__(self):
  5. super(Lcy,self).__init__()
  6. self.model = nn.Sequential(
  7. nn.Conv2d(3,32,5,1,2),
  8. nn.MaxPool2d(2),
  9. nn.Conv2d(32,32,5,1,2),
  10. nn.MaxPool2d(2),
  11. nn.Conv2d(32,64,5,1,2),
  12. nn.MaxPool2d(2),
  13. nn.Flatten(),
  14. nn.Linear(64*4*4,64),
  15. nn.Linear(64,10)
  16. )
  17. def forward(self,input):
  18. output = self.model(input)
  19. return output
  20. if __name__ == '__main__':
  21. l = Lcy()
  22. input = torch.ones((64,3,32,32))
  23. output = l(input)
  24. print(output.shape)