笔者碎碎念:本文是因为在看SWAV的时候不懂什么是Sinkhorn而写,提醒自己不懂的地方太多。 限于笔者水平有限,或出现错误,还请看官怒斥。 2022/4/6

参考链接:
https://zhuanlan.zhihu.com/p/458312488
https://zhuanlan.zhihu.com/p/441197063


最优传输

下面给出一些定义:传输方案Sinkhorn算法简介与推导 - 图1,代价函数Sinkhorn算法简介与推导 - 图2
image.png
假设这里分布是一堆沙子,我们要将位于Sinkhorn算法简介与推导 - 图4处的沙子搬到Sinkhorn算法简介与推导 - 图5处,因此搬运过程Sinkhorn算法简介与推导 - 图6付出的代价为Sinkhorn算法简介与推导 - 图7。这里考虑的是一粒沙子,那么对于整堆沙子来说,记某一微元的沙粒总量为Sinkhorn算法简介与推导 - 图8Sinkhorn算法简介与推导 - 图9是红色沙堆的概率密度函数(PDF)
v2-ab82f0b7f8cdce6e6c1032f4beaf119a_720w.jpg
因此对于搬运方法Sinkhorn算法简介与推导 - 图11所需要的总代价为Sinkhorn算法简介与推导 - 图12
对于很小的Sinkhorn算法简介与推导 - 图13,总的代价可以写成积分形式Sinkhorn算法简介与推导 - 图14
所以,我们的任务是找到一个最优传输映射,使得总的代价最小,可以表示为:
Sinkhorn算法简介与推导 - 图15
Sinkhorn算法简介与推导 - 图16表示Sinkhorn算法简介与推导 - 图17的所有取值集合。

但是对于式1解的存在性、唯一性和正则性都很难直接计算,Kantorovich将其拓展为OT问题,Kantorovich的思路点Sinkhorn算法简介与推导 - 图18处的沙粒可以被分配到任意一个地方。

为此我们给出下面的定义:
Sinkhorn算法简介与推导 - 图19的联合概率密度函数Sinkhorn算法简介与推导 - 图20,将Sinkhorn算法简介与推导 - 图21视为Sinkhorn算法简介与推导 - 图22处沙堆搬运到Sinkhorn算法简介与推导 - 图23处沙的量。
因此可知,Sinkhorn算法简介与推导 - 图24分别是x处搬运前的沙的量,y处搬运后的量。
我们记Sinkhorn算法简介与推导 - 图25分别为Sinkhorn算法简介与推导 - 图26处和Sinkhorn算法简介与推导 - 图27处的沙堆分布,所以有Sinkhorn算法简介与推导 - 图28

记这样的联合概率分布的集合为Sinkhorn算法简介与推导 - 图29,称之为Sinkhorn算法简介与推导 - 图30的传输方案,Kantorovich考虑下面的OT问题:
Sinkhorn算法简介与推导 - 图31
可以看出式2是关于传输方案Sinkhorn算法简介与推导 - 图32的凸函数。

Sinkhorn算法概述

什么是Sinkhorn,它用来干嘛

Sinkhorn是一种OT(Optimal Transport)算法,你可以将其建模为两个分布Sinkhorn算法简介与推导 - 图33,将分布x变换为y的任务。Sinkhorn就是为了找到最优的传输方案(将Sinkhorn算法简介与推导 - 图34分布转换为Sinkhorn算法简介与推导 - 图35分布),使得消耗最少。

有兴趣的读者可以自行搜索Wasserstein距离

算法推导

问题定义

因为计算机只能处理离散的数据,我们分布离散为n点点集得到位置向量Sinkhorn算法简介与推导 - 图36,我们在第一节的提到的密度Sinkhorn算法简介与推导 - 图37归一化并用Sinkhorn算法简介与推导 - 图38重新表示为
Sinkhorn算法简介与推导 - 图39
Sinkhorn算法简介与推导 - 图40可以看做Sinkhorn算法简介与推导 - 图41处分布初始状态的量,Sinkhorn算法简介与推导 - 图42表示Sinkhorn算法简介与推导 - 图43处分布终止状态的量。

我们使用矩阵Sinkhorn算法简介与推导 - 图44表示搬运沙的消耗,比如Sinkhorn算法简介与推导 - 图45表示将Sinkhorn算法简介与推导 - 图46的单位量的沙搬到Sinkhorn算法简介与推导 - 图47处的消耗(cost)。

我们使用矩阵Sinkhorn算法简介与推导 - 图48表示将向量Sinkhorn算法简介与推导 - 图49表示的沙的分布搬到Sinkhorn算法简介与推导 - 图50表示的沙的分布的最优传输方案,比如Sinkhorn算法简介与推导 - 图51表示为将Sinkhorn算法简介与推导 - 图52处的沙子的Sinkhorn算法简介与推导 - 图53的量搬运到Sinkhorn算法简介与推导 - 图54处。
所以这里就有Sinkhorn算法简介与推导 - 图55,其中Sinkhorn算法简介与推导 - 图56

