今天在推特上看到了 consistency models 代码的开源

一步生成的扩散模型:Consistency Models - 知乎 - 图1

这篇工作的论文在今年3月放出,那时候还没有开源代码,但是其展现的潜力让人印象深刻。因为 Diffusion Models 在生成一张图片时需要多次进行模型推理,对于实时性较强的应用,就很难让人满意了。虽然也有后续一些采样相关的工作比如 DDIM,DPM-Solver,Uni-PC 将推理步数缩减到10步左右,但是还没有像这篇文章所claim的一步采样即能达到较好的效果。

笔者和大家一起来看看这篇文章。

Arxiv: https://arxiv.org/pdf/2303.01469.pdf

Code: https://github.com/openai/consistency_models

背景:Diffusion Models 的 SDE 形式和 ODE 形式

首先回顾一下diffusion的算法原理,假设我们有数据分布 pdata(x)p_{data}(x) , 扩散模型通过如下的 SDE 对数据分布进行

一步生成的扩散模型:Consistency Models - 知乎 - 图2

stochastic differential equation

songyang 推导出,上述 SDE 存在一个 ODE 形式的解轨迹

一步生成的扩散模型:Consistency Models - 知乎 - 图3

ODE trajectory

其中 \nabla \log p_t(x) 是 p_t(x) 的得分函数。得分函数是diffusion model 的直接或者间接学习目标。

采用 EDM 中的 setting,设置 \mu (x, t) = 0 , \sigma(t) = \sqrt{2t} , 训练一个得分模型 s_ {\phi}(x, t) \approx \nabla \log p_t(x)

上述 ODE 转为

一步生成的扩散模型:Consistency Models - 知乎 - 图4

得到 ode 的具体形式后,利用现有的数值 ODE solver,如 Euler, Heun, Lms 等,即可解出 x(.). 考虑到数值精确性,我们往往不会直接求出原图即 x(0) ,而是计算出一个 x(t-\delta_t) ,持续这个过程来求出 x(0) 。

Consistency Models 的概念

回顾一下 diffusion 的采样过程,从先验分布 (x_{t_N}, t_N) 出发,推导采样过程

(x{t_N}, t_N) \rightarrow (x{t{N-1}}, t{N-1})\rightarrow … \rightarrow (x_{t_0}, t_0)

Consistency Models 假设存在一个函数 f ,对于上述过程中的每个点,f都能输出一个相同的值

一步生成的扩散模型:Consistency Models - 知乎 - 图5

并且对于轨迹的起点 x_0 = \epsilon ,我们有

一步生成的扩散模型:Consistency Models - 知乎 - 图6

那么对于轨迹中任意一点,我们代入先验分布, 即可得到 f(xT, T) = x{\epsilon} 。这样也就完成了一步采样。

自然想到训练一个神经网络来拟合 f,但是这里要满足两个条件,一个是轨迹上的点输出值一致,一个是在起始时间点 f 为一个对于x的恒等函数。

作者做了如下的设计,巧妙的实现了上述目标

一步生成的扩散模型:Consistency Models - 知乎 - 图7

其中 c{skip} 和 c{out}为可微函数,满足 c{skip}(\epsilon) = 1, c{out}(\epsilon) = 0. F_{\theta} 为深度神经网络,输出维度同 x .

这样,第二个条件自然满足,因为有

一步生成的扩散模型:Consistency Models - 知乎 - 图8

F_{\theta} 可使用一致性损失来学习。

观察 f 的性质,显然, f(xT, T) = x{\epsilon} 可以得到我们想要的生成结果。但一般认为,这样的生成误差会比较大。就像 DDPM 也可以通过预测噪声直接从 xt 预测 x_0 , 但我们会依赖 x_t, x_0 预测 x{t-1} ,依次向下采样来获得 x_0 来减小误差。

同样的,我们每次从 x{\tau_n} 预测出初始的 x 后,回退一步来预测 x{\tau_{n-1}} 来减小误差。因此有如下多步采样的算法。

一步生成的扩散模型:Consistency Models - 知乎 - 图9

Consistency Models 的训练方式

考虑到 Consistency Models 的性质,对采样轨迹上的不同点,f应该有一个相同的输出,自然我们需要找到采样的轨迹。

如果我们有预训练的diffusion model,自然可以构建轨迹。

假设采样轨迹的时间序列为

\epsilon = t_1<t_2<…<t_N = T

ODE sampler 的采样可以 formulate 成

一步生成的扩散模型:Consistency Models - 知乎 - 图10

其中 \Phi( x{t{n+1}}, t_{n+1}; \phi) 为 ODE solver。

使用 Euler Solver, 代入求解对象式(3),上式转为

一步生成的扩散模型:Consistency Models - 知乎 - 图11

当然也可以采用其他的 Solver,一般来说阶数越高的 Solver 求解精度越高。

对于处于同一轨道的 (x{t{n+1}}, t{n+1}), (\hat{x}^{\phi}{t{n}}, t{n}) ,f应该有相同的输出,我们用距离函数 d 来衡量输出是否相同。

因而有如下训练损失

一步生成的扩散模型:Consistency Models - 知乎 - 图12

其中 \lambda 用来对不同时间步赋予不同重要性, \theta^{-} 为 EMA 版本的权重。

综合上述过程,蒸馏(Consistency Distillation)的算法为

一步生成的扩散模型:Consistency Models - 知乎 - 图13

在蒸馏的过程中,我们实际上用预训练模型来估计得分 \nabla \log p_t(x) 。 如果从头训练,需要找一个不依赖于预训练模型的估计方法。

作者在论文中证明了一种新得分函数的估计

一步生成的扩散模型:Consistency Models - 知乎 - 图14

利用该得分估计,作者为从 Isolation Training 构建了一个新的训练损失,

一步生成的扩散模型:Consistency Models - 知乎 - 图15

并且证明了该 Loss 和 Distillation Loss 在最大间隔趋于0时相等。即

一步生成的扩散模型:Consistency Models - 知乎 - 图16

从而可以利用上述 loss 训练一个 Consistency Models,并且不依赖于已有 Diffusion Model。具体算法如下

一步生成的扩散模型:Consistency Models - 知乎 - 图17

实验结果

作者分析了总时间步 N ,EMA decay rate \mu 随时间轮次变化的Scheduler,以及 ODE solver, 距离函数 d 的选取对结果的影响,这里不做介绍。不过有意思的是,实验结果表明,使用 lpips 距离的效果会超出 l_1, l_2 函数,并且这种选取也会提升对比的蒸馏方法 Progressive Distillization 的效果。

这让我想起了两篇论文,Projected GANs Converge FasterThe Role of ImageNet Classes in Fréchet Inception Distance, 笔者认为使用 lpips 作为距离函数某种程度 hack 了 FID,从而取得了更好的指标。当 NFE = 1 时,这种效应尤其明显.

一步生成的扩散模型:Consistency Models - 知乎 - 图18

笔者简单看了一眼实验结果,CD(consistency distillatio)和 Diffusion 的其他蒸馏工作相比,该方法在1,2步采样下,结果是最好的,但似乎仍然无法与当前最好的GAN方法相比。CT (consistency distillation)点数则相差比较大,在 ImageNet 上 NFE =1 的 fid 甚至超过了 10。

但笔者仍认为该方法具有很高的价值,目前的diffusion models,采样速度确实是一个痛点,而现在有大量基于diffusion的工作,如果能把现有工作如 stable diffusion 用该方法进行蒸馏,有助于满足高实时性的需求。