参考来源:
CSDN:在 Pytorch 中实现 early stopping
使用了这个 github 仓库中提供的 **spytorch_tools**,而且该仓库中还有这个工具的使用案例,建议读者前往查看。在此感谢作者的代码分享。
注意:不要使用 **pip install pytorchtools** 获取相关代码,这是不对的可以直接去上述的仓库里把其中的 **pytorchtools.py** 下载(或复制)下来并放置在项目中,这样就可以正常执行 **from pytorchtools import EarlyStopping** 并使用 EarlyStopping 了。
实现
有了 **pytorch_tools** 工具后,使用 **early stopping** 就很简单了。
先从该工具类中导入 **EarlyStopping** 。
# import EarlyStoppingfrom pytorchtools import EarlyStoppingimport torch.utils.data as Data # 用于创建 DataLoaderimport torch.nn as nn
为了方便描述,这里还是会使用一些伪代码,如果你想阅读详细案例的话,不用犹豫直接看上述工具自带的案例代码。
model = yourModel() # 伪# 指定损失函数,可以是其他损失函数,根据训练要求决定criterion = nn.CrossEntropyLoss() # 交叉熵,注意该损失函数对自动对批量样本的损失取平均# 指定优化器,可以是其他optimizer = torch.optim.Adam(model.parameters())# 初始化 early_stopping 对象patience = 20 # 当验证集损失在连续20次训练周期中都没有得到降低时,停止模型训练,以防止模型过拟合early_stopping = EarlyStopping(patience, verbose=True) # 关于 EarlyStopping 的代码可先看博客后面的内容batch_size = 64 # 或其他,该参数属于超参,对于如何选择超参,你可以参考下我的上一篇博客n_epochs = 100 # 可以设置大一些,毕竟你是希望通过 early stopping 来结束模型训练#----------------------------------------------------------------# 训练模型,直到 epoch == n_epochs 或者触发 early_stopping 结束训练for epoch in range(1, n_epochs + 1):# 建立训练数据的 DataLoadertraining_dataset = Data.TensorDataset(X_train, y_train)# 把dataset放到DataLoader中data_loader = Data.DataLoader(dataset=training_dataset,batch_size=batch_size, # 批量大小shuffle=True # 是否打乱数据顺序)#---------------------------------------------------model.train() # 设置模型为训练模式# 按小批量训练for batch, (data, target) in enumerate(data_loader):optimizer.zero_grad() # 清楚所有参数的梯度output = model(data) # 输出模型预测值loss = criterion(output, target) # 计算损失loss.backward() # 计算损失对于各个参数的梯度optimizer.step() # 执行单步优化操作:更新参数#----------------------------------------------------model.eval() # 设置模型为评估/测试模式# 一般如果验证集不是很大的话,模型验证就不需要按批量进行了,但要注意输入参数的维度不能错valid_output = model(X_val)valid_loss = criterion(valid_output, y_val) # 注意这里的输入参数维度要符合要求,我这里为了简单,并未考虑这一点early_stopping(valid_loss, model)# 若满足 early stopping 要求if early_stopping.early_stop:print("Early stopping")# 结束模型训练break# 获得 early stopping 时的模型参数model.load_state_dict(torch.load('checkpoint.pt'))
以下是 **pytorch_tools** 工具的代码:
import numpy as npimport torchclass EarlyStopping:"""Early stops the training if validation loss doesn't improve after a given patience."""def __init__(self, patience=7, verbose=False, delta=0):"""Args:patience (int): How long to wait after last time validation loss improved.Default: 7verbose (bool): If True, prints a message for each validation loss improvement.Default: Falsedelta (float): Minimum change in the monitored quantity to qualify as an improvement.Default: 0"""self.patience = patienceself.verbose = verboseself.counter = 0self.best_score = Noneself.early_stop = Falseself.val_loss_min = np.Infself.delta = deltadef __call__(self, val_loss, model):score = -val_lossif self.best_score is None:self.best_score = scoreself.save_checkpoint(val_loss, model)elif score < self.best_score + self.delta:self.counter += 1print(f'EarlyStopping counter: {self.counter} out of {self.patience}')if self.counter >= self.patience:self.early_stop = Trueelse:self.best_score = scoreself.save_checkpoint(val_loss, model)self.counter = 0def save_checkpoint(self, val_loss, model):'''Saves model when validation loss decrease.'''if self.verbose:print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')torch.save(model.state_dict(), 'checkpoint.pt') # 这里会存储迄今最优模型的参数self.val_loss_min = val_loss
不过这些代码还是需要读者根据自己的模型做出改动。
