- 原始语雀文档:https://www.yuque.com/lart/gw5mta/vhbggb
- 论文:https://arxiv.org/abs/2205.02399
代码:https://github.com/zju-vipa/spot-adaptive-pytorch
内容摘要
知识蒸馏(KD)已成为压缩深神经网络的一种表现良好的范式。进行知识蒸馏的典型方法是在教师网络的监督下训练学生网络,在教师网络中以一个或多个位置(spots,即 layers)来利用(harness)知识。在整个蒸馏过程中,一旦针对所有训练样本指定了蒸馏位置就不会再更改。
在这项工作中,我们认为蒸馏位置应该适应训练样本和蒸馏周期。因此,我们提出了一种新的蒸馏策略,称为 spot-adaptive KD(SAKD),以在整个蒸馏期间的每次训练迭代中,针对每个样本而在教师网络中自适应决定蒸馏位置。由于 SAKD 实际上专注于“在哪里蒸馏”,而不是大多数现有作品广泛研究的“要蒸馏什么”,因此可以将其无缝集成到现有的蒸馏方法中以进一步提高其性能。
在 10 种最先进的蒸馏算法上的实验验证了在同质和异质(homogeneous and heterogeneous,这里主要强调学生网络和教师网络是否是相同风格的架构)蒸馏的设置下 SAKD 在改善其蒸馏性能方面的有效性。相关工作
知识蒸馏
将 DNN 部署到资源有限的边缘设备上仍然存在 困难。为了使 DNN 更适用于这些现实情况,知识蒸馏(KD)被用于针对 DNN 的构建轻巧的替代。其主典型思想是采用一种“教师-学生”的学习形式,这里具有竞争力的轻量级替代品(称为学生)是通过模仿行为良好但结构繁琐的“教师”DNN 的行为而产生的。通过利用教师模型学到的隐式知识(dark knowledge),轻量 de 学生模型被期望可以实现可比的性能,但参数量却要少得多。
KD 已经逐渐成为一种成熟的模型压缩范式,已经出现了大量的相关工作。
除了经典的“教师-学生”范式,也存在一些其他的形式:mutual distillation:一个学生模型的集成形式会协作学习并在整个训练过程中相互教学,使教师和学生模型能够共同发展。
- self distillation:不适用独立的教师网络,而是通过使用更深层的输出来监督较浅的层,从而来提炼学生网络本身的知识。
- knowledge amalgamation:旨在通过融合来自多个教师模型的知识,从而构建单个多任务形式的学生模型。
- data-free distillation:会放松以下假设:教师模型的训练数据可用于训练学生。即应在没有任何原始训练数据的情况下将知识迁移给学生。
在这项工作中,我们仍然遵循传统设定,即教师网络和原始训练数据都可以用于训练学生网络。但是我们认为,这个一般的想法也可以应用于各种 KD 设置,这将留给将来的工作。
基于蒸馏位置的差异,可以将他们大致归为两类:
- one-spot 蒸馏:KD 仅发生在单一位置,典型的形式是在 logit 层。
- Hinton 等人关于蒸馏的工作中提出最大程度地减少教师的概率输出与学生网络输出之间的差异。
- 对比表示蒸馏(CRD)采用对比学习方法来提炼结构知识,即表征层不同输出维度之间的相互依赖关系。
- 关系知识蒸馏(RKD)将数据样本的相互关系从教师模型转移到学生模型中,其中相互关系是在单个表示层中产生的。
- multi-spot 蒸馏:通过从教师网络中的多层挖掘知识,从而监督学生模型的学习。
- Fitnets 不仅用了输出,还用了教师学到的中间表征,作为训练学生的提示。
- 注意转移(AT)策略,通过强迫学生模仿强大的教师网络的不同层的注意力图来提高学生模型的性能。
- 激活边界(AB)通过不同层的隐神经元形成的蒸馏激活边界传递教师模型的知识。
由于 multi-spot 蒸馏方法与 one-spot 方法相比,从教师模型中使用的信息更多,因此通常认为它们可以表现出较好迁移效果。
现有的蒸馏方法,无论是 one-spot 还是 mutli-spot,都共享一个共同的特征:蒸馏点通常是一种手动设计的选择,无法优化,尤其是对于具有数百或数千层的网络。一方面,如果蒸馏点过于稀疏,则学生模型不受教师的充分监督。另一方面,如果蒸馏斑的设置过于密集,例如,每个可能的层或神经元都被用上,学生模型的学习可能会过度正则化,这也会导致蒸馏性能恶化。
此外,当前方法采用了全局蒸馏策略,即蒸馏点一旦确定,就对于所有样本都是固定的。基本的假设是,这种蒸馏点对于整个数据分布都是最佳的,这在许多情况下是不正确的。理想情况下,我们希望可以为每个样本以及每个可能的位置自动确定蒸馏点。
这项工作提出了一种新的蒸馏策略 spot-adaptive KD(SAKD)。这使蒸馏点自适应与训练样本和蒸馏阶段。
为此,我们首先将学生模型和教师模型合并到一个多路径路由网络中,如图 2 所示。结构对数据流提供了多种可行的路径。当数据到达网络分支点时,一个轻巧的策略网络(policy network)会为每个样本选择最优的传播路径。
- 如果数据被路由教师模型的层,则表明学生模型中的对应层(缩写为学生层)无法替换教师模型中的层(缩写为教师层)。因此,这些老师层中的知识应被蒸馏到相应的学生层。
- 如果数据被路由到学生层,则表明这些学生层是相应的教师层的良好替代品,从而产生了更好的或至少可比的表现。这些层中不允许蒸馏。
由于策略网络是在路由网络之上设计的,并通过路由网络同时进行优化,因此在学生模型的训练迭代中它可以自动确定每个样本的最佳蒸馏点。
由此可以看出来,所提出的方法着重于“在哪里提炼”,这与当前工作不同且相互正交,他们主要研究了“什么要蒸馏”,即要蒸馏的知识的形式。因此,提出的方法可以与现有方法无缝结合,以进一步增强蒸馏性能。具体而言,所提出的方法自然与同质蒸馏兼容,其中学生模型与教师模型相同。但是,实验表明,在异质蒸馏设置下,所提出的方法也可以很好地发挥作用,其中学生模型与教师模型架构差异很大。此外,尽管所提出的方法主要用于 multi-spot 蒸馏,但它也可以通过动态确定每个训练样本的蒸馏来提高 one-spot 蒸馏的性能。
总体而言,主要的贡献为:
- 第一次引入了自适应蒸馏问题,其中蒸馏点应该自适应于不同的训练样本和蒸馏阶段。
- 针对这一问题,提出了一个 spot-adaptive 蒸馏策略,可以自动的确定蒸馏位置,使得蒸馏位置可以自适应于训练样本和时期。
- 实验结果证明了提出的方法对于提升现有的蒸馏策略的有效性。
路由网络(Routing Networks)
路由网络是一种高模块化的神经网络,这是鼓励任务分解、降低模型复杂性和改善模型泛化能力所需的关键属性。路由网络通常由两个可训练的组件组成:一组函数模块(function settings)和一个策略代理(policy agent)。
- 函数模块:在神经网络设置中,函数模块由子网络实现,并被用作处理输入数据的候选模块。
- 策略代理:对于每个样本,策略代理从这些候选者中选择一个函数模块的子集,将它们组装成一个完整的模型,并将组装模型应用于输入数据以进行任务预测。数种算法已经被提出用于优化策略模块,包括遗传算法,多代理强化学习,重参数化策略等。
路由网络与数种结构密切相关,例如条件计算、专家混合模型,以及他们基于现代注意力和稀疏结构的变体。它们已成功应用于多任务学习,迁移学习和语言模型等多个领域。在这项工作中,借助路由网络,我们提出了一种新型的蒸馏策略,以自动确定网络中的蒸馏位置。
位置自适应的知识蒸馏(SPOT-ADAPTIVE KNOWLEDGE DISTILLATION)
整个模型由两个主要组成部分组成:多路径路由网络和一个轻量级策略网络。
- 多路径路由网络由教师模型和学生模型组成,并具有适应层,以便在必要时彼此适应特征。
- 当数据到达路由网络中的分支点时,策略网络用于在数据流路径上的每个样本中做出路由决策。
所提出的蒸馏方法的一般思想是自动确定是否要在候选蒸馏点进行蒸馏,如图 2 所示。如果样本被策略网络路由到某些教师层,则表明对应的学生层不能替换这些老师层。因此,这些老师层中的知识应被蒸馏成相应的学生层。如果数据通过策略网络传输到某些学生层,则表明这些学生层是相应的教师层的良好替代品,可以产生优于或至少可比的表现。这些位置就不再需要蒸馏。
蒸馏的最终目标是使策略网络逐渐为路由数据而选择学生层,这意味着学生模型是教师网络的良好替代品。
多路径路由网络
不失一般性,假设用于视觉分类的卷积神经网络(CNN)由几个用于表征学习的卷积块、一个用于矢量化特征图的全连接的层以及用于做出概率预测的 softmax 层组成。每个卷积块由几个卷积层组成,每个卷积层随后是非线性激活层和批归一化层。一般而言,在每个块之后,特征图会被池化层或卷积层缩小了 2 个或更多。
由此,教师网络和学生网络可以大致表示为数个卷积块、一个线性层和一个 softmax 层的级联组合。
多路路由网络由教师和学生网络组成,其中间层相互关联。然而由于他们各个层之间存在维度不匹配的问题,所以本文也引入了用于对齐特征维度的 1x1 卷积适应层。由此,多路径路由网络同样可以看做是使用多个卷积层、一个线性层和 softmax 层的级联结构,但是不同于单一网络,这里的卷积层和线性层都是教师网络与学生网络对应结构的加权融合的结果(这里会用到适应层来对齐特征)。而用于融合所使用的权重就来自于策略网络,其取值范围为 0 到 1。当特征融合权重取离散值的情况时,网络实际上就成为了教师网络部分层和学生网络部分层的组合体。
使用路由网络,最终的目的是为了获得一个独立的学生模型,其可以对感兴趣任务表现的尽可能好。
策略网络
我们采用策略网络来为通过路由网络的数据流路径上的每个样本做出决策。在这里,我们只采用轻量的全连接层来实现策略网络。
- 其输入是拼接的教师和学生模型。
- 其输出是 N+1 个 2 维路由向量,这里的 N+1 表示分支点的数量,即候选蒸馏点的数量。每个路由向量是一个概率分布,我们从中绘制一个分类值(categorical value),以为路由网络中分支点上的数据流路径做出决定。
采样操作对于离散情况是不可微的。为了确保采样操作的可微分性,这里使用 Gumbel-Softmax 技术来实现策略网络。形式上来讲,对于第 i 个分支点,对应的路由向量是一个二维向量,这里第一个元素存储着表示第 i 个块中的教师网络层有多可能会被用于处理到来的数据。
前向传播中,该策略根据基于如下分布的分类分布中得出离散决策:
这里的 w 就是一个二维的 one-hot 向量,而“one_hot”函数表示返回 one-hot 向量的函数。最右侧量是一个二维向量,其中的元素都为从 Gumbel 分布中绘制的 i.i.d 样本,用于添加少量噪声,以避免 argmax 操作始终选择具有最高概率值的元素。
为了确保离散采样函数的可微分性,这里使用了 Gumbel-Softmax 技巧来在反向传播期间放松 w。
这里的 τ 是温度参数,用于近似后的分布的锐利程度。注意对于 w 中的每个向量,其中包含的两个元素之和都为 1。
Spot-adaptive Distillation
提出的位置自适应蒸馏通过同时训练路由网络和策略网络而构造。从策略网络和路由网络的角度来看,训练所提出的网络是非平稳的,因为最佳路由策略取决于模块参数,反之亦然。
在这项工作中,多路路由网络和策略网络以端到端的方式同时训练。
完整的目标函数包括四部分:
- 使用真值对学生模型监督的交叉熵损失。
- 使用教师模型预测对学生模型监督的 KL 散度,这与 Hinton 工作中提出的普通的蒸馏损失一致。使用系数 β1 加权。
- 现有的基于中间层特征的知识蒸馏损失。例如 FitNets、Attention Transfer 等工作中提出的损失。使用系数 β2 加权。
- 使用真值监督来自路由网络的预测的交叉熵形式的路由损失。使用 β3 加权。这里说的似乎让人有些迷惑,其实作者提供的伪代码可以提供一个直观的理解。路由网络可以被认为是独立于学生网络和教师网络进行前向传播的,它的各个子结构就是对教师网络和学生网络各部分的加权组合。
在整个训练阶段,教师模型的预训练参数被保持固定。可训练的参数仅包括学生模型的参数,适应层和策略网络。策略网络和适应层仅参与仅计算路由损失,它们的参数仅在路由损失的监督下进行训练。
学生网络和策略网络形成一个环,学生模型的输出进入策略网络,策略网络的输出再次进入学生网络。为了稳定学生网络的训练,我们不再将反传策略网络的梯度到学生网络中。
在训练早期,由于教师模型已经训练得当,因此策略网络更有可能将样本传递给教师层。在这种情况下,知识蒸馏发生在所有候选蒸馏点。随着训练进行,学生模型在不同层次的不同程度上逐渐掌握了教师的知识。在这种情况下,策略网络可以为每个样本规划一条路径,其中教师层和学生层都交织在一起。因此,知识蒸馏是在某些层适应性进行的,以推动仅涉及学生层的最佳策略。
优化算法
为了使提出的方法更清晰,伪代码在算法 1 中提供。
- 给定两个深神经网络,一个学生 S 和老师 T。令 x 为网络输入。
- 我们将来自教师和学生模型的中间表征的集合表示为 featT 和 featS,最终预测表示为 logitT 和 logitS。
- 策略网络 P 的输入是教师和学生特征的拼接。策略模型 P 的输出为 N+1 个二维的路由向量,表示为 w,它们是前向传播期间做出的离散决策,并将在向后传播期间使用 Gumbel-Softmax 来放松约束。
- 这里一个明显的困难是,学生的蒸馏损失 Ls 取决于路由决策 w,因此与策略网络一起优化学生模型是有问题的。我们通过截止梯度操作来避免这种困难。这意味着 d 在损失中被视为常数。
- 学生模型的完整目标函数显示在第 26 行〜28 行中,其中包括交叉熵损失、KL 散度和知识蒸馏损失。
- 之后开始多路路由网络的传播。这包含教师模型和学生模型,中间层相互关联。通过将学生模型设置为 eval 模式(避免 BN 和 dropout 的反复变动),来让路由网络更加稳定的工作,在获得最终预测之后,再将其恢复为 train 模式。
- 为了对齐教师与学生之间的特征,引用了适应层 Hst 和 Hts。
- 交叉熵损失最终被用来优化策略模块和适应模块的参数。
实验
实验设置
在实验当前方法与其他方法的组合效果时,所有方法都会与普通的 KD 损失(从老师和学生模型中的软化预测之间的差异)结合在一起,以提高其性能。因此,所有方法均至少涉及两个蒸馏点,无论它们最初是单点或多点蒸馏方法,都将变成多点版本。而且这些方法在本文的使用中,会在训练前确定不同的蒸馏点,并在整个蒸馏过程中保持相同。如果一种方法采用块 i(1≤i≤N+1)的知识,则将其蒸馏点称为 i。
软化预测分布的温度值设置为 4。gumbel-softmax 中的 τ 最初设置为 5,并在训练期间逐渐衰减,因此网络可以在早期阶段自由探索,并在后期利用收敛后的蒸馏策略。
为简单起见,超参数 β1 和 β3 被设置为 1。β2 是根据蒸馏方法设置的。我们在用 CRD 原始论文中的参数对大多数蒸馏方法设置 β2。除了 FitNets,为了更稳定的训练,所以 β2 设置为 1 而不是 1000。 β2 的详细设置如表 II 所示。与现有方法的对比
同质蒸馏范式
注意所有的对比方法都会搭配原始的 KD 策略,而且在我们的方案中,候选蒸馏点包括 softmax 层和一些中间层。对于中间层,不同方案会利用不同数量的的中间层,提出的自适应蒸馏策略仅确定是否在这些蒸馏点处进行蒸馏。它不会在标准蒸馏方法中添加任何其他候选蒸馏点。异质蒸馏范式
与同质蒸馏相似,候选蒸馏点包括 softmax 层和中间层。softmax 层始终是自适应方案中的候选蒸馏点。ImageNet 上的验证
作者们也在 ImageNet 这样的大型数据集上验证了提出的策略的扩展性。消融实验
策略网络是否可以提供有效的决策?
我们验证了策略网络做出决策的实用性。为此,我们介绍了四种基线蒸馏策略:
- always-distillation:每个蒸馏点上始终进行蒸馏的标准蒸馏策略。
- rand-distillation:随机决定是否在候选蒸馏点上进行蒸馏。
- anti-distillation:采用与提议的自适应蒸馏相反的蒸馏策略。如果在某个点进行自适应蒸馏策略会进行蒸馏,则该策略就不会蒸馏;否则,它会在此位置蒸馏。
- no-distillation:学生在没有任何蒸馏的情况下进行了轻率(trivially)的训练。
可以看出,所提出的自适应蒸馏始终优于其他基准,包括竞争性的 always-distillation。尽管对于某些蒸馏方法的改进有时微不足道,但几乎所有蒸馏的一致改进都验证了提出的策略网络确实为蒸馏做出了有用的路由决策。此外,anti-distillation 通常比 adaptive-,adaptive-和 rand-distillation 的性能要差得多,有时甚至比没有 no-distillation 的情况还要差。这些结果表明,在不适当的地方进行蒸馏可能对训练学生有害。
决策随着位置改变会有如何的变化?
在这里,我们研究了策略网络在不同的蒸馏点和蒸馏阶段(即训练时期)做出的蒸馏决策。在每个候选蒸馏位置,蒸馏的可能性是在此点蒸馏的样本数量与训练样本总数的比例。图 3 中描述了沿训练时期的不同点的概率曲线。
- 早期阶段由于教师网络训练有素,最佳路由决策应在路由网络中的所有分支点上选择教师层。因此,在所有蒸馏点,蒸馏概率应接近 100%。但由于策略网络是随机初始化的,并且尚未经过良好训练,因此它的决策是随机的,因此蒸馏概率很低。
- 随着训练进行,策略网络逐渐学习如何做出正确的决策,并发现教师层往往会更好,因此蒸馏概率迅速增加。
- 经过一段时间的蒸馏,学生模型掌握了老师的知识。有些样本对于训练学生模型的有用程度降低,因此蒸馏概率降低(例如 KD_1)。
一般而言,浅层对自适应蒸馏更敏感。深层的话,几乎所有样本都一直都需要蒸馏,正如 KD_4 和 KL 的曲线所展示的那样。这种现象的原因可能是浅层的特征对于蒸馏而言相对嘈杂。由于学生模型的能力比老师小得多,因此从这些嘈杂的功能中学习会使其在最终目标任务上的表现降低。
教师网络应该冻结还是可训练?
教师网络在提出的方法中一直是被冻结的。这里放松了这一约束,并介绍了两个替代设置:
- 教师网络被随机初始化并与学生网络共同训练;
- 教师网络用预训练的参数初始化,并与学生网络共同训练。
可训练的教师网络提高了多路路由网络的能力,但可能会损害将要独立部署的学生模型的训练。表 VIII 提供了实验结果。可以看出,无论是从头开始还是从预训练的参数中训练教师网络,都会降低蒸馏性能,从而验证我们的假设。更糟糕的是,训练教师网络会减慢蒸馏过程,因为更新教师参数需要更多的计算。
对 β3 和 τ 的敏感性
提出的方法涉及几个超参数。但是,其中大多数是在以前的作品中引入的。我们沿用了这些文献中的设置。这个工作还引入了两个新的超参数,即 τ 和 β3。这里对齐进行敏感性分析,观察下他们的影响。
实验结果可以看到他们都会在某种程度上影响结果。但是,它们使所提出的自适应方法在广泛的值中获得了比标准蒸馏方法高的结果。这一特点使提出的方法更具潜力,因为我们没有过多调整参数。
蒸馏决策的可视化
为了更好地了解策略网络做出的决策,这里提供了 tiny-ImageNet 中十种类别上的决策的可视化。
可以看到,大多数要被蒸馏的图像,有着比不会被蒸馏的图像更好的质量。我们把没有知识蒸馏的样本分为四类:内容缺失,主题模棱两可,对象组,以及形态异常。这些分别由图中的红色,黄色,紫色和绿色方框表示。
- 内容缺失(红色)。由于极端靠近或无特点的视角(extreme close-ups or uncharacteristic views),这种类型数据仅捕获物体的一部分。在其他内容缺失的图像中,物体与背景无法区分。
- 主题模棱两可(黄色)。这些图像包含多个物体,无法识别哪个对象是图像的焦点。使用这些输入图像,模型很容易学习不会属于目标类别并最终导致错误的特征。
- 对象组(紫色)。单个对象的特写可以详细揭示其特征,而对象组仅能提供总体特征。
- 形态异常(绿色)。其中一些图像与数据集中的大多数图像不同,这些特别的图像不会被蒸馏。这些图像的稀有性使它们提供了与一般特征不兼容的特征。例如,我们可以看到蓝色的龙虾,毛茸茸的企鹅和粉红色头发的猪,这与数据集中这些目标的常见特性有冲突。
这些低质量的特征可能产生嘈杂的特征或预测,这可能会因为学生模型能力有限,而损害模型的学习。我们承认,这些未蒸馏的图像可以从另一个角度为模型提供信息,但是它们引入的噪音也值得考虑。通常,具有判别性的图像可以提供有用特征,因此图中显示的蒸馏决策是合理的,因此这些图像的知识将很好地指导学生。