在 pytorch 框架中,损失函数其实和网络是等价的。它可以看做是网络的一个特殊层。编写自己的损失函数和编写自己的网络模型是类似的。模型和损失函数都需要求导,那么如果采用框架内部的对应操作,可以省去对求导代码的编写。

损失函数要求是凸函数,凸函数利用优化。
损失函数用于评价模型当前输出和实际结果之间的差异,为后续优化提供指导。

分类问题损失函数选择

Picking Loss Functions - A comparison between MSE, Cross Entropy, and Hinge Loss

交叉熵损失函数

交叉熵损失函数原理详解

交叉熵

用于度量两个概率分布之间的差异性。

信息量

信息奠基人香农认为“信息是用来消除随机不确定性的东西”,也就是说信息量的大小和其消除不确定性的程度成正相关。而不确定性表现在概率上则:发生概率越小的信息,其不确定性更大。
设某件事发生概率为,则其信息量定义为:
信息量和信息发生概率成反比(对数意义下)

信息熵

热力学熵

信息熵的定义类比于玻尔兹曼公式:
常见损失函数 - 图1
常见损失函数 - 图2
假设一个盒子分为两格,记为常见损失函数 - 图3,里面有常见损失函数 - 图4个分子,以分子在盒子中的分布来表征系统的状态,则每一种状态对应一种熵,其中熵最大的情况是,,换句话说,两个格子中分子数目相同是系统演化趋势。

信息熵
信息熵表示所有信息量的期望:
常见损失函数 - 图5

  • 常见损失函数 - 图6 表示的是一个分布;通常情况下它都是一个离散的分布,这时候它可以用一个向量进行表示,例如:,此分布对应 8 种情景,对应的值表示每种情形发生的概率。
  • 对信息熵的理解,从含义以及形式上定性推导信息熵的表达式

相对熵(KL散度)

  • 物理中的散度用于矢量场的发散程度(和矢量场有源性相关)
  • 散度:针对向量场,
  • 旋度:针对向量场,常见损失函数 - 图7
  • 梯度:针对数量场,常见损失函数 - 图8

设对于同一个随机变量有两种单独的概率分布
对应一个多分类问题,假设类别用 常见损失函数 - 图9 表示,则散度定义如下:(两者对数差值的期望)
常见损失函数 - 图10
常见损失函数 - 图11
其中:表示交叉熵。
对于已知输入数据的情况下,往往是一个常数,比如一个样本输入,那么此时信息量为0;对于一个batch的输入,那么可以以一个batch样本分布作为概率分布进行计算。计算交叉熵时以输出结果生成输出的概率分布,进行计算。

  • KL散度越小,则说明两者越接近;
  • 常见损失函数 - 图12
  • 证明的话,通过求极值即可

交叉熵计算

  • 交叉熵越小,则说明两者越接近;
  • 最小化交叉熵和最大化对数似然函数对应;

常见损失函数 - 图13

常见损失函数 - 图14 表示第 常见损失函数 - 图15 类,P 和 Q 分别表示期望的概率和预测的概率

那么对于一个batch的采用直接叠加的方式产生:
常见损失函数 - 图16

如果是 0-1 分布的话,交叉熵可以表示为:,其中 常见损失函数 - 图17 表示期望的分布,常见损失函数 - 图18 表示真实的分布。通过求极值可以知道,当 常见损失函数 - 图19 时,可以取到最小值。

注意在以常见损失函数 - 图20为激活函数时,往往选用交叉熵作为损失函数。原因:求导,导数的大小。
它是凸函数么?

通常情况下 常见损失函数 - 图21是 one-hot 形式,所以对于多分类输出,最终只有真是类别对应的概率值参与 loss 的计算。所以很多时候 CE Loss 写作:(常见损失函数 - 图22表示真实类别,模型预测的概率)

softmax函数

由于交叉熵和概率分布相关,所以交叉熵往往和softmax函数关联。
常见损失函数 - 图23
将输出得分转为概率分布。

交叉熵讲解:https://www.cnblogs.com/noahzhixiao/p/10170087.html

在分类问题中,我们的标签往往是:0, 1, 2, … 之类的形式。例如 CrossEntropyLoss 的 label 并不需要是 one-hot 形式,在函数内部,会将对应的标签转为 one-hot 的形式。

均方误差损失函数(MSE)

就是直接求取标签和模型输出之间的欧式距离:
常见损失函数 - 图24

在分类问题中,交叉熵损失函数是要优于MSE损失函数的

NLL Loss

Pytorch 中 CE、NLL loss

实现原理:
假设分类问题模型输出的值通过 Softmax 函数之后转为了概率分布 常见损失函数 - 图25。NLL Loss 选取模型对真实类别预测的概率 常见损失函数 - 图26 做 -log 运算,得到的值就是 loss 值。一个 batch 里面的 loss 取平均值即可。

常见损失函数 - 图27 越大,那么 常见损失函数 - 图28 对应的值越小;反之,则越大。所以 loss 越小,则会约束模型最终输出 常见损失函数 - 图29 足够大。

Focal Loss

focal loss 可以用来解决样本不平衡问题。它可以对 hard samples 加大损失权重,从而提高对 hard samples 的预测能力。其公式如下所示:

通过控制其中的 常见损失函数 - 图30常见损失函数 - 图31 可以实现对高概率的降低损失,低概率的增大损失。

  1. class WeightedFocalLoss(nn.Module):
  2. "Non weighted version of Focal Loss"
  3. def __init__(self, alpha=.25, gamma=2):
  4. super(WeightedFocalLoss, self).__init__()
  5. self.alpha = torch.tensor([alpha, 1-alpha]).cuda()
  6. self.gamma = gamma
  7. def forward(self, inputs, targets):
  8. BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
  9. targets = targets.type(torch.long)
  10. at = self.alpha.gather(0, targets.data.view(-1))
  11. pt = torch.exp(-BCE_loss)
  12. F_loss = at*(1-pt)**self.gamma * BCE_loss
  13. return F_loss.mean()