一、优化器简介

在网络中要使用优化器进行寻优,在PyTorch框架中可以使用torch.optim快速实现,例如下面就选择了SGD和Adam作为整个网络的优化器

  1. # 针对网络model的参数进行寻优
  2. optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
  3. # 针对变量var1和var2进行寻优
  4. optimizer = optim.Adam([var1, var2], lr=0.0001)

§ PyTorch的optimizer - 图1 设定optimizer的参数

我们可以通过设定优化器的参数来定制适合自己的优化器,优化器的参数是字典类型的数据,并且这个字典数据应该含有params这个关键字。比如,我们可以将每层网络的学习率设置成不同大小的数据

  1. # model.classifier层的学习率为1e-3,其他层的学习率为1e-2,所有层的momentum大小均为0.9
  2. optim.SGD([
  3. {'params': model.base.parameters()},
  4. {'params': model.classifier.parameters(), 'lr': 1e-3}
  5. ], lr=1e-2, momentum=0.9)

§ PyTorch的optimizer - 图2 设定更新优化器

使用optimizer.step()来更新网络中的参数,这个过程发生在梯度计算以后,即loss.backward()过程发生完成,就比如下面这个例子

  1. for input, target in dataset:
  2. optimizer.zero_grad() # 设定所有被优化的tensor张量参数为0
  3. output = model(input)
  4. loss = loss_fn(output, target)
  5. loss.backward() # 计算网络的梯度
  6. optimizer.step() # 更新网络参数

§ PyTorch的optimizer - 图3 优化器的种类

  1. [Adadelta](https://pytorch.org/docs/stable/generated/torch.optim.Adadelta.html#torch.optim.Adadelta)
  2. [Adagrad](https://pytorch.org/docs/stable/generated/torch.optim.Adagrad.html#torch.optim.Adagrad)
  3. [Adam](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam)
  4. [AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html#torch.optim.AdamW)
  5. [SparseAdam](https://pytorch.org/docs/stable/generated/torch.optim.SparseAdam.html#torch.optim.SparseAdam)
  6. [Adamax](https://pytorch.org/docs/stable/generated/torch.optim.Adamax.html#torch.optim.Adamax)
  7. [ASGD](https://pytorch.org/docs/stable/generated/torch.optim.ASGD.html#torch.optim.ASGD)
  8. [LBFGS](https://pytorch.org/docs/stable/generated/torch.optim.LBFGS.html#torch.optim.LBFGS)
  9. [RMSprop](https://pytorch.org/docs/stable/generated/torch.optim.RMSprop.html#torch.optim.RMSprop)
  10. [Rprop](https://pytorch.org/docs/stable/generated/torch.optim.Rprop.html#torch.optim.Rprop)
  11. [SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html#torch.optim.SGD)

二、优化器:SGD

优化器SGD通过torch.optim.SGD来调用,这个类的接口如下

  1. torch.optim.SGD(params, lr=<required parameter>, momentum=0,
  2. dampening=0, weight_decay=0, nesterov=False)

上述的参数含义为:

  • params:设定的需要迭代的网络的参数
  • lr(float):网络的学习率
  • momentum(float, optional):动量因子,默认为0
  • weight_decay(float, optional):L2正则化的惩罚因子,默认为0
  • dampening(float, optional):动量因子的阻尼,默认为0
  • nesterov(bool, optional):是否使用Nesterov动量,默认为False

使用的一个简单例子

  1. optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  2. optimizer.zero_grad()
  3. loss_fn(model(input), target).backward()
  4. optimizer.step()

上述给了一堆代码,但是上述代码中的momentum动量是什么,怎么用的;lr是什么,怎么用的。最后还需要通过公式来知道,下面看一下参数更新的过程

§ PyTorch的optimizer - 图4

其中,p就是要更新的网络参数,v表示的就是更新的速率,g表示的就是梯度,§ PyTorch的optimizer - 图5表示的动量因子。