深度解析博客,有些地方难理解可以看看博客。
retinafece采用的损失函数和SSD的相似,在SSD基础上添加了人脸关键点分支

一、先验框生成代码

在三个不同尺度的特征图上生成先验框,

  1. def get_anchors(self):
  2. anchors = []
  3. # 输出三个尺度:
  4. # [80,80],[40,40],[20,20]
  5. # 遍历每个特征图,并在每个特征图的每个网格上生成两个先验框
  6. # 先验框大小为:
  7. # [16,32],[64,128],[256,512]此时的尺寸为在原图[640,640]上
  8. # 在[80,80]特征图的每个像素点上生成两个先验框,两个先验框尺寸为[16,32]/640
  9. # 在[40,40]特征图的每个像素点上生成两个先验框,两个先验框尺寸为[64,128]/640
  10. # 在[20,20]特征图的每个像素点上生成两个先验框,两个先验框尺寸为[256,512]/640
  11. for k, f in enumerate(self.feature_maps):#遍历三个特征层
  12. min_sizes = self.min_sizes[k]
  13. #-----------------------------------------#
  14. # 对特征层的高和宽进行循环迭代
  15. #-----------------------------------------#
  16. for i, j in product(range(f[0]), range(f[1])):#遍历每个特征层的每个网格
  17. for min_size in min_sizes:#每次生成一个尺寸的先验框
  18. #生成每个网格先验框的尺寸
  19. s_kx = min_size / self.image_size[1]#16/640
  20. s_ky = min_size / self.image_size[0]#16/640
  21. #对坐标进行归一化,生成每个网格中心坐标
  22. dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
  23. dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
  24. for cy, cx in product(dense_cy, dense_cx):
  25. anchors += [cx, cy, s_kx, s_ky]
  26. #转换为[16800, 4]的矩阵,即总共16800个先验框,每个先验框四个信息[x,y,w,h]
  27. output = torch.Tensor(anchors).view(-1, 4)
  28. if self.clip:
  29. output.clamp_(max=1, min=0)
  30. return output

二、损失函数代码

  1. def forward(self, predictions, priors, targets):
  2. #--------------------------------------------------------------------#
  3. # predictions代表网络预测得到的信息
  4. # priors代表先验框信息
  5. # targets代表真实框信息
  6. # 取出预测结果的三个值:框的回归信息,置信度,人脸关键点的回归信息
  7. #--------------------------------------------------------------------#
  8. loc_data, conf_data, landm_data = predictions
  9. #--------------------------------------------------#
  10. # 计算出batch_size和先验框的数量
  11. #--------------------------------------------------#
  12. num = loc_data.size(0)
  13. num_priors = (priors.size(0))
  14. #--------------------------------------------------#
  15. # 创建一个tensor进行处理
  16. #--------------------------------------------------#
  17. loc_t = torch.Tensor(num, num_priors, 4)
  18. landm_t = torch.Tensor(num, num_priors, 10)
  19. conf_t = torch.LongTensor(num, num_priors)
  20. for idx in range(num):#取出一张图片进行处理
  21. # 获得真实框与标签
  22. truths = targets[idx][:, :4].data
  23. labels = targets[idx][:, -1].data
  24. landms = targets[idx][:, 4:14].data
  25. # 获得先验框
  26. defaults = priors.data
  27. #--------------------------------------------------#
  28. # 利用真实框和先验框进行匹配。
  29. # 如果真实框和先验框的重合度较高,则认为匹配上了。
  30. # 该先验框用于负责检测出该真实框。
  31. #--------------------------------------------------#
  32. match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx)
  33. #--------------------------------------------------#
  34. # 转化成Variable
  35. # loc_t (num, num_priors, 4)
  36. # conf_t (num, num_priors)
  37. # landm_t (num, num_priors, 10)
  38. #--------------------------------------------------#
  39. zeros = torch.tensor(0)
  40. if self.cuda:
  41. loc_t = loc_t.cuda()
  42. conf_t = conf_t.cuda()
  43. landm_t = landm_t.cuda()
  44. zeros = zeros.cuda()
  45. #------------------------------------------------------------------------#
  46. # 有人脸关键点的人脸真实框的标签为1,没有人脸关键点的人脸真实框标签为-1
  47. # 所以计算人脸关键点loss的时候pos1 = conf_t > zeros
  48. # 计算人脸框的loss的时候pos = conf_t != zeros
  49. #------------------------------------------------------------------------#
  50. pos1 = conf_t > zeros
  51. pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data)
  52. landm_p = landm_data[pos_idx1].view(-1, 10)
  53. landm_t = landm_t[pos_idx1].view(-1, 10)
  54. loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum')
  55. pos = conf_t != zeros
  56. pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
  57. loc_p = loc_data[pos_idx].view(-1, 4)
  58. loc_t = loc_t[pos_idx].view(-1, 4)
  59. loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')
  60. #--------------------------------------------------#
  61. # batch_conf (num * num_priors, 2)
  62. # loss_c (num, num_priors)
  63. #--------------------------------------------------#
  64. conf_t[pos] = 1
  65. batch_conf = conf_data.view(-1, self.num_classes)
  66. # 这个地方是在寻找难分类的先验框
  67. loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
  68. # 难分类的先验框不把正样本考虑进去,只考虑难分类的负样本
  69. loss_c[pos.view(-1, 1)] = 0
  70. loss_c = loss_c.view(num, -1)
  71. #--------------------------------------------------#
  72. # loss_idx (num, num_priors)
  73. # idx_rank (num, num_priors)
  74. #--------------------------------------------------#
  75. _, loss_idx = loss_c.sort(1, descending=True)
  76. _, idx_rank = loss_idx.sort(1)
  77. #--------------------------------------------------#
  78. # 求和得到每一个图片内部有多少正样本
  79. # num_pos (num, )
  80. # neg (num, num_priors)
  81. #--------------------------------------------------#
  82. num_pos = pos.long().sum(1, keepdim=True)
  83. # 限制负样本数量
  84. num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
  85. neg = idx_rank < num_neg.expand_as(idx_rank)
  86. #--------------------------------------------------#
  87. # 求和得到每一个图片内部有多少正样本
  88. # pos_idx (num, num_priors, num_classes)
  89. # neg_idx (num, num_priors, num_classes)
  90. #--------------------------------------------------#
  91. pos_idx = pos.unsqueeze(2).expand_as(conf_data)
  92. neg_idx = neg.unsqueeze(2).expand_as(conf_data)
  93. # 选取出用于训练的正样本与负样本,计算loss
  94. conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes)
  95. targets_weighted = conf_t[(pos+neg).gt(0)]
  96. loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum')
  97. N = max(num_pos.data.sum().float(), 1)
  98. loss_l /= N
  99. loss_c /= N
  100. num_pos_landm = pos1.long().sum(1, keepdim=True)
  101. N1 = max(num_pos_landm.data.sum().float(), 1)
  102. loss_landm /= N1
  103. return loss_l, loss_c, loss_landm