我们下面将满足条件的所有矩阵Sinkhorn算法简介与推导 - 图57矩阵的几个记为Sinkhorn算法简介与推导 - 图58,类似于式2,我们将满足分布的Sinkhorn算法简介与推导 - 图59的沙堆搬运至分布Sinkhorn算法简介与推导 - 图60的沙堆的最优传输方案写为下面最优化形式:
Sinkhorn算法简介与推导 - 图61
可以注意到式3是一个线性的最优化问题,但是可能解不唯一(搬运沙堆的方法不唯一)。

熵正则化

为了解决解不唯一的问题,我们使用熵正则化来选择一个唯一解。对于熵正则化后的问题,我们能使用比单纯形法描述更简单的Sinkhorn算法来求解。同时,Sinkhorn算法也能更适合GPU的并行计算。
下面,我们定义熵函数Sinkhorn算法简介与推导 - 图62,规定如果Sinkhorn算法简介与推导 - 图63有小于等于0的数,那么Sinkhorn算法简介与推导 - 图64
下面我们将式3近似为
Sinkhorn算法简介与推导 - 图65
可以证明,问题4的解是唯一的。并且,当Sinkhorn算法简介与推导 - 图66的时候,问题4的最优解Sinkhorn算法简介与推导 - 图67会收敛到问题3的解的集合中具有最大熵的解。

笔者小猜想:这里的熵最大可否理解为最优化的值最小?

喜闻乐见拉格朗日

那么就到了大家最喜欢的拉格朗日乘子法了。我们这里构造拉格朗日量,给定向量Sinkhorn算法简介与推导 - 图68Sinkhorn算法简介与推导 - 图69
Sinkhorn算法简介与推导 - 图70
求导:
Sinkhorn算法简介与推导 - 图71
化简得到最优解为
Sinkhorn算法简介与推导 - 图72
我们记Sinkhorn算法简介与推导 - 图73,于是式6就可以写为
Sinkhorn算法简介与推导 - 图74
其中Sinkhorn算法简介与推导 - 图75为对角元素为向量Sinkhorn算法简介与推导 - 图76对应值的矩阵,即Sinkhorn算法简介与推导 - 图77

迭代形式

由于Sinkhorn算法简介与推导 - 图78Sinkhorn算法简介与推导 - 图79满足Sinkhorn算法简介与推导 - 图80
改写为逐元素相乘的形式
Sinkhorn算法简介与推导 - 图81
这里的Sinkhorn算法简介与推导 - 图82表示为逐元素乘法。
为了求解式8,Sinkhorn使用迭代算法,首先初始化Sinkhorn算法简介与推导 - 图83,使用迭代式:
Sinkhorn算法简介与推导 - 图84

代码

当然聪明如你,推导完这些数学公式之后,一定对代码怎么写已经有了想法,下面给出一个例程。

  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import scipy as scp
  4. N = 200
  5. t = np.arange(0, N)/N
  6. Gaussian = lambda t0, sigma: np.exp(-(t-t0)**2/(2*sigma**2))
  7. normalize = lambda p: p/np.sum(p)
  8. sigma = .06
  9. a = Gaussian(.25, sigma)
  10. b = Gaussian(.8, sigma) + 3 * Gaussian(.6, sigma) + Gaussian(0.4, sigma)
  11. vmin = .02
  12. a = normalize(a+np.max(a)*vmin)
  13. b = normalize(b+np.max(b)*vmin)
  14. plt.figure(figsize = (10,7))
  15. plt.subplot(2, 1, 1)
  16. plt.bar(t, a, width = 1/len(t), color = "darkblue")
  17. plt.subplot(2, 1, 2)
  18. plt.bar(t, b, width = 1/len(t), color = "darkblue")
  19. epsilon = (.03)**2
  20. [Y, X] = np.meshgrid(t,t)
  21. K = np.exp(-(X-Y)**2/epsilon)
  22. v = np.ones(N)
  23. niter = 4000
  24. Err_p = np.zeros(niter)
  25. Err_q = np.zeros(niter)
  26. for i in range(niter):
  27. u = a / (np.dot(K, v))
  28. r = v * (np.dot(K, u))
  29. Err_q[i] = np.linalg.norm(r - b, ord=1)
  30. v = b / (np.dot(K, u))
  31. s = u * (np.dot(K, v))
  32. Err_p[i] = np.linalg.norm(s - a, ord=1)
  33. plt.figure(figsize = (10, 7))
  34. plt.subplot(2, 1, 1)
  35. plt.title("$||P1 -a||_1$")
  36. plt.plot(np.log(np.asarray(Err_p)), linewidth=2)
  37. plt.subplot(2, 1, 2)
  38. plt.title("$||P^T 1 -b||_1$")
  39. plt.plot(np.log(np.asarray(Err_q)), linewidth=2)
  40. P = np.dot(np.dot(np.diag(u),K),np.diag(v))
  41. plt.figure(figsize=(5,5))
  42. plt.imshow(np.log(P+1e-5))
  43. plt.axis('off')
  44. plt.show()

输出如下:
image.png
image.png
image.png
当然也完全等价于这么写:

def sinkhorn(K,a,b,nither=4000):
    for _ in range(nither):  
        K *= (a/K.sum(1))[:, np.newaxis]
        K *= (b/K.sum(0))[np.newaxis, :]  
    return K