做深度学习时,最麻烦的一件事情就是等待,等待训练过程的结束往往要耗费大量时间,如果我们曾经训练过某个模型,下次使用时又需要重新训练一遍,为此不得不再等上好久,这部分等待时间是毫无意义的。因此训练过模型之后,及时地添加代码然后保存模型就可以节省很多时间,下一次要用到模型的时候,直接加载即可。
直接保存/加载模型本身
下面以利用内置数据集一览Pytorch实现图像识别为例,展示模型保存和模型加载的用法。
保存模型
# ...模型训练完成
torch.save(model, '/mnt/main/pkl_checkpoint_file/plain_CNNet.pkl') #保存速度很慢
加载模型并测试
# 导入模块和加载数据集部分略
model = torch.load('/mnt/main/pkl_checkpoint_file/plain_CNNet.pkl')
model = model.to(device)
print('--------------查看网络结构-----------')
print(model)
#测试模型
eval_loss = 0
eval_acc = 0
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
total = 0
model.eval()
with torch.no_grad():
for img, label in test_loader:
img, label = img.to(device), label.to(device)
out = model(img)
#计算损失值
loss = criterion(out,label)
eval_loss += loss.item()
#计算准确率
_, pred = out.max(1)
#print("len(label):{}".format(len(label)))
num_correct += (pred == label).sum()
c = (pred == label).squeeze()
acc = num_correct/len(label)
eval_acc += acc
total += label.size(0)
#计算各类别准确率
for i in range(4):
class_correct[label[i]] += c[i].item()
class_total[label[i]] += 1
eval_losses.append(eval_loss/total)
eval_acces.append(eval_acc/total)
print("total:{}".format(total))
print("len(test_loader):{}".format(len(test_loader)))
for i in range(10):
print("accuracy of {}:{}%".format(classes[i],100*class_correct[i]/class_total[i]))
print("----------------")
输出结果
保存/加载模型的参数
这种只保存参数未保存结构的方法的优点就是保存、加载都很快。但是由于没有保存结构,因此加载之前还是需要先定义一下结构。加载的时候,是把参数加载到空的网络之中。
示例:(莫烦习惯用这个)
torch.save(net1.state_dict(),'net_params.pkl')
net2 = torchvision.models.resnet18()
net2.load_state_dict(torch.load('net_params.pkl'))
下面同样以利用内置数据集一览Pytorch实现图像识别为例,展示只保存模型参数和加载模型参数的用法。
保存模型参数
# ...模型训练完成
torch.save(model.state_dict(),'/mnt/main/pkl_checkpoint_file/plain_CNNet_params.pkl') #保存速度极快
加载模型参数
# 导入模块和加载数据集部分略
#定义网络结构
class CNNNet(nn.Module):
def __init__(self):
super(CNNNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3,out_channels=16,kernel_size=5,stride=1)
self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2)
self.conv2 = nn.Conv2d(in_channels=16,out_channels=36,kernel_size=3,stride=1)
self.pool2 = nn.MaxPool2d(kernel_size=2,stride=2)
#self.aap = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Linear(1296,128)
self.fc2 = nn.Linear(128,10)
#self.fc3 = nn.Linear(36,10)
def forward(self,x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
#x = self.aap(x)
#x = x.view(x.shape[0],-1)
#x = self.fc3(x)
x = x.view(-1,36*6*6)
#print("x.shape:{}".format(x.shape))
x = F.relu(self.fc2(F.relu(self.fc1(x))))
return x
model2 = CNNNet()
model2 = model2.to(device)
print('--------------查看网络结构-----------')
print(model2)
model2.load_state_dict(torch.load('/mnt/main/pkl_checkpoint_file/plain_CNNet_params.pkl'))
#测试模型
eval_loss = 0
eval_acc = 0
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
total = 0
model2.eval()
with torch.no_grad():
for img, label in test_loader:
img, label = img.to(device), label.to(device)
out = model2(img)
#计算损失值
loss = criterion(out,label)
eval_loss += loss.item()
#计算准确率
_, pred = out.max(1)
#print("len(label):{}".format(len(label)))
num_correct += (pred == label).sum()
c = (pred == label).squeeze()
acc = num_correct/len(label)
eval_acc += acc
total += label.size(0)
#计算各类别准确率
for i in range(4):
class_correct[label[i]] += c[i].item()
class_total[label[i]] += 1
eval_losses.append(eval_loss/total)
eval_acces.append(eval_acc/total)
print("total:{}".format(total))
print("len(test_loader):{}".format(len(test_loader)))
for i in range(10):
print("accuracy of {}:{}%".format(classes[i],100*class_correct[i]/class_total[i]))
print("----------------")
输出结果
[