三、先验框匹配算法

因为正样本相比于负样本数量太少,采用匹配的方式为真实框匹配多个先验框,更好的平衡正负样本比例
将所有先验框和真实框进行匹配,重合度高的用于预测目标,重合度低的作为背景,匹配后的信息进行了编码,编码为与先验框的偏差,网络学习如何预测偏差。

在训练过程中,首先要确定训练图片中的ground truth(真实目标)与哪个prior boxes(先验框)来进行匹配,与之匹配的先验框所对应的边界框将负责预测它。Retinaface的prior boxes与ground truth的匹配原则主要有两点:

  • 首先,对于图片中每个ground truth,找到与其IOU最大的prior boxes,该prior boxes与其匹配,这样,可以保证每个ground truth一定与某个prior boxes匹配。通常称与ground truth匹配的prior boxes为正样本(其实应该是先验框对应的预测box,不过由于是一一对应的就这样称呼了),反之,若一个prior boxes没有与任何ground truth进行匹配,那么该prior boxes只能与背景匹配,就是负样本。
  • 对于剩余的未匹配prior boxes,若某个ground truth的 IOU大于某个阈值(一般是0.5),那么该先验框也与这个ground truth进行匹配。一个图片中ground truth是非常少的, 而prior boxes却很多,如果仅按第一个原则匹配,很多prior boxes会是负样本,正负样本极其不平衡,所以需要第二个原则。

尽管一个ground truth可以与多个prior boxes匹配,但是ground truth相对prior boxes还是太少了,所以负样本相对正样本会很多。为了保证正负样本尽量平衡,Retinaface采用了hard negative mining,就是对负样本进行抽样,抽样时按照置信度误差(预测背景的置信度越小,误差越大)进行降序排列,选取误差的较大的top-k作为训练的负样本,以保证正负样本比例接近1:3。

Hard Negative Mining技术

