A unifying mutual information view of metric learning: cross-entropy vs. pairwise losses

总的来看,这里作者认为当今DML研究者都在精力花费在了寻找复杂的对间损失函数,而忽视了原有的交叉熵损失。作者认为交叉熵损失在DML中仍有用武之地。作者将从优化的角度和相互信息的判别和生成的角度对于交叉熵损失进行分析,并认为交叉熵损失函数与带有样本挖掘的对间损失一样,可以起到相似的的效果。

Introduction

作者回顾了DML领域的进展和局限性。

作者指出,尽管一眼看上去交叉熵损失似乎和目前使用较多的对间损失没关系,但作者将一些对间损失和交叉熵相关联并提出理论解释,关联的角度为:交叉熵损失和对间损失都在优化一个在样本标签和需要学习的嵌入之间的互信息。

On the two views of the mutual information

互信息可以用来衡量两个随机变量之间共享的信息的关系:

  • 互信息的意义
    它可以看成是一个随机变量中包含的关于另一个随机变量的信息量,或者说是一个随机变量由于已知另一个随机变量而减少的不肯定性 。

在本文中重点检查的是模型学习的特征Metric Learning: CE vs PL - 图1与数据的标签Metric Learning: CE vs PL - 图2之间的互信息:

Metric Learning: CE vs PL - 图3

