代码:https://github.com/cdtrans/cdtrans
论文:cdtrans_cross_domain_transform.pdf
Abstract
无监督域适应(UDA)旨在将从标记的源域学到的知识转移到不同的未标记目标域。大多数现有的 UDA 方法都专注于使用基于卷积神经网络 (CNN) 的框架从域级别或类别级别学习域不变的特征表示。基于类别级别的 UDA 的一个基本问题是为目标域中的样本生成伪标签,这些伪标签通常对于准确的域对齐来说过于嘈杂,不可避免地会影响 UDA 的性能。随着 Transformer 在各种任务中的成功,我们发现 Transformer 中的交叉注意力对嘈杂的输入对具有鲁棒性,可以更好地进行特征对齐,因此在本文中,Transformer 被用于具有挑战性的 UDA 任务。具体来说,为了生成准确的输入对,我们设计了一种双向中心感知标记算法来为目标样本生成伪标签。除了伪标签之外,还提出了一个权重共享的三分支转换器框架,以分别将自注意力和交叉注意力应用于源/目标特征学习和源-目标域对齐。这种设计明确地强制框架同时学习有区别的域特定和域不变表示。所提出的方法被称为 CDTrans(跨域转换器),它提供了使用纯转换器解决方案解决 UDA 任务的首次尝试之一。�
Introduction
研究发现Transformer中的cross-attention 擅长对齐不同的分布,甚至来自不同的模态,如视觉到视觉,视觉到文本,文本转语音等。它们在一定程度上对伪标签中的噪音具有鲁棒性。
本文实验得出结论,即使标签对中有噪声,由于注意力机制,交叉注意力仍然可以很好的对齐两个分布。
为了获得更准确的伪标签,本文为目标域中的样本设计了一种双向中心感知标签算法。伪标签是基于跨域相似度矩阵产生,且涉及到中心感知匹配来加权矩阵并将噪声减弱到可容忍范围内。在伪标签帮助下,本文为UDA设计了跨域转换器(CDTrans)。它由三个权重共享转换器分支组成,其中两个分支分别用于源数据和目标数据,第三个是特征对齐分支,其输入来自源-目标对。 自注意力应用于源/目标转换器分支,交叉注意力参与特征对齐分支以进行域对齐。 这种设计明确地强制框架同时学习有区别的域特定和域不变表示。�
contributions
- 提出一个权重共享的三分支转换器框架,即CDTrans,利用其对嘈杂标签数据的鲁棒性和强大的特征对其能力进行准确的无监督域适应。
- 为了产生高质量的伪标签,提出了一种双向中心感知标签方法,它提高了 CDTrans 上下文中的最终性能�
Proposed Method
The cross attention in transformer
Preliminary
Vision Transformer (ViT) (Dosovitskiy et al.,2020) 在计算机视觉任务上取得了相当甚至更出色的性能。 ViT 中最重要的结构之一是自注意力模块(Vaswani 等人,2017)。�在 ViT 中,图像 被重新整形为一系列扁平的 2D patches ,其中(H,W)是原始图像的分辨率,C 是通道数,(P,P)是每个图像patches的分辨率 , 是得到的patches数量�。对于 self-attention,首先将patches投影到三个向量中,即查询 、键 和值 。 dk 和 dv 表示它们的尺寸。 输出计算为值的加权和,其中分配给每个值的权重由查询与相应键的兼容性函数计算。 N 个补丁作为自注意力模块的输入,该过程可以表述如下。 自注意力模块旨在强调输入图像的patches之间的关系�
交叉注意模块源自自我注意模块。 不同的是,cross-attention 的输入是一对图像,即 Is 和 It 。 它的查询和键/值分别来自 Is 和 It 的补丁。 交叉注意力模块可以计算如下�:
其中 是来自图像 Is 的 M 个 patches 的 queries。 和 是来自图像 It 的 N 个 patches 的 keys 和 values。交叉注意模块的输出长度M与查询的数量相同。 �对于每个输出,它是通过将 Vt 与注意力权重相乘来计算的,注意力权重来自于 Is 中相应查询与 It 中所有键的相似度。 因此,在 It 中的所有 patches 中,与 Is 的查询更相似的 patches 将拥有更大的权重,并且对输出的贡献更大。 换句话说,交叉注意力模块的输出设法根据它们相似的 patches 来聚合两个输入图像。�
到目前为止,许多研究人员已经将交叉注意力用于特征融合,尤其是在多模态任务中(Tsai et al.,2019; Li et al.,2019; Hu & Singh,2021; Li et al.,2021e)。 在这些作品中,交叉注意力模块的输入来自两种模式,例如 视觉到文本(Tsai et al.,2019; Hu & Singh,2021)、文本到语音(Li et al.,2019)和视觉到视觉(Li et al.,2021e)。 他们应用交叉注意力来聚合和对齐来自两种模式的信息。 鉴于其在特征对齐方面的强大功能,我们建议使用交叉注意力模块来解决无监督域适应问题�
Robustness to noise
如上所述,cross-attention 模块的输入是一对图像,通常来自两个域,cross-attention 模块旨在对齐这两个图像。 如果存在标签噪声,则训练数据中将存在误报对。 假阳性对中的图像会有不同的外观,强行对齐它们的特征将不可避免地损害训练并影响性能。 我们假设误报对中的不同补丁比相似补丁对性能的危害更大。 在交叉注意模块中,两个图像根据它们的补丁相似度进行对齐。 如图 1a 所示,交叉注意模块会为假阳性对中的不同块分配较低的权重。 从而在一定程度上削弱了不同补丁对最终性能的负面影响。�
为了进一步分析这个问题,我们精心设计了一个实验。具体来说,我们从 VisDA-2017 数据集 (Peng et al., 2017) 中的源域和目标域中随机抽取真阳性对作为训练数据。然后我们手动将真阳性对替换为随机假阳性对以增加噪声,并观察性能的变化,如图 1b 所示。 x 轴表示训练数据中误报对的比率,y 轴表示不同方法在 UDA 任务上的表现。红色曲线表示将pairs与cross-attention模块对齐的结果,而绿色曲线是没有cross-attention的结果,即直接用pair中对应源数据的标签训练目标数据。可以看出,红色曲线比绿色曲线实现了更好的性能,这意味着交叉注意力模块对噪声的鲁棒性。我们还提供了另一个基线,如图 1b 中的蓝色曲线所示,即从训练数据中移除假阳性对,并仅使用真阳性对训练交叉注意力。如果没有嘈杂的数据,这个基线可以被认为是我们方法的上限。我们可以看到红色曲线非常接近蓝色曲线,而且两者都比绿色曲线好很多。这进一步意味着交叉注意模块对嘈杂的输入对具有鲁棒性。�
Two-way center-aware pseudo labeling
Two-way labeling
为了构建交叉注意力模块的训练对,一种直观的方法是,对于源域中的每个图像,我们设法从目标域中找到最相似的图像。 所选对的集合 是�:
其中 S,T 分别是源数据和目标数据。�d(fi , fj ) 表示图像 i 和 j 的特征之间的距离。 这种策略的优点是可以充分利用源数据,而其缺点是只涉及到目标数据的一部分,这一点很明显。 为了从目标数据中消除这种训练偏差,我们从相反的方向引入更多对 ,包括所有目标数据及其在源域中对应的最相似的图像。�
因此最终的集合 P 是两个集合的并集,即 ,使得训练对包含所有源数据和目标数据�
Center-aware Filtering
P 中的对是基于来自两个域的图像的特征相似性构建的,因此对的伪标签的准确性高度依赖于特征相似性。 受 Liang et al. (2020) 的启发,我们发现源数据的预训练模型也有助于进一步提高准确性。 首先,我们通过预训练模型发送所有目标数据,并从分类器中获得它们在源类别上的概率分布 δ。 与 Liang et al.(2020) 类似,这些分布可用于通过加权 k 均值聚类计算目标域中每个类别的初始中心�
其中 表示图像 t 在类别 k 上的概率。 目标数据的伪标签可以通过最近邻分类器产生�
其中 和 d(i, j) 是特征 i 和 j 的距离。 基于伪标签,我们可以计算新的中心�:
在 Liang et al.(2020) 中,Eq.6 和 7 可以更新多轮,我们在论文中只采用一轮。 然后使用最终的伪标签来细化选定的对。 具体来说,对于每一对,如果目标图像的伪标签与源图像的标签一致,则保留这一对用于我们的训练,否则将作为噪声丢弃。�
CDTrans:Cross-domain transformer
所提出的跨域变换器(CDTrans)的框架如图 2 所示,它由三个权重共享变换器组成。 权重共享分支存在三个数据流和约束。�
框架的输入是从我们上面提到的标记方法中选择的对。 这三个分支分别命名为源分支、目标分支、源-目标分支。 如图 2 所示,输入对中的源图像和目标图像分别被发送到源分支和目标分支。 在这两个分支中,涉及到自我注意模块来学习特定领域的表示。 并且使用softmax交叉熵损失来训练分类。 值得注意的是,由于两个图像的标签相同,所有三个分支共享相同的分类器。�
交叉注意力模块被导入到源-目标分支中。 源-目标分支的输入来自其他两个分支。 在第 N 层,交叉注意力模块的查询来自源分支第 N 层的查询,而键和值来自目标分支的查询。 然后交叉注意模块输出对齐的特征,这些特征与第(N-1)层的输出相加�。
源-目标分支的特征不仅对齐了两个域的分布,而且由于交叉注意模块,对输入对中的噪声具有鲁棒性。 因此,我们使用源-目标分支的输出来指导目标分支的训练。 具体来说,源-目标分支和目标分支分别表示为教师和学生。 我们将源-目标分支中分类器的概率分布视为一个软标签,可用于通过蒸馏损失进一步监督目标分支(Hinton et al., 2015)�
其中 qk 和 pk 分别是来自源-目标分支和目标分支的类别 k 的概率�
在推理过程中,仅使用目标分支。 输入是来自测试数据的图像,仅触发目标数据流,即图2中的蓝线。 它的分类器输出被用作最终的预测标签。�
Experiments
Datasets and implementation
所提出的方法在四个流行的 UDA 基准上得到验证,包括 VisDA-2017 (Peng et al., 2017)、Office-Home (Venkateswara et al., 2017)、Office-31 (Saenko et al., 2010) 和 DomainNet ( 彭等人,2019)。 我们实验中的输入图像大小为 224x224。 DeiT-small 和 DeiT-base (Touvron et al., 2021) 都被用作我们进行公平比较的主干。 我们使用动量为 0.9、权重衰减比为 1e-4 的随机梯度下降算法来优化训练过程。 Office-Home、Office-31 和 DomainNet 的学习率设置为 3e-3,VisDA-2017 的学习率设置为 5e-5,因为它可以轻松收敛。 批量大小设置为 64。