一般情况下negative default boxes数量是远大于positive default boxes数量,如果随机选取样本训练会导致网络过于重视负样本(因为抽取到负样本的概率值更大一些),这会使得loss不稳定。因此需要平衡正负样本的个数,我们常用的方法就是Hard Ngative Mining,即依据confidience score对default box进行排序,挑选其中confidience高的box进行训练,将正负样本的比例控制在positive:negative=1:3,这样会取得更好的效果。如果我们不加控制的话,很可能会出现Sample到的所有样本都是负样本(即让网络从这些负样本中找正确目标,这显然是不可以的),这样就会使得网络的性能变差

  1. def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx):
  2. #----------------------------------------------#
  3. # 计算所有的先验框和真实框的重合程度
  4. #----------------------------------------------#
  5. overlaps = jaccard(
  6. truths,
  7. point_form(priors)
  8. )
  9. #----------------------------------------------#
  10. # 所有真实框和先验框的最好重合程度,为每个真实框找出IOU最大的先验框
  11. # best_prior_overlap [truth_box,1]
  12. # best_prior_idx [truth_box,1]
  13. #----------------------------------------------#
  14. best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
  15. best_prior_idx.squeeze_(1)
  16. best_prior_overlap.squeeze_(1)
  17. #----------------------------------------------#
  18. # 所有先验框和真实框的最好重合程度,为每个先验框找出IOU最大的真实框
  19. # best_truth_overlap [1,prior]
  20. # best_truth_idx [1,prior]
  21. #----------------------------------------------#
  22. best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
  23. best_truth_idx.squeeze_(0)
  24. best_truth_overlap.squeeze_(0)
  25. #----------------------------------------------#
  26. # 用于保证每个真实框都至少有对应的一个先验框
  27. # 因为有可能存在没有先验框和真实框匹配,或IOU太低被过滤掉
  28. #----------------------------------------------#
  29. #将与真实框匹配程度最高的先验框IOU置为2,确保它能被保留下来
  30. best_truth_overlap.index_fill_(0, best_prior_idx, 2)
  31. # 对best_truth_idx内容进行设置
  32. # 更改上面置为2的先验框索引
  33. for j in range(best_prior_idx.size(0)):
  34. best_truth_idx[best_prior_idx[j]] = j
  35. #----------------------------------------------#
  36. # 获取每一个先验框对应的真实框[num_priors,4]
  37. #----------------------------------------------#
  38. matches = truths[best_truth_idx]
  39. # Shape: [num_priors] 此处为将每一个anchor对应的label取出来
  40. conf = labels[best_truth_idx]
  41. matches_landm = landms[best_truth_idx]
  42. #----------------------------------------------#
  43. # 如果重合程度小于threhold则认为是背景
  44. #----------------------------------------------#
  45. conf[best_truth_overlap < threshold] = 0
  46. #----------------------------------------------#
  47. # 利用真实框和先验框进行编码
  48. # 编码后的结果就是网络应该有的预测结果
  49. #----------------------------------------------#
  50. loc = encode(matches, priors, variances)
  51. landm = encode_landm(matches_landm, priors, variances)
  52. #----------------------------------------------#
  53. # [num_priors, 4]
  54. #----------------------------------------------#
  55. loc_t[idx] = loc
  56. #----------------------------------------------#
  57. # [num_priors]
  58. #----------------------------------------------#
  59. conf_t[idx] = conf
  60. #----------------------------------------------#
  61. # [num_priors, 10]
  62. #----------------------------------------------#
  63. landm_t[idx] = landm

3.1 point_from()

转换坐标

  1. #将[中心点x,中心点y,w,h]形式转化为[左上角x,左上角y,右下角x,右下角y]
  2. def point_form(boxes):
  3. return torch.cat((boxes[:, :2] - boxes[:, 2:]/2,
  4. boxes[:, :2] + boxes[:, 2:]/2), 1)

3.2 jaccard()

计算IOU

  1. #计算IOU
  2. def jaccard(box_a, box_b):
  3. #-------------------------------------#
  4. # 返回的inter的shape为[A,B]
  5. # 代表每一个真实框和先验框的交矩形
  6. #-------------------------------------#
  7. inter = intersect(box_a, box_b)
  8. #-------------------------------------#
  9. # 计算先验框和真实框各自的面积
  10. #-------------------------------------#
  11. area_a = ((box_a[:, 2]-box_a[:, 0]) *
  12. (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
  13. area_b = ((box_b[:, 2]-box_b[:, 0]) *
  14. (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
  15. union = area_a + area_b - inter
  16. #-------------------------------------#
  17. # 每一个真实框和先验框的交并比[A,B]
  18. #-------------------------------------#
  19. return inter / union # [A,B]

3.3 先验框编码

坐标编码公式如下
Retinafece预处理代码 - 图1
g_cxcy求的是gt bx 和prior box的偏差

Retinafece预处理代码 - 图2 求的是gt box和prior box的宽高比例值,通过 Retinafece预处理代码 - 图3 函数进行函数映射
Retinafece预处理代码 - 图4

  1. def encode(matched, priors, variances):
  2. # 进行编码的操作
  3. g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2]
  4. # 中心编码
  5. g_cxcy /= (variances[0] * priors[:, 2:])
  6. # 宽高编码
  7. g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
  8. g_wh = torch.log(g_wh) / variances[1]
  9. return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]

3.4 人脸关键点编码

求得人脸关键点和先验框的偏差。

  1. def encode_landm(matched, priors, variances):
  2. # 将人脸关键点信息分开,总共五个关键点,每个关键点两个坐标信息[x,y]
  3. matched = torch.reshape(matched, (matched.size(0), 5, 2))
  4. priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
  5. priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
  6. priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
  7. priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
  8. priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2)
  9. # 减去中心后除上宽高
  10. g_cxcy = matched[:, :, :2] - priors[:, :, :2]
  11. g_cxcy /= (variances[0] * priors[:, :, 2:])
  12. g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1)
  13. return g_cxcy

四、smooth_l1_loss

Retinafece预处理代码 - 图5
PyTorch smooth_l1_loss计算公式:
Retinafece预处理代码 - 图6

Localization Loss
Retinafece预处理代码 - 图7