参考来源:
CSDN:在 Pytorch 中实现 early stopping
使用了这个 github
仓库中提供的 **spytorch_tools**
,而且该仓库中还有这个工具的使用案例,建议读者前往查看。在此感谢作者的代码分享。
注意:不要使用 **pip install pytorchtools**
获取相关代码,这是不对的可以直接去上述的仓库里把其中的 **pytorchtools.py**
下载(或复制)下来并放置在项目中,这样就可以正常执行 **from pytorchtools import EarlyStopping**
并使用 EarlyStopping
了。
实现
有了 **pytorch_tools**
工具后,使用 **early stopping**
就很简单了。
先从该工具类中导入 **EarlyStopping**
。
# import EarlyStopping
from pytorchtools import EarlyStopping
import torch.utils.data as Data # 用于创建 DataLoader
import torch.nn as nn
为了方便描述,这里还是会使用一些伪代码,如果你想阅读详细案例的话,不用犹豫直接看上述工具自带的案例代码。
model = yourModel() # 伪
# 指定损失函数,可以是其他损失函数,根据训练要求决定
criterion = nn.CrossEntropyLoss() # 交叉熵,注意该损失函数对自动对批量样本的损失取平均
# 指定优化器,可以是其他
optimizer = torch.optim.Adam(model.parameters())
# 初始化 early_stopping 对象
patience = 20 # 当验证集损失在连续20次训练周期中都没有得到降低时,停止模型训练,以防止模型过拟合
early_stopping = EarlyStopping(patience, verbose=True) # 关于 EarlyStopping 的代码可先看博客后面的内容
batch_size = 64 # 或其他,该参数属于超参,对于如何选择超参,你可以参考下我的上一篇博客
n_epochs = 100 # 可以设置大一些,毕竟你是希望通过 early stopping 来结束模型训练
#----------------------------------------------------------------
# 训练模型,直到 epoch == n_epochs 或者触发 early_stopping 结束训练
for epoch in range(1, n_epochs + 1):
# 建立训练数据的 DataLoader
training_dataset = Data.TensorDataset(X_train, y_train)
# 把dataset放到DataLoader中
data_loader = Data.DataLoader(
dataset=training_dataset,
batch_size=batch_size, # 批量大小
shuffle=True # 是否打乱数据顺序
)
#---------------------------------------------------
model.train() # 设置模型为训练模式
# 按小批量训练
for batch, (data, target) in enumerate(data_loader):
optimizer.zero_grad() # 清楚所有参数的梯度
output = model(data) # 输出模型预测值
loss = criterion(output, target) # 计算损失
loss.backward() # 计算损失对于各个参数的梯度
optimizer.step() # 执行单步优化操作:更新参数
#----------------------------------------------------
model.eval() # 设置模型为评估/测试模式
# 一般如果验证集不是很大的话,模型验证就不需要按批量进行了,但要注意输入参数的维度不能错
valid_output = model(X_val)
valid_loss = criterion(valid_output, y_val) # 注意这里的输入参数维度要符合要求,我这里为了简单,并未考虑这一点
early_stopping(valid_loss, model)
# 若满足 early stopping 要求
if early_stopping.early_stop:
print("Early stopping")
# 结束模型训练
break
# 获得 early stopping 时的模型参数
model.load_state_dict(torch.load('checkpoint.pt'))
以下是 **pytorch_tools**
工具的代码:
import numpy as np
import torch
class EarlyStopping:
"""Early stops the training if validation loss doesn't improve after a given patience."""
def __init__(self, patience=7, verbose=False, delta=0):
"""
Args:
patience (int): How long to wait after last time validation loss improved.
Default: 7
verbose (bool): If True, prints a message for each validation loss improvement.
Default: False
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
Default: 0
"""
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.delta:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0
def save_checkpoint(self, val_loss, model):
'''Saves model when validation loss decrease.'''
if self.verbose:
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
torch.save(model.state_dict(), 'checkpoint.pt') # 这里会存储迄今最优模型的参数
self.val_loss_min = val_loss
不过这些代码还是需要读者根据自己的模型做出改动。