参考来源:
Pytorch 使用 ReduceLROnPlateau 来更新学习率

解析说明

  1. torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
  2. mode='min',
  3. factor=0.1,
  4. patience=10,
  5. verbose=False,
  6. threshold=0.0001,
  7. threshold_mode='rel',
  8. cooldown=0,
  9. min_lr=0,
  10. eps=1e-08
  11. )

在发现 loss 不再降低或者 acc 不再提高之后,降低学习率。各参数意义如下:

  • **mode**'min' 模式检测 metric 是否不再减小,'max' 模式检测 metric 是否不再增大;
  • **factor**:触发条件后 lr*=factor
  • **patience**:不再减小(或增大)的累计次数;
  • **verbose**:触发条件后 print
  • **threshold**:只关注超过阈值的显著变化;
  • **threshold_mode**:有 relabs 两种阈值计算模式;
    • **rel** 规则max 模式下如果超过 best(1+threshold) 为显著,min 模式下如果低于 best(1-threshold) 为显著;
    • **abs** 规则max模式下如果超过 best+threshold 为显著,min 模式下如果低于 best-threshold 为显著;
  • **cooldown**:触发一次条件后,等待一定 epoch 再进行检测,避免 lr 下降过速;
  • **min_lr**:最小的允许 lr
  • **eps**:如果新旧 lr 之间的差异小于1e-8,则忽略此次更新。

例子,如图所示的 y 轴为 lrx 为调整的次序,初始的学习率为 0.0009575。则学习率的方程为:lr = 0.0009575 * (0.35)^x
使用 ReduceLROnPlateau 来更新学习率——Pytorch - 图1

  1. import math
  2. import matplotlib.pyplot as plt
  3. #%matplotlib inline
  4. x = 0
  5. o = []
  6. p = []
  7. o.append(0)
  8. p.append(0.0009575)
  9. while(x < 8):
  10. x += 1
  11. y = 0.0009575 * math.pow(0.35,x)
  12. o.append(x)
  13. p.append(y)
  14. print('%d: %.50f' %(x,y))
  15. plt.plot(o,p,c='red',label='test') #分别为x,y轴对应数据,c:color,label
  16. plt.legend(loc='best') # 显示label,loc为显示位置(best为系统认为最好的位置)
  17. plt.show()

难点

我感觉这里面最难的时这几个参数的选择,第一个是初始的学习率(我目前接触的 miniest 和下面的图像分类貌似都是 0.001,我这里训练调整时才发现自己设置的为 0.0009575,这个值是上一个实验忘更改了,但发现结果不错,第一次运行该代码接近到 0.001 这么小的损失值),这里面的乘积系数以及判断说多少次没有减少(增加)后决定变换学习率都是难以估计的。我自己的最好方法是先按默认不变的 0.001 来训练一下(结合 tensoarboard )观察从哪里开始出现问题就可以从这里来确定次数,而乘积系数,个人感觉还是用上面的代码来获取一个较为平滑且变化极小的数字来作为选择。建议在做这种测试时可以把模型先备份一下以免浪费过多的时间!

例子

