一、优化器简介
在网络中要使用优化器进行寻优,在PyTorch框架中可以使用torch.optim
快速实现,例如下面就选择了SGD和Adam作为整个网络的优化器
# 针对网络model的参数进行寻优
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# 针对变量var1和var2进行寻优
optimizer = optim.Adam([var1, var2], lr=0.0001)
设定optimizer的参数
我们可以通过设定优化器的参数来定制适合自己的优化器,优化器的参数是字典类型的数据,并且这个字典数据应该含有params
这个关键字。比如,我们可以将每层网络的学习率设置成不同大小的数据
# model.classifier层的学习率为1e-3,其他层的学习率为1e-2,所有层的momentum大小均为0.9
optim.SGD([
{'params': model.base.parameters()},
{'params': model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)
设定更新优化器
使用optimizer.step()
来更新网络中的参数,这个过程发生在梯度计算以后,即loss.backward()
过程发生完成,就比如下面这个例子
for input, target in dataset:
optimizer.zero_grad() # 设定所有被优化的tensor张量参数为0
output = model(input)
loss = loss_fn(output, target)
loss.backward() # 计算网络的梯度
optimizer.step() # 更新网络参数
优化器的种类
[Adadelta](https://pytorch.org/docs/stable/generated/torch.optim.Adadelta.html#torch.optim.Adadelta)
[Adagrad](https://pytorch.org/docs/stable/generated/torch.optim.Adagrad.html#torch.optim.Adagrad)
[Adam](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam)
[AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html#torch.optim.AdamW)
[SparseAdam](https://pytorch.org/docs/stable/generated/torch.optim.SparseAdam.html#torch.optim.SparseAdam)
[Adamax](https://pytorch.org/docs/stable/generated/torch.optim.Adamax.html#torch.optim.Adamax)
[ASGD](https://pytorch.org/docs/stable/generated/torch.optim.ASGD.html#torch.optim.ASGD)
[LBFGS](https://pytorch.org/docs/stable/generated/torch.optim.LBFGS.html#torch.optim.LBFGS)
[RMSprop](https://pytorch.org/docs/stable/generated/torch.optim.RMSprop.html#torch.optim.RMSprop)
[Rprop](https://pytorch.org/docs/stable/generated/torch.optim.Rprop.html#torch.optim.Rprop)
[SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html#torch.optim.SGD)
二、优化器:SGD
优化器SGD通过torch.optim.SGD
来调用,这个类的接口如下
torch.optim.SGD(params, lr=<required parameter>, momentum=0,
dampening=0, weight_decay=0, nesterov=False)
上述的参数含义为:
params
:设定的需要迭代的网络的参数lr(float)
:网络的学习率momentum(float, optional)
:动量因子,默认为0weight_decay(float, optional)
:L2正则化的惩罚因子,默认为0dampening(float, optional)
:动量因子的阻尼,默认为0nesterov(bool, optional)
:是否使用Nesterov动量,默认为False
使用的一个简单例子
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()
上述给了一堆代码,但是上述代码中的momentum
动量是什么,怎么用的;lr
是什么,怎么用的。最后还需要通过公式来知道,下面看一下参数更新的过程
其中,p就是要更新的网络参数,v表示的就是更新的速率,g表示的就是梯度,表示的动量因子。