序列化与反序列化
序列化与反序列化一般描述的是内存与硬盘之间的数据转换关系。
在内存中,模型中的数据是以对象的形式存储的,而在硬盘中是以二进制序列的形式进行保存的。
所以序列化就是将内存中模型的对象以二进制的形式存储在硬盘中;反序列化就是将二进制序列转化成内存中的对象从而可以使用模型。
序列化是为了能够长久地保存模型。
torch.save
主要参数:
- obj:对象
-
torch.load
主要参数:
f:文件路径
map_location:指定存放位置,cpu or gpu。
- 因为gpu模式下训练保存的模型不能采用default加载进来。如果是cpu就不用关心
模型保存与加载的两种方式
法1:保存整个Module
torch.save(net,path)
参数:
- 因为gpu模式下训练保存的模型不能采用default加载进来。如果是cpu就不用关心
net:保存的网络
- path:保存路径
因为这里保存的是整个Module(包括8个有序字典),所以对应的加载方式就是直接将Module加载出来即可:
path_model = "./model.pkl"net_load = torch.load(path_model)
就加载好了net
法2:保存模型参数(官方推荐)
state_dict=net.state_dict()
torch.save(state_dict,path)
参数:
- state_dict:保存的对象,这里放的就是通过net.state_dict()获得的可学习参数
- path:保存路径
因为这里只保存了可学习参数,所以对应的加载方式为:
path_state_dict = "./model_state_dict.pkl"state_dict_load = torch.load(path_state_dict) # 这里只加载了可学习参数net_new = LeNet2(classes=2019) # 新建一个网络net_new.load_state_dict(state_dict_load) # 这里将可学习参数加载进模型中
模型断点续训练

首先我们知道Pytorch的模型训练流程包括以上五个模块。其中数据会变动的只有模型(可学习参数)和优化器(Momentum中之前的更新量)。
因此我们在模型保存的时候只需要保存模型和优化器中的参数。也就是构建一个如下的字典:
其中epoch记录这是第几个epoch的参数信息
代码示例
模型保存
这里每隔checkpoint_interval个epoch保存一次模型,注意写在epoch的循环里:
if (epoch+1) % checkpoint_interval == 0:checkpoint = {"model_state_dict": net.state_dict(),"optimizer_state_dic": optimizer.state_dict(),"loss": loss,"epoch": epoch}path_checkpoint = "./checkpint_{}_epoch.pkl".format(epoch)torch.save(checkpoint, path_checkpoint)
断点恢复
path_checkpoint = "./checkpoint_4_epoch.pkl"checkpoint = torch.load(path_checkpoint)net.load_state_dict(checkpoint['model_state_dict']) # 加载Module权重optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # 加载优化器参数start_epoch = checkpoint['epoch']scheduler.last_epoch = start_epoch # 学习率调整时记录上一个epoch# ============================ step 5/5 训练 =====================================train_curve = list()valid_curve = list()for epoch in range(start_epoch + 1, MAX_EPOCH):loss_mean = 0.correct = 0.total = 0.net.train()for i, data in enumerate(train_loader):# forwardinputs, labels = dataoutputs = net(inputs)# backwardoptimizer.zero_grad()loss = criterion(outputs, labels)loss.backward()# update weightsoptimizer.step()# 统计分类情况_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).squeeze().sum().numpy()# 打印训练信息loss_mean += loss.item()train_curve.append(loss.item())if (i+1) % log_interval == 0:loss_mean = loss_mean / log_intervalprint("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))loss_mean = 0.scheduler.step() # 更新学习率
