参考来源:
CSDN:在 Pytorch 中实现 early stopping

使用了这个 github 仓库中提供的 **spytorch_tools**,而且该仓库中还有这个工具的使用案例,建议读者前往查看。在此感谢作者的代码分享。

注意:不要使用 **pip install pytorchtools** 获取相关代码,这是不对的可以直接去上述的仓库里把其中的 **pytorchtools.py** 下载(或复制)下来并放置在项目中,这样就可以正常执行 **from pytorchtools import EarlyStopping** 并使用 EarlyStopping 了。

实现

有了 **pytorch_tools** 工具后,使用 **early stopping** 就很简单了。
先从该工具类中导入 **EarlyStopping**

  1. # import EarlyStopping
  2. from pytorchtools import EarlyStopping
  3. import torch.utils.data as Data # 用于创建 DataLoader
  4. import torch.nn as nn

为了方便描述,这里还是会使用一些伪代码,如果你想阅读详细案例的话,不用犹豫直接看上述工具自带的案例代码

  1. model = yourModel() # 伪
  2. # 指定损失函数,可以是其他损失函数,根据训练要求决定
  3. criterion = nn.CrossEntropyLoss() # 交叉熵,注意该损失函数对自动对批量样本的损失取平均
  4. # 指定优化器,可以是其他
  5. optimizer = torch.optim.Adam(model.parameters())
  6. # 初始化 early_stopping 对象
  7. patience = 20 # 当验证集损失在连续20次训练周期中都没有得到降低时,停止模型训练,以防止模型过拟合
  8. early_stopping = EarlyStopping(patience, verbose=True) # 关于 EarlyStopping 的代码可先看博客后面的内容
  9. batch_size = 64 # 或其他,该参数属于超参,对于如何选择超参,你可以参考下我的上一篇博客
  10. n_epochs = 100 # 可以设置大一些,毕竟你是希望通过 early stopping 来结束模型训练
  11. #----------------------------------------------------------------
  12. # 训练模型,直到 epoch == n_epochs 或者触发 early_stopping 结束训练
  13. for epoch in range(1, n_epochs + 1):
  14. # 建立训练数据的 DataLoader
  15. training_dataset = Data.TensorDataset(X_train, y_train)
  16. # 把dataset放到DataLoader中
  17. data_loader = Data.DataLoader(
  18. dataset=training_dataset,
  19. batch_size=batch_size, # 批量大小
  20. shuffle=True # 是否打乱数据顺序
  21. )
  22. #---------------------------------------------------
  23. model.train() # 设置模型为训练模式
  24. # 按小批量训练
  25. for batch, (data, target) in enumerate(data_loader):
  26. optimizer.zero_grad() # 清楚所有参数的梯度
  27. output = model(data) # 输出模型预测值
  28. loss = criterion(output, target) # 计算损失
  29. loss.backward() # 计算损失对于各个参数的梯度
  30. optimizer.step() # 执行单步优化操作:更新参数
  31. #----------------------------------------------------
  32. model.eval() # 设置模型为评估/测试模式
  33. # 一般如果验证集不是很大的话,模型验证就不需要按批量进行了,但要注意输入参数的维度不能错
  34. valid_output = model(X_val)
  35. valid_loss = criterion(valid_output, y_val) # 注意这里的输入参数维度要符合要求,我这里为了简单,并未考虑这一点
  36. early_stopping(valid_loss, model)
  37. # 若满足 early stopping 要求
  38. if early_stopping.early_stop:
  39. print("Early stopping")
  40. # 结束模型训练
  41. break
  42. # 获得 early stopping 时的模型参数
  43. model.load_state_dict(torch.load('checkpoint.pt'))

以下是 **pytorch_tools** 工具的代码:

  1. import numpy as np
  2. import torch
  3. class EarlyStopping:
  4. """Early stops the training if validation loss doesn't improve after a given patience."""
  5. def __init__(self, patience=7, verbose=False, delta=0):
  6. """
  7. Args:
  8. patience (int): How long to wait after last time validation loss improved.
  9. Default: 7
  10. verbose (bool): If True, prints a message for each validation loss improvement.
  11. Default: False
  12. delta (float): Minimum change in the monitored quantity to qualify as an improvement.
  13. Default: 0
  14. """
  15. self.patience = patience
  16. self.verbose = verbose
  17. self.counter = 0
  18. self.best_score = None
  19. self.early_stop = False
  20. self.val_loss_min = np.Inf
  21. self.delta = delta
  22. def __call__(self, val_loss, model):
  23. score = -val_loss
  24. if self.best_score is None:
  25. self.best_score = score
  26. self.save_checkpoint(val_loss, model)
  27. elif score < self.best_score + self.delta:
  28. self.counter += 1
  29. print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
  30. if self.counter >= self.patience:
  31. self.early_stop = True
  32. else:
  33. self.best_score = score
  34. self.save_checkpoint(val_loss, model)
  35. self.counter = 0
  36. def save_checkpoint(self, val_loss, model):
  37. '''Saves model when validation loss decrease.'''
  38. if self.verbose:
  39. print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
  40. torch.save(model.state_dict(), 'checkpoint.pt') # 这里会存储迄今最优模型的参数
  41. self.val_loss_min = val_loss

不过这些代码还是需要读者根据自己的模型做出改动。