- 分享主题:Transfer Learning, Domain Adaptation, CV, Explicit Feature Distribution Alignment, Adversarial
- 论文标题:Discriminative Adversarial Domain Adaptation
- 论文链接:https://arxiv.org/pdf/1911.12036.pdf
1.Summary
This is a paper about picture classification. Suppose there is a source domain and a target domain. The pictures of the source domain are labeled and the pictures of the target domain are not labeled, which is a problem of domain adaptation. In order to transfer the knowledge from the source domain to the target domain, this paper uses the adversarial method to realize the transfer learning. However, different from the previous methods in the structure of GAN, this paper uses a method called DADA (Discriminative Adversarial Domain Adaptation). DADA combines the label predictor and domain discriminator to align the joint distribution of source domain and target domain (feature and picture labels). DADA makes up for the disadvantage that the Gan structure can not align the joint distribution. In order to deepen my understanding of this paper, I can read some papers on domain adversarial learning.2.你对于论文的思考
这是一篇关于图片分类的文章,背景是一个源域和一个目标域,源域图片带标签,目标域不带标签,使用的迁移学习方法是域对抗的方法,但是与以往的GAN结构的域对抗方法(最经典的就是DANN)不同,这篇文章的DADA方法是把原先的标签预测器和域判别器结合在了一起,利用了巧妙的损失函数来使得源域和目标域的联合分布对齐,而这一点是原先的GAN结构的域对抗方法做不到的。3. 其他
3.1 解决的问题
这篇文章解决的图片分类的问题,把带标签的源于数据迁移到不带标签的目标域数据上,也就是一个域自适应的问题。3.2 DADA(Discriminative Adversarial Domain Adaptation)
DADA模型主题分为两个部分,特征提取器和分类器,特征提取器相当于DANN中的生成器,分类器相当于DANN中的域判别器和标签预测器的结合。3.2.1 分类器
DADA的分类器结合了域判别器和标签预测器,假设图片标签一共有K类,那么这个分类器最终输出一个K+1维的向量,其中前K维是源域图片分别属于K类标签的概率,也就是如果被分类到了前K维那么这个图片是属于源域的,最后一维表示属于目标域图片的概率。3.2.2 损失函数
下面是初步设计的损失函数,第一个式子的目的是让判别器尽可能让源于数据准确分类,同时尽可能被分到前K类,并且让目标域数据被分到第K+1类,也就是域标签也能准确分类;第二个式子的目的是让源于数据准确分类,同时让目标域数据尽可能分到前K类。这样一来,下面的损失函数的最终目的是让源于数据的图片准确分类,并且让源域和目标域的域标签相互混淆,因为在混淆时,源域数据的特征是和图片标签挂钩的,比如说源域中由张图片属于第i类,那么最终训练的目标是让这张图片属于第i类合第K+1类的概率变高并且接近。
因为上面的损失函数让对抗的过程变得十分含蓄,所以这篇文章又设计了一套损失函数,如下所示:
(1)下面的式子是和源域数据相关的损失函数:
如下图所示,只要保证源域分类输出大于0.5,在分类器上最小化这个式子可以提高图片分类能力,并且会让域标签也分类准确,在特征提取器上最大化这个式子虽然会降低图片分类能力,但是能够尽可能的混淆两个域,提取两个域的共同特征。
(2)下面的式子是和目标域数据相关的损失函数,最小化第一个式子,目标是让域标签分类准确;最大化第二个式子,目标是让域标签分类不准确。
(3)为了提高效果,加入了如下损失函数,H(.)是一个熵函数,值越大,熵就越小。目的是让判别器的判断更加明确,尽量避免某些样本在某两类(前K类中的某两类)上的概率很接近,以至于难以判别图片的标签;同时也是让生成器生成的特征的熵尽量大,有助于让生成的特征更加接近,有助于两个域的混淆。
(4)总体损失函数:3.3 实验
(1)与最基础的baseline对比以及消融实验。
数据集:Office-31
(2)数据集:Office-31
(3) 数据集:Syn2Real-C