image.png

  1. @article{hou2019learning,
  2. title={Learning Lightweight Lane Detection CNNs by Self Attention Distillation},
  3. author={Hou, Yuenan and Ma, Zheng and Liu, Chunxiao and Loy, Chen Change},
  4. journal={arXiv preprint arXiv:1908.00821},
  5. year={2019}
  6. }

被 ICCV 2019接收。

主要工作

本文的学生-学生模型希望学生模型可以学习到:

  • 学生网络自己前期特征注意力图

提出了一种使用自身前期特征图来作为蒸馏监督的监督方法,即所谓的“自注意力蒸馏”——自己蒸自己。

这篇文章和Attention Transfer方法有很大的关联。

Different from Sergey et al. [Paying more attention to attention: improving the performance of convolutional neu-ral networks via attention transfer] who perform attention distillation within two networks, the proposed self attention distillation is performed within the network itself.

主要结构

image.png

要注意,关于这里的蒸馏路径实际上不知这样的一种,还可以使用类似于DenseNet那样的密集连接方式。实际上可能的蒸馏路径对于M层网络而言,可以有(ICCV 2019) Learning Lightweight Lane Detection CNNs by Self Attention Distillation - 图3种连接方式。

损失函数

最终的损失函数:

image.png

Lseg和LIoU表示分割用的标准交叉熵和IoU损失,Lexist表示二值交叉熵损失,输入为车道线存在与否的图的概率图(推测,应该是概率图,BCELoss),最后是蒸馏损失了,使用的输入是各层的特征。

交叉熵损失和二值交叉熵损失的设置代码:

  1. criterion = torch.nn.NLLLoss(ignore_index=ignore_label, weight=class_weights).cuda()
  2. criterion_exist = torch.nn.BCELoss().cuda()
  3. ...
  4. input_var = torch.autograd.Variable(input)
  5. target_var = torch.autograd.Variable(target)
  6. target_exist_var = torch.autograd.Variable(target_exist)
  7. # compute output
  8. output, output_exist = model(input_var) # output_mid
  9. loss = criterion(torch.nn.functional.log_softmax(output, dim=1), target_var)
  10. # print(output_exist.data.cpu().numpy().shape)
  11. loss_exist = criterion_exist(output_exist, target_exist_var)
  12. loss_tot = loss + loss_exist * 0.1

关于蒸馏部分没有在Pytorch代码中实现,所以这里就不放了。

参考链接