笔者碎碎念:本文是因为在看SWAV的时候不懂什么是Sinkhorn而写,提醒自己不懂的地方太多。 限于笔者水平有限,或出现错误,还请看官怒斥。 2022/4/6
参考链接:
https://zhuanlan.zhihu.com/p/458312488
https://zhuanlan.zhihu.com/p/441197063
最优传输
下面给出一些定义:传输方案,代价函数
假设这里分布是一堆沙子,我们要将位于处的沙子搬到处,因此搬运过程付出的代价为。这里考虑的是一粒沙子,那么对于整堆沙子来说,记某一微元的沙粒总量为,是红色沙堆的概率密度函数(PDF)
因此对于搬运方法所需要的总代价为
对于很小的,总的代价可以写成积分形式
所以,我们的任务是找到一个最优传输映射,使得总的代价最小,可以表示为:
表示的所有取值集合。
但是对于式1解的存在性、唯一性和正则性都很难直接计算,Kantorovich将其拓展为OT问题,Kantorovich的思路点处的沙粒可以被分配到任意一个地方。
为此我们给出下面的定义:
的联合概率密度函数,将视为处沙堆搬运到处沙的量。
因此可知,分别是x处搬运前的沙的量,y处搬运后的量。
我们记分别为处和处的沙堆分布,所以有
记这样的联合概率分布的集合为,称之为的传输方案,Kantorovich考虑下面的OT问题:
可以看出式2是关于传输方案的凸函数。
Sinkhorn算法概述
什么是Sinkhorn,它用来干嘛
Sinkhorn是一种OT(Optimal Transport)算法,你可以将其建模为两个分布,将分布x变换为y的任务。Sinkhorn就是为了找到最优的传输方案(将分布转换为分布),使得消耗最少。
有兴趣的读者可以自行搜索Wasserstein距离
算法推导
问题定义
因为计算机只能处理离散的数据,我们分布离散为n点点集得到位置向量,我们在第一节的提到的密度归一化并用重新表示为
可以看做处分布初始状态的量,表示处分布终止状态的量。
我们使用矩阵表示搬运沙的消耗,比如表示将的单位量的沙搬到处的消耗(cost)。
我们使用矩阵表示将向量表示的沙的分布搬到表示的沙的分布的最优传输方案,比如表示为将处的沙子的的量搬运到处。
所以这里就有,其中
我们下面将满足条件的所有矩阵矩阵的几个记为,类似于式2,我们将满足分布的的沙堆搬运至分布的沙堆的最优传输方案写为下面最优化形式:
可以注意到式3是一个线性的最优化问题,但是可能解不唯一(搬运沙堆的方法不唯一)。
熵正则化
为了解决解不唯一的问题,我们使用熵正则化来选择一个唯一解。对于熵正则化后的问题,我们能使用比单纯形法描述更简单的Sinkhorn算法来求解。同时,Sinkhorn算法也能更适合GPU的并行计算。
下面,我们定义熵函数,规定如果有小于等于0的数,那么。
下面我们将式3近似为
可以证明,问题4的解是唯一的。并且,当的时候,问题4的最优解会收敛到问题3的解的集合中具有最大熵的解。
笔者小猜想:这里的熵最大可否理解为最优化的值最小?
喜闻乐见拉格朗日
那么就到了大家最喜欢的拉格朗日乘子法了。我们这里构造拉格朗日量,给定向量和:
求导:
化简得到最优解为
我们记,于是式6就可以写为
其中为对角元素为向量对应值的矩阵,即
迭代形式
由于,满足
改写为逐元素相乘的形式
这里的表示为逐元素乘法。
为了求解式8,Sinkhorn使用迭代算法,首先初始化,使用迭代式:
代码
当然聪明如你,推导完这些数学公式之后,一定对代码怎么写已经有了想法,下面给出一个例程。
import numpy as np
import matplotlib.pyplot as plt
import scipy as scp
N = 200
t = np.arange(0, N)/N
Gaussian = lambda t0, sigma: np.exp(-(t-t0)**2/(2*sigma**2))
normalize = lambda p: p/np.sum(p)
sigma = .06
a = Gaussian(.25, sigma)
b = Gaussian(.8, sigma) + 3 * Gaussian(.6, sigma) + Gaussian(0.4, sigma)
vmin = .02
a = normalize(a+np.max(a)*vmin)
b = normalize(b+np.max(b)*vmin)
plt.figure(figsize = (10,7))
plt.subplot(2, 1, 1)
plt.bar(t, a, width = 1/len(t), color = "darkblue")
plt.subplot(2, 1, 2)
plt.bar(t, b, width = 1/len(t), color = "darkblue")
epsilon = (.03)**2
[Y, X] = np.meshgrid(t,t)
K = np.exp(-(X-Y)**2/epsilon)
v = np.ones(N)
niter = 4000
Err_p = np.zeros(niter)
Err_q = np.zeros(niter)
for i in range(niter):
u = a / (np.dot(K, v))
r = v * (np.dot(K, u))
Err_q[i] = np.linalg.norm(r - b, ord=1)
v = b / (np.dot(K, u))
s = u * (np.dot(K, v))
Err_p[i] = np.linalg.norm(s - a, ord=1)
plt.figure(figsize = (10, 7))
plt.subplot(2, 1, 1)
plt.title("$||P1 -a||_1$")
plt.plot(np.log(np.asarray(Err_p)), linewidth=2)
plt.subplot(2, 1, 2)
plt.title("$||P^T 1 -b||_1$")
plt.plot(np.log(np.asarray(Err_q)), linewidth=2)
P = np.dot(np.dot(np.diag(u),K),np.diag(v))
plt.figure(figsize=(5,5))
plt.imshow(np.log(P+1e-5))
plt.axis('off')
plt.show()
输出如下:
当然也完全等价于这么写:
def sinkhorn(K,a,b,nither=4000):
for _ in range(nither):
K *= (a/K.sum(1))[:, np.newaxis]
K *= (b/K.sum(0))[np.newaxis, :]
return K