图像识别实现流程图
加载图像数据集
#需要用到的库import torchvisionimport torchvision.transforms as transformsfrom torch.utils.data import DataLoadertrain_batch_size = 4test_batch_size = 4num_workers = 0 #线程数#加载数据集transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])]) #用于数据增强train_dataset = torchvision.datasets.CIFAR10('./dataset',train=True,transform=transform,download=True)test_dataset = torchvision.datasets.CIFAR10('./dataset',train=False,transform=transform,download=True)train_loader = DataLoader(train_dataset,batch_size=train_batch_size,shuffle=True,num_workers=num_workers)test_loader = DataLoader(test_dataset,batch_size=test_batch_size,shuffle=False,num_workers=num_workers)
完整代码
import torchimport torchvisionimport torchvision.transforms as transformsfrom torch.utils.data import DataLoadertrain_batch_size = 4test_batch_size = 4num_workers = 0 #线程数classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')lr = 0.001momentum = 0.9#加载数据集transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])train_dataset = torchvision.datasets.CIFAR10('./dataset',train=True,transform=transform,download=True)test_dataset = torchvision.datasets.CIFAR10('./dataset',train=False,transform=transform,download=True)train_loader = DataLoader(train_dataset,batch_size=train_batch_size,shuffle=True,num_workers=num_workers)test_loader = DataLoader(test_dataset,batch_size=test_batch_size,shuffle=False,num_workers=num_workers)#数据可视化import matplotlib.pyplot as pltimport numpy as npplt.figure()def imshow(img):img = img/2 +0.5npimg = img.numpy()plt.imshow(np.transpose(npimg,(1,2,0)))plt.show()examples = enumerate(train_loader)idx, (examples_data, examples_target) = next(examples) #examples_target是标签列表,0-9表示不同的类别imshow(torchvision.utils.make_grid(examples_data))#用于具体查看examplesprint('--------------测试examples------------')print('examples_target.shape:{}'.format(examples_target.shape))print('examples_target[0]:{}'.format(examples_target[0]))print('examples_data.shape:{}'.format(examples_data.shape))#构建网络import torch.nn as nnimport torch.nn.functional as Fdevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')class CNNNet(nn.Module):def __init__(self):super(CNNNet, self).__init__()self.conv1 = nn.Conv2d(in_channels=3,out_channels=16,kernel_size=5,stride=1)self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2)self.conv2 = nn.Conv2d(in_channels=16,out_channels=36,kernel_size=3,stride=1)self.pool2 = nn.MaxPool2d(kernel_size=2,stride=2)#self.aap = nn.AdaptiveAvgPool2d(1)self.fc1 = nn.Linear(1296,128)self.fc2 = nn.Linear(128,10)#self.fc3 = nn.Linear(36,10)def forward(self,x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))#x = self.aap(x)#x = x.view(x.shape[0],-1)#x = self.fc3(x)x = x.view(-1,36*6*6)#print("x.shape:{}".format(x.shape))x = F.relu(self.fc2(F.relu(self.fc1(x))))return xmodel = CNNNet()model = model.to(device)print('--------------查看网络结构-----------')print(model)#--训练模型--print('-----训练优化器-------')import torch.optim as optimcriterion = nn.CrossEntropyLoss()optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)print("----------正式训练模型---------")losses = []acces = []eval_losses = []eval_acces = []for epoch in range(10):train_acc = 0train_loss = 0num_correct = 0model.train()for i, data in enumerate(train_loader):img, label = dataimg, label = img.to(device), label.to(device)#权重参数梯度清零optimizer.zero_grad()#正向反向传播out = model(img)loss = criterion(out, label)loss.backward()optimizer.step()#计算损失值train_loss += loss.item()#计算准确率_, pred = out.max(1)num_correct += (pred == label).sum()if i % 2000 == 1999:print('[%d,%5d] loss : %.3f' % (epoch + 1, i + 1, train_loss / 2000))train_loss = 0.0acces.append(num_correct/(len(train_loader)*train_batch_size))#精确率可视化plt.title('Train Acc')plt.plot(np.arange(len(acces)),acces)plt.legend(['Train Acc'],loc='upper right')plt.show()#测试模型eval_loss = 0eval_acc = 0class_correct = list(0. for i in range(10))class_total = list(0. for i in range(10))total = 0model.eval()with torch.no_grad():for img, label in test_loader:img, label = img.to(device), label.to(device)out = model(img)#计算损失值loss = criterion(out,label)eval_loss += loss.item()#计算准确率_, pred = out.max(1)#print("len(label):{}".format(len(label)))num_correct += (pred == label).sum()c = (pred == label).squeeze()acc = num_correct/len(label)eval_acc += acctotal += label.size(0)#计算各类别准确率for i in range(4):class_correct[label[i]] += c[i].item()class_total[label[i]] += 1eval_losses.append(eval_loss/total)eval_acces.append(eval_acc/total)print("total:{}".format(total))print("len(test_loader):{}".format(len(test_loader)))for i in range(10):print("accuracy of {}:{}%".format(classes[i],100*class_correct[i]/class_total[i]))print("----------------")#print('epoch:{}, eval_loss:{:.4f},eval_acc:{:.4f}'.format(epoch,eval_loss/len(test_loader),eval_acc/len(test_loader)))#rint("Accuracy of the network on the 10000 test images:%d %%" % (100 * eval_acc / len(test_loader)))
输出结果

部分输出结果解释
[epoch, i] (如[10,8000]),10代表的是进行到的迭代次数,8000指的是dataloader进行到了第8000份(每2000份打印一回),由于train_loader一共有12500份,因此最多只能打印到12000份,不会打印到14000份。但是实际上会进行到12500份的。
