图像识别实现流程图

加载图像数据集

  1. #需要用到的库
  2. import torchvision
  3. import torchvision.transforms as transforms
  4. from torch.utils.data import DataLoader
  5. train_batch_size = 4
  6. test_batch_size = 4
  7. num_workers = 0 #线程数
  8. #加载数据集
  9. transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])]) #用于数据增强
  10. train_dataset = torchvision.datasets.CIFAR10('./dataset',train=True,transform=transform,download=True)
  11. test_dataset = torchvision.datasets.CIFAR10('./dataset',train=False,transform=transform,download=True)
  12. train_loader = DataLoader(train_dataset,batch_size=train_batch_size,shuffle=True,num_workers=num_workers)
  13. test_loader = DataLoader(test_dataset,batch_size=test_batch_size,shuffle=False,num_workers=num_workers)

完整代码

  1. import torch
  2. import torchvision
  3. import torchvision.transforms as transforms
  4. from torch.utils.data import DataLoader
  5. train_batch_size = 4
  6. test_batch_size = 4
  7. num_workers = 0 #线程数
  8. classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
  9. lr = 0.001
  10. momentum = 0.9
  11. #加载数据集
  12. transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])
  13. train_dataset = torchvision.datasets.CIFAR10('./dataset',train=True,transform=transform,download=True)
  14. test_dataset = torchvision.datasets.CIFAR10('./dataset',train=False,transform=transform,download=True)
  15. train_loader = DataLoader(train_dataset,batch_size=train_batch_size,shuffle=True,num_workers=num_workers)
  16. test_loader = DataLoader(test_dataset,batch_size=test_batch_size,shuffle=False,num_workers=num_workers)
  17. #数据可视化
  18. import matplotlib.pyplot as plt
  19. import numpy as np
  20. plt.figure()
  21. def imshow(img):
  22. img = img/2 +0.5
  23. npimg = img.numpy()
  24. plt.imshow(np.transpose(npimg,(1,2,0)))
  25. plt.show()
  26. examples = enumerate(train_loader)
  27. idx, (examples_data, examples_target) = next(examples) #examples_target是标签列表,0-9表示不同的类别
  28. imshow(torchvision.utils.make_grid(examples_data))
  29. #用于具体查看examples
  30. print('--------------测试examples------------')
  31. print('examples_target.shape:{}'.format(examples_target.shape))
  32. print('examples_target[0]:{}'.format(examples_target[0]))
  33. print('examples_data.shape:{}'.format(examples_data.shape))
  34. #构建网络
  35. import torch.nn as nn
  36. import torch.nn.functional as F
  37. device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  38. class CNNNet(nn.Module):
  39. def __init__(self):
  40. super(CNNNet, self).__init__()
  41. self.conv1 = nn.Conv2d(in_channels=3,out_channels=16,kernel_size=5,stride=1)
  42. self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2)
  43. self.conv2 = nn.Conv2d(in_channels=16,out_channels=36,kernel_size=3,stride=1)
  44. self.pool2 = nn.MaxPool2d(kernel_size=2,stride=2)
  45. #self.aap = nn.AdaptiveAvgPool2d(1)
  46. self.fc1 = nn.Linear(1296,128)
  47. self.fc2 = nn.Linear(128,10)
  48. #self.fc3 = nn.Linear(36,10)
  49. def forward(self,x):
  50. x = self.pool1(F.relu(self.conv1(x)))
  51. x = self.pool2(F.relu(self.conv2(x)))
  52. #x = self.aap(x)
  53. #x = x.view(x.shape[0],-1)
  54. #x = self.fc3(x)
  55. x = x.view(-1,36*6*6)
  56. #print("x.shape:{}".format(x.shape))
  57. x = F.relu(self.fc2(F.relu(self.fc1(x))))
  58. return x
  59. model = CNNNet()
  60. model = model.to(device)
  61. print('--------------查看网络结构-----------')
  62. print(model)
  63. #--训练模型--
  64. print('-----训练优化器-------')
  65. import torch.optim as optim
  66. criterion = nn.CrossEntropyLoss()
  67. optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
  68. print("----------正式训练模型---------")
  69. losses = []
  70. acces = []
  71. eval_losses = []
  72. eval_acces = []
  73. for epoch in range(10):
  74. train_acc = 0
  75. train_loss = 0
  76. num_correct = 0
  77. model.train()
  78. for i, data in enumerate(train_loader):
  79. img, label = data
  80. img, label = img.to(device), label.to(device)
  81. #权重参数梯度清零
  82. optimizer.zero_grad()
  83. #正向反向传播
  84. out = model(img)
  85. loss = criterion(out, label)
  86. loss.backward()
  87. optimizer.step()
  88. #计算损失值
  89. train_loss += loss.item()
  90. #计算准确率
  91. _, pred = out.max(1)
  92. num_correct += (pred == label).sum()
  93. if i % 2000 == 1999:
  94. print('[%d,%5d] loss : %.3f' % (epoch + 1, i + 1, train_loss / 2000))
  95. train_loss = 0.0
  96. acces.append(num_correct/(len(train_loader)*train_batch_size))
  97. #精确率可视化
  98. plt.title('Train Acc')
  99. plt.plot(np.arange(len(acces)),acces)
  100. plt.legend(['Train Acc'],loc='upper right')
  101. plt.show()
  102. #测试模型
  103. eval_loss = 0
  104. eval_acc = 0
  105. class_correct = list(0. for i in range(10))
  106. class_total = list(0. for i in range(10))
  107. total = 0
  108. model.eval()
  109. with torch.no_grad():
  110. for img, label in test_loader:
  111. img, label = img.to(device), label.to(device)
  112. out = model(img)
  113. #计算损失值
  114. loss = criterion(out,label)
  115. eval_loss += loss.item()
  116. #计算准确率
  117. _, pred = out.max(1)
  118. #print("len(label):{}".format(len(label)))
  119. num_correct += (pred == label).sum()
  120. c = (pred == label).squeeze()
  121. acc = num_correct/len(label)
  122. eval_acc += acc
  123. total += label.size(0)
  124. #计算各类别准确率
  125. for i in range(4):
  126. class_correct[label[i]] += c[i].item()
  127. class_total[label[i]] += 1
  128. eval_losses.append(eval_loss/total)
  129. eval_acces.append(eval_acc/total)
  130. print("total:{}".format(total))
  131. print("len(test_loader):{}".format(len(test_loader)))
  132. for i in range(10):
  133. print("accuracy of {}:{}%".format(classes[i],100*class_correct[i]/class_total[i]))
  134. print("----------------")
  135. #print('epoch:{}, eval_loss:{:.4f},eval_acc:{:.4f}'.format(epoch,eval_loss/len(test_loader),eval_acc/len(test_loader)))
  136. #rint("Accuracy of the network on the 10000 test images:%d %%" % (100 * eval_acc / len(test_loader)))

输出结果

image.png

部分输出结果解释

[epoch, i] (如[10,8000]),10代表的是进行到的迭代次数,8000指的是dataloader进行到了第8000份(每2000份打印一回),由于train_loader一共有12500份,因此最多只能打印到12000份,不会打印到14000份。但是实际上会进行到12500份的。