对于最大化公式1的解释有两种:

  1. 区分度观点
    在该角度下,标签Metric Learning: CE vs PL - 图4应当是平衡的(个人的理解就是尽可能地分散在各个类别中,以实现Metric Learning: CE vs PL - 图5#card=math&code=H%28Y%29&id=GZZ9z)的最大化,而这不是我们所可以控制的),并且标签应当可以和特征区分开来(因为这要使得Metric Learning: CE vs PL - 图6#card=math&code=H%28Y%7C%5Chat%7BZ%7D%29&id=D7XA3)尽可能地小,也就是二者的相关性要尽可能地低)。
  2. 统一性的观点
    学习获得的特征应当尽可能地分散(也就是应当最大化Metric Learning: CE vs PL - 图7#card=math&code=H%28%5Chat%7BZ%7D%29&id=vrrMT)),并且保证同类相近的特点更加接近。
    • 条件熵的意义
      条件熵 Metric Learning: CE vs PL - 图8#card=math&code=H%28Y%7CX%29&id=NjEV3) 表示在已知随机变量 Metric Learning: CE vs PL - 图9 的条件下随机变量 Metric Learning: CE vs PL - 图10 的不确定性

因此,最大化Metric Learning: CE vs PL - 图11#card=math&code=%5Cmathbb%7BI%7D%28%5Chat%7BZ%7D%3BY%29&id=eu5MS)要求需要最小化统一性中的Metric Learning: CE vs PL - 图12#card=math&code=H%28%5Chat%7BZ%7D%7CY%29&id=BrD5B),也就是需要在Metric Learning: CE vs PL - 图13确定的条件下尽可能减少学习的特征的不确定性,因此学习的特征中的相似标签的样本应当接近,也就是同类点更加接近。

总的来看,区分度观点更加关注在标签的识别上,而统一性观点更加关注模型学习的特征分布的塑造上。这也就是这篇文章所要探讨的classification loss(比如交叉熵)和feature-shaping loss(比如对比学习中常用的各种对间损失)。

Pairwise losses and the generative view of the MI

文章的这一节重点从上文所说的统一性的观点回顾了一些度量学习中的对间损失函数。

The example of contrastive loss

Contrastive Loss的形式如下:

Metric Learning: CE vs PL - 图14

其中的Metric Learning: CE vs PL - 图15是欧氏距离,Metric Learning: CE vs PL - 图16是边界值margin,Metric Learning: CE vs PL - 图17是Hinge Loss。整个损失函数分为两部分,分别为左侧的紧致部分和右侧的对比部分。左侧紧致部分使得同类型的数据形成更加紧致的聚类,右侧对比部使得不同类别的数据在嵌入特征空间中彼此分离。

在公式2中的Metric Learning: CE vs PL - 图18可以被等价转化成Center Loss:

Metric Learning: CE vs PL - 图19

其中的Metric Learning: CE vs PL - 图20,是类别Metric Learning: CE vs PL - 图21在嵌入空间Metric Learning: CE vs PL - 图22中的特征数据的均值,因此该式即为衡量了所有的类别的样本在嵌入空间中的实际特征值与该类别的平均特征值之间的距离的均数,这其实可以转换成一种交叉熵的形式:

Metric Learning: CE vs PL - 图23%3DH(%5Chat%7BZ%7D%7CY)%2BD%7BKL%7D(%5Chat%7BZ%7D%7C%7C%5Coverline%7BZ%7D%7CY)%5Ctag%7B4%7D%0A#card=math&code=T%7Bcontrast%7D%5Coverset%7Bc%7D%7B%3D%7DH%28%5Chat%7BZ%7D%3B%5Coverline%7BZ%7D%7CY%29%3DH%28%5Chat%7BZ%7D%7CY%29%2BD_%7BKL%7D%28%5Chat%7BZ%7D%7C%7C%5Coverline%7BZ%7D%7CY%29%5Ctag%7B4%7D%0A&id=yiERR)

也就是样本数据经过神经网络的卷积层的后获得嵌入的估计Metric Learning: CE vs PL - 图24和在样本标签Metric Learning: CE vs PL - 图25已知的条件下的嵌入空间中的特征的均值交叉熵,其中Metric Learning: CE vs PL - 图26#card=math&code=c_Y%3A%5Coverline%7BZ%7D%7CY%20%5Csim%5Cmathbb%7BN%7D%28c_Y%2CI%29&id=eF8tM),即随机变量Metric Learning: CE vs PL - 图27 相对于Metric Learning: CE vs PL - 图28的分布为一个中心为Metric Learning: CE vs PL - 图29的高斯分布。

因此可以将对比项Metric Learning: CE vs PL - 图30看成是一个条件熵的上界:

Metric Learning: CE vs PL - 图31%20%5Ctag%7B5%7D%0A#card=math&code=T_%7Bcontrast%7D%5Cgeq%20H%28%5Chat%7BZ%7D%7CY%29%20%5Ctag%7B5%7D%0A&id=OHCLR)

在满足条件Metric Learning: CE vs PL - 图32#card=math&code=%5Coverline%7BZ%7D%7CY%20%5Csim%5Cmathbb%7BN%7D%28c_Y%2CI%29&id=GUgAv)时可以将最小化紧致项近四成最小化Metric Learning: CE vs PL - 图33#card=math&code=H%28%5Chat%7BZ%7D%7CY%29&id=BB90p),这也就是鼓励神经网络去学习一个可以使得在样本的嵌入特征与实际标签之间的交叉熵最小化的模型。

由于上述的紧致项将会在求解过程中陷入一个显然的、所有数据均在一点上的局部最优,因此需要额外添加一个约束项也就是对比项。类似地:

Metric Learning: CE vs PL - 图34

以及:

Metric Learning: CE vs PL - 图35%3D%5Cfrac%7Bd%7D%7Bn(n-1)%7D%5Csum%7Bi%3D1%7D%5En%5Csum%7Bj%3D1%7D%5En%5Clog%20D%7Bij%7D%5E2%5Coverset%7Bc%7D%7B%3D%7D%5Csum%7Bi%3D1%7D%5En%5Csum%7Bj%3D1%7D%5En%5Clog%20D%7Bij%7D%20%5Ctag%7B7%7D%0A#card=math&code=%5Chat%7BH%7D%28%5Chat%7BZ%7D%29%3D%5Cfrac%7Bd%7D%7Bn%28n-1%29%7D%5Csum%7Bi%3D1%7D%5En%5Csum%7Bj%3D1%7D%5En%5Clog%20D%7Bij%7D%5E2%5Coverset%7Bc%7D%7B%3D%7D%5Csum%7Bi%3D1%7D%5En%5Csum%7Bj%3D1%7D%5En%5Clog%20D%7Bij%7D%20%5Ctag%7B7%7D%0A&id=OSjcX)

因此第二项对比项也可以看作是一个交叉熵的估计。

基于上述工作,Contrastive Loss可以被看作以下形式:

Metric Learning: CE vs PL - 图36%7D%7B%5Cpropto%20H(%5Chat%7BZ%7D%7CY)%7D-%5Cunderbrace%7B%5Cfrac%7B2m%7D%7Bn%7D%5Csum%7Bi%3D1%7D%5En%5Csum%7Bj%3D1%7D%5EnD%7Bij%7D%7D%20%7B%5Cpropto%20H(%5Chat%7BZ%7D)%7D%5Cquad%20%5Cquad%20%5Cpropto-%5Cmathbb%7BI%7D(%5Chat%7BZ%7D%3BY)%5Ctag%7B8%7D%0A#card=math&code=L%7Bcontrast%7D%3D%5Cunderbrace%7B%5Cfrac%7B1%7D%7Bn%7D%5Csum%7Bi%3D1%7D%5En%5Csum%7Bj%3Ayj%3Dy_i%7D%5En%28D%7Bij%7D%5E2%2B2mD%7Bij%7D%29%7D%7B%5Cpropto%20H%28%5Chat%7BZ%7D%7CY%29%7D-%5Cunderbrace%7B%5Cfrac%7B2m%7D%7Bn%7D%5Csum%7Bi%3D1%7D%5En%5Csum%7Bj%3D1%7D%5EnD%7Bij%7D%7D%20%7B%5Cpropto%20H%28%5Chat%7BZ%7D%29%7D%5Cquad%20%5Cquad%20%5Cpropto-%5Cmathbb%7BI%7D%28%5Chat%7BZ%7D%3BY%29%5Ctag%7B8%7D%0A&id=jMLc7)

针对其他类型的对间损失也有类似的转换方式。

由各种对间损失之间的关系可以推算获得引理1:

  • Lemma 1
    Metric Learning: CE vs PL - 图37代表损失对间损失函数中的紧致项,并假设所有的特征都经过了L2正则化,并且类别已经经过平衡,则针对Center Loss,Contrastive Loss,SNCA Loss,和MS Loss有如下关系: Metric Learning: CE vs PL - 图38

Cross-ectropy does it all

在这个章节,作者继续从两个角度,也就是第一节和第二节,揭示对间损失与交叉熵之间的关系。

The pairwise loss behind unary cross-entropy

当给出一个难以优化的函数时,边界优化就是一种可以通过优化一个辅助函数(比如该函数的上界)来迭代地运算。令Metric Learning: CE vs PL - 图39为当前的迭代轮次,Metric Learning: CE vs PL - 图40是一个辅助函数当:

Metric Learning: CE vs PL - 图41%5Cleq%20a_t(W)%2C%5Cquad%20%5Cforall%20W%5C%5Cf(W_t)%3Da_t(W_t)%20%5Ctag%7B10%7D%0A#card=math&code=f%28W%29%5Cleq%20a_t%28W%29%2C%5Cquad%20%5Cforall%20W%5C%5Cf%28W_t%29%3Da_t%28W_t%29%20%5Ctag%7B10%7D%0A&id=nCJ8Y)

基于边界辅助函数Metric Learning: CE vs PL - 图42的优化分为两步,首先计算Metric Learning: CE vs PL - 图43,并最小化Metric Learning: CE vs PL - 图44

Metric Learning: CE vs PL - 图45

这是为了保证原函数亦减少:

Metric Learning: CE vs PL - 图46%5Cleq%20at(W%7Bt%2B1%7D)%5Cleq%20at(W_t)%3Df(W_T)%20%5Ctag%7B12%7D%0A#card=math&code=f%28W%7Bt%2B1%7D%29%5Cleq%20at%28W%7Bt%2B1%7D%29%5Cleq%20a_t%28W_t%29%3Df%28W_T%29%20%5Ctag%7B12%7D%0A&id=T7t2M)

CCCP,EM和SSP算法都可以用来解决这个问题。

基于上述理论基础,作者认为优化编码器参数Metric Learning: CE vs PL - 图47与分类器的权重Metric Learning: CE vs PL - 图48可以被近似为优化一个对间交叉熵损失函数(PCE Loss),其定义如下:

Metric Learning: CE vs PL - 图49

这里,Metric Learning: CE vs PL - 图50是类别Metric Learning: CE vs PL - 图51的嵌入特征的加权平均,其中的Metric Learning: CE vs PL - 图52是数据点Metric Learning: CE vs PL - 图53属是否属于第Metric Learning: CE vs PL - 图54类的softmax概率。Metric Learning: CE vs PL - 图55且依赖于编码器Metric Learning: CE vs PL - 图56,需要在每次迭代时计算一个矩阵的特征值。

获得这一PCE损失的思路是通过将交叉熵损失函数进行一定程度的变形以分为两部分:

Metric Learning: CE vs PL - 图57%7D%2B%5Cunderbrace%7B%5Cfrac%7B1%7D%7Bn%7D%5Csum%7Bi%3D1%7D%5En%5Clog%20%5Csum%7Bk%3D1%7D%5EKe%5E%7B%5Ctheta%5ETkz_i%7D-%5Cfrac%7B%5Clambda%7D%7B2%7D%5Csum%7Bk%3D1%7D%5EK%5Ctheta%5ETk%5Ctheta_k%7D%7Bf2(%5Ctheta)%7D%20%5Ctag%7B14%7D%0A#card=math&code=L%7BCE%7D%3D%5Cunderbrace%7B-%5Cfrac%7B1%7D%7Bn%7D%5Csum%7Bi%3D1%7D%5En%5Ctheta%5ET%7Byi%7Dz_i%2B%5Cfrac%7B%5Clambda%7D%7B2%7D%5Csum%7Bk%3D1%7D%5EK%5Ctheta%5ETk%5Ctheta_k%7D%7Bf1%28%5Ctheta%29%7D%2B%5Cunderbrace%7B%5Cfrac%7B1%7D%7Bn%7D%5Csum%7Bi%3D1%7D%5En%5Clog%20%5Csum%7Bk%3D1%7D%5EKe%5E%7B%5Ctheta%5ET_kz_i%7D-%5Cfrac%7B%5Clambda%7D%7B2%7D%5Csum%7Bk%3D1%7D%5EK%5Ctheta%5ETk%5Ctheta_k%7D%7Bf_2%28%5Ctheta%29%7D%20%5Ctag%7B14%7D%0A&id=LqVGO)

并选择合适的Metric Learning: CE vs PL - 图58 来使得两部分均为凸函数,且插入一些均值,如Metric Learning: CE vs PL - 图59与加权均值Metric Learning: CE vs PL - 图60作为Metric Learning: CE vs PL - 图61的代替来得到函数的一个上界估计。因此,寻找CE Loss的最优Metric Learning: CE vs PL - 图62将被解释为在PCE损失上构造一个辅助函数Metric Learning: CE vs PL - 图63%3DL%7BCE%7D(W%2C%5Ctheta%5E*)#card=math&code=a_t%28W%29%3DL%7BCE%7D%28W%2C%5Ctheta%5E%2A%29&id=GtOVi)。

为了避免计算过程复杂的Metric Learning: CE vs PL - 图64,作者进一步优化了PCE Loss,演化出一个简化版本的PCE:SPCE Loss

Metric Learning: CE vs PL - 图65%7D%7BCONTRASTIVE%7D%20%5Ctag%7B15%7D%0A#card=math&code=L%7BSPCE%7D%3D%5Cunderbrace%7B-%5Cfrac%7B1%7D%7Bn%5E2%7D%5Csum%7Bi%3D1%7D%5En%5Csum%7Bj%3Ayj%3Dy_i%7Dz%5ET_iz_j%7D%7BTIGHTNESS%7D%2B%5Cunderbrace%7B%5Cfrac%7B1%7D%7Bn%7D%5Csum%7Bi%3D1%7D%5En%5Clog%5Csum%7Bk%3D1%7D%5EKexp%28%5Cfrac%7B1%7D%7Bn%7D%5Csum%7Bi%3Ay_j%3Dk%7Dz%5ET_iz_j%29%7D%7BCONTRASTIVE%7D%20%5Ctag%7B15%7D%0A&id=gwilh)

SPCE与PCE的区别在于SPCE对比项中使用加权平均替换了PCE之中的平均数,从而大大减少计算量。

A discriminative view of mutual information

首先作者给出引理2:

  • Lemma 2
    最小化条件交叉熵Metric Learning: CE vs PL - 图66#card=math&code=H%28Y%3B%5Chat%7BY%7D%7C%5Chat%7BZ%7D%29&id=mYFFE)相当于最大化互信息Metric Learning: CE vs PL - 图67#card=math&code=%5Cmathbb%7BI%7D%28%5Chat%7BZ%7D%3BY%29&id=Yy6dm)。

这个信息论论证强化了从命题 1 得出的结论,即交叉熵和前文描述的对间损失本质上是在做同样的工作。

Then why would cross-entropy work better?

在文中的实验描述部分,交叉熵损失表现比绝大多数对间损失都要好,因此作者在这里给出了他们的分析。

作者认为,交叉熵损失的的优化过程更加优秀。一方面,对间损失需要仔细的样本挖掘和加权策略来获得信息含量最多的对,尤其是在考虑小批次时,以便在合理的时间内使用合理的内存量实现收敛。 另一方面,优化交叉熵要容易得多,因为它只意味着最小化一个一元项。 本质上,交叉熵可以在不处理成对项的困难的情况下完成这一切。 它不仅使优化更容易,而且简化了实现,从而增加了它在现实世界问题中的潜在适用性。