序列化与反序列化
序列化与反序列化一般描述的是内存与硬盘之间的数据转换关系。
在内存中,模型中的数据是以对象的形式存储的,而在硬盘中是以二进制序列的形式进行保存的。
所以序列化就是将内存中模型的对象以二进制的形式存储在硬盘中;反序列化就是将二进制序列转化成内存中的对象从而可以使用模型。
序列化是为了能够长久地保存模型。
torch.save
主要参数:
- obj:对象
-
torch.load
主要参数:
f:文件路径
map_location:指定存放位置,cpu or gpu。
- 因为gpu模式下训练保存的模型不能采用default加载进来。如果是cpu就不用关心
模型保存与加载的两种方式
法1:保存整个Module
torch.save(net,path)
参数:
- 因为gpu模式下训练保存的模型不能采用default加载进来。如果是cpu就不用关心
net:保存的网络
- path:保存路径
因为这里保存的是整个Module(包括8个有序字典),所以对应的加载方式就是直接将Module加载出来即可:
path_model = "./model.pkl"
net_load = torch.load(path_model)
就加载好了net
法2:保存模型参数(官方推荐)
state_dict=net.state_dict()
torch.save(state_dict,path)
参数:
- state_dict:保存的对象,这里放的就是通过net.state_dict()获得的可学习参数
- path:保存路径
因为这里只保存了可学习参数,所以对应的加载方式为:
path_state_dict = "./model_state_dict.pkl"
state_dict_load = torch.load(path_state_dict) # 这里只加载了可学习参数
net_new = LeNet2(classes=2019) # 新建一个网络
net_new.load_state_dict(state_dict_load) # 这里将可学习参数加载进模型中
模型断点续训练
首先我们知道Pytorch的模型训练流程包括以上五个模块。其中数据会变动的只有模型(可学习参数)和优化器(Momentum中之前的更新量)。
因此我们在模型保存的时候只需要保存模型和优化器中的参数。也就是构建一个如下的字典:
其中epoch记录这是第几个epoch的参数信息
代码示例
模型保存
这里每隔checkpoint_interval个epoch保存一次模型,注意写在epoch的循环里:
if (epoch+1) % checkpoint_interval == 0:
checkpoint = {"model_state_dict": net.state_dict(),
"optimizer_state_dic": optimizer.state_dict(),
"loss": loss,
"epoch": epoch}
path_checkpoint = "./checkpint_{}_epoch.pkl".format(epoch)
torch.save(checkpoint, path_checkpoint)
断点恢复
path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)
net.load_state_dict(checkpoint['model_state_dict']) # 加载Module权重
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # 加载优化器参数
start_epoch = checkpoint['epoch']
scheduler.last_epoch = start_epoch # 学习率调整时记录上一个epoch
# ============================ step 5/5 训练 =====================================
train_curve = list()
valid_curve = list()
for epoch in range(start_epoch + 1, MAX_EPOCH):
loss_mean = 0.
correct = 0.
total = 0.
net.train()
for i, data in enumerate(train_loader):
# forward
inputs, labels = data
outputs = net(inputs)
# backward
optimizer.zero_grad()
loss = criterion(outputs, labels)
loss.backward()
# update weights
optimizer.step()
# 统计分类情况
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).squeeze().sum().numpy()
# 打印训练信息
loss_mean += loss.item()
train_curve.append(loss.item())
if (i+1) % log_interval == 0:
loss_mean = loss_mean / log_interval
print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
loss_mean = 0.
scheduler.step() # 更新学习率