该例子初始学习率为 0.0009575,乘积项系数为:0.35,在我的例子中 x 变化的条件是:累计 125 次没有减小则 x 加 1;自己训练在第一次 lr 变化后(从 0.0009575 变化到 0.00011729)损失值慢慢取向于 0.001(如第一张图所示),准确率达到 69%;
使用 ReduceLROnPlateau 来更新学习率——Pytorch - 图2

  1. import torch
  2. import torchvision
  3. import torchvision.transforms as transforms
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. import torch.optim as optim
  9. from datetime import datetime
  10. from torch.utils.tensorboard import SummaryWriter
  11. from torch.optim import *
  12. PATH = './cifar_net_tensorboard_net_width_200_and_chang_lr_by_decrease_0_35^x.pth' # 保存模型地址
  13. transform = transforms.Compose(
  14. [transforms.ToTensor(),
  15. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  16. trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
  17. download=True, transform=transform)
  18. trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
  19. shuffle=True, num_workers=0)
  20. testset = torchvision.datasets.CIFAR10(root='./data', train=False,
  21. download=True, transform=transform)
  22. testloader = torch.utils.data.DataLoader(testset, batch_size=4,
  23. shuffle=False, num_workers=0)
  24. classes = ('plane', 'car', 'bird', 'cat',
  25. 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  26. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  27. # Assuming that we are on a CUDA machine, this should print a CUDA device:
  28. print(device)
  29. print("获取一些随机训练数据")
  30. # get some random training images
  31. dataiter = iter(trainloader)
  32. images, labels = dataiter.next()
  33. # functions to show an image
  34. def imshow(img):
  35. img = img / 2 + 0.5 # unnormalize
  36. npimg = img.numpy()
  37. plt.imshow(np.transpose(npimg, (1, 2, 0)))
  38. plt.show()
  39. # show images
  40. imshow(torchvision.utils.make_grid(images))
  41. # print labels
  42. print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
  43. print("**********************")
  44. # 设置一个tensorborad
  45. # helper function to show an image
  46. # (used in the `plot_classes_preds` function below)
  47. def matplotlib_imshow(img, one_channel=False):
  48. if one_channel:
  49. img = img.mean(dim=0)
  50. img = img / 2 + 0.5 # unnormalize
  51. npimg = img.cpu().numpy()
  52. if one_channel:
  53. plt.imshow(npimg, cmap="Greys")
  54. else:
  55. plt.imshow(np.transpose(npimg, (1, 2, 0)))
  56. # 设置tensorBoard
  57. # default `log_dir` is "runs" - we'll be more specific here
  58. writer = SummaryWriter('runs/train')
  59. # get some random training images
  60. dataiter = iter(trainloader)
  61. images, labels = dataiter.next()
  62. # create grid of images
  63. img_grid = torchvision.utils.make_grid(images)
  64. # show images
  65. # matplotlib_imshow(img_grid, one_channel=True)
  66. imshow(img_grid)
  67. # write to tensorboard
  68. # writer.add_image('imag_classify', img_grid)
  69. # Tracking model training with TensorBoard
  70. # helper functions
  71. def images_to_probs(net, images):
  72. '''
  73. Generates predictions and corresponding probabilities from a trained
  74. network and a list of images
  75. '''
  76. output = net(images)
  77. # convert output probabilities to predicted class
  78. _, preds_tensor = torch.max(output, 1)
  79. # preds = np.squeeze(preds_tensor.numpy())
  80. preds = np.squeeze(preds_tensor.cpu().numpy())
  81. return preds, [F.softmax(el, dim=0)[i].item() for i, el in zip(preds, output)]
  82. def plot_classes_preds(net, images, labels):
  83. preds, probs = images_to_probs(net, images)
  84. # plot the images in the batch, along with predicted and true labels
  85. fig = plt.figure(figsize=(12, 48))
  86. for idx in np.arange(4):
  87. ax = fig.add_subplot(1, 4, idx+1, xticks=[], yticks=[])
  88. matplotlib_imshow(images[idx], one_channel=True)
  89. ax.set_title("{0}, {1:.1f}%\n(label: {2})".format(
  90. classes[preds[idx]],
  91. probs[idx] * 100.0,
  92. classes[labels[idx]]),
  93. color=("green" if preds[idx]==labels[idx].item() else "red"))
  94. return fig
  95. #
  96. class Net(nn.Module):
  97. def __init__(self):
  98. super(Net, self).__init__()
  99. self.conv1 = nn.Conv2d(3, 200, 5)
  100. self.pool = nn.MaxPool2d(2, 2)
  101. self.conv2 = nn.Conv2d(200, 16, 5)
  102. self.fc1 = nn.Linear(16 * 5 * 5, 120)
  103. self.fc2 = nn.Linear(120, 84)
  104. self.fc3 = nn.Linear(84, 10)
  105. def forward(self, x):
  106. x = self.pool(F.relu(self.conv1(x)))
  107. x = self.pool(F.relu(self.conv2(x)))
  108. x = x.view(-1, 16 * 5 * 5)
  109. x = F.relu(self.fc1(x))
  110. x = F.relu(self.fc2(x))
  111. x = self.fc3(x)
  112. return x
  113. net = Net()
  114. # # 把net结构可视化出来
  115. writer.add_graph(net, images)
  116. net.to(device)
  117. try:
  118. net.load_state_dict(torch.load(PATH))
  119. print("Modle file load successful !")
  120. except:
  121. print("no model file,it will creat a new file!")
  122. # 训练
  123. print("训练")
  124. criterion = nn.CrossEntropyLoss()
  125. # optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
  126. #在发现loss不再降低或者acc不再提高之后,降低学习率。
  127. optimizer = torch.optim.SGD(net.parameters(), lr=0.0009575, momentum=0.9)
  128. scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',factor=0.35,verbose=1,min_lr=0.0001,patience=125)
  129. startTime = datetime.now()
  130. for epoch in range(200): # loop over the dataset multiple times
  131. running_loss = 0.0
  132. for i, data in enumerate(trainloader, 0):
  133. # get the inputs; data is a list of [inputs, labels]
  134. # inputs, labels = data
  135. inputs, labels = data[0].to(device), data[1].to(device)
  136. # zero the parameter gradients
  137. optimizer.zero_grad() #将参数的grad值初始化为0
  138. # forward + backward + optimize
  139. outputs = net(inputs)
  140. loss = criterion(outputs, labels) #计算损失
  141. loss.backward() # 反向传播
  142. optimizer.step() # 反向传播求梯度
  143. # print statistics
  144. running_loss += loss.item()
  145. if i % 2000 == 1999: # print every 2000 mini-batches
  146. now_loss = running_loss / 2000 # 2000mini-batches 的平均损失率
  147. print('[%d, %5d] loss: %.3f' %
  148. (epoch + 1, i + 1, now_loss))
  149. # now_loss = running_loss / 2000
  150. scheduler.step(now_loss)
  151. # 把数据写入tensorflow
  152. # ...log the running loss
  153. writer.add_scalar('image training loss on net width 200 chang_lr_by_decrease',
  154. now_loss,
  155. epoch * len(trainloader) + i)
  156. writer.add_scalar('learning rate on net width 200 chang_lr_by_decrease',
  157. optimizer.state_dict()['param_groups'][0]['lr'],
  158. epoch * len(trainloader) + i)
  159. running_loss = 0.0
  160. torch.save(net.state_dict(), PATH)
  161. print('Finished Training')
  162. print("***************************")
  163. print("***************************")
  164. print("***************************")
  165. print("Time taken:", datetime.now() - startTime)
  166. print("***************************")
  167. print("***************************")
  168. print("***************************")
  169. #获取一些随机测试数据
  170. print("获取一些随机测试数据")
  171. dataiter = iter(testloader)
  172. images, labels = dataiter.next()
  173. # print images
  174. imshow(torchvision.utils.make_grid(images))
  175. print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
  176. # 恢复模型并测试
  177. net = Net()
  178. net.load_state_dict(torch.load(PATH))
  179. outputs = net(images)
  180. _, predicted = torch.max(outputs, 1)
  181. print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
  182. for j in range(4)))
  183. print("**********************")
  184. print("输出训练得到的准确度")
  185. # 输出训练得到的准确度
  186. correct = 0
  187. total = 0
  188. with torch.no_grad():
  189. for data in testloader:
  190. images, labels = data
  191. outputs = net(images)
  192. _, predicted = torch.max(outputs.data, 1)
  193. total += labels.size(0)
  194. correct += (predicted == labels).sum().item()
  195. print('Accuracy of the network on the 10000 test images: %d %%' % (
  196. 100 * correct / total))
  197. class_correct = list(0. for i in range(10))
  198. class_total = list(0. for i in range(10))
  199. with torch.no_grad():
  200. for data in testloader:
  201. images, labels = data
  202. outputs = net(images)
  203. _, predicted = torch.max(outputs, 1)
  204. c = (predicted == labels).squeeze()
  205. for i in range(4):
  206. label = labels[i]
  207. class_correct[label] += c[i].item()
  208. class_total[label] += 1
  209. for i in range(10):
  210. print('Accuracy of %5s : %2d %%' % (
  211. classes[i], 100 * class_correct[i] / class_total[i]))