ArcFace与 AM-Softmax同样也是加性的 margin,差别只是 ArcFace的 margin 加在 Cos 算子的里面,而 AM-Softmax的 margin 在加性算子的外面。 ArcFace的loss伪代码如下:

    • 对x进行归一化
    • 对W进行归一化
    • 计算Wx得到预测向量y
    • 从y中挑出与ground truth对应的值
    • 计算其反余弦得到角度
    • 角度加上m
    • 从y 中挑出与ground truth对应的值所在位置的独热码
    • 将 cos(theta+m)通过独热码放回原来的位置
    • 对所有值乘上尺度因子S
    1. ArcFace
    2. class ArcMarginProduct(nn.Module):
    3. r"""Implement of large margin arc distance: :
    4. Args:
    5. in_features: size of each input sample
    6. out_features: size of each output sample
    7. s: norm of input feature
    8. m: margin
    9. cos(theta + m)
    10. """
    11. def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False):
    12. super(ArcMarginProduct, self).__init__()
    13. self.in_features = in_features
    14. self.out_features = out_features
    15. self.s = s
    16. self.m = m
    17. # Parameter 的用途:
    18. # 将一个不可训练的类型Tensor转换成可以训练的类型parameter
    19. # 并将这个parameter绑定到这个module里面
    20. # net.parameter()中就有这个绑定的parameter,所以在参数优化的时候可以进行优化的
    21. # https://www.jianshu.com/p/d8b77cc02410
    22. # 初始化权重
    23. self.weight = Parameter(torch.FloatTensor(out_features, in_features))
    24. nn.init.xavier_uniform_(self.weight)
    25. self.easy_margin = easy_margin
    26. self.cos_m = math.cos(m)
    27. self.sin_m = math.sin(m)
    28. self.th = math.cos(math.pi - m)
    29. self.mm = math.sin(math.pi - m) * m
    30. def forward(self, input, label):
    31. # --------------------------- cos(theta) & phi(theta) ---------------------------
    32. # torch.nn.functional.linear(input, weight, bias=None)
    33. # y=x*W^T+b
    34. cosine = F.linear(F.normalize(input), F.normalize(self.weight))
    35. sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
    36. # cos(a+b)=cos(a)*cos(b)-size(a)*sin(b)
    37. phi = cosine * self.cos_m - sine * self.sin_m
    38. if self.easy_margin:
    39. # torch.where(condition, x, y) → Tensor
    40. # condition (ByteTensor) – When True (nonzero), yield x, otherwise yield y
    41. # x (Tensor) – values selected at indices where condition is True
    42. # y (Tensor) – values selected at indices where condition is False
    43. # return:
    44. # A tensor of shape equal to the broadcasted shape of condition, x, y
    45. # cosine>0 means two class is similar, thus use the phi which make it
    46. phi = torch.where(cosine > 0, phi, cosine)
    47. else:
    48. phi = torch.where(cosine > self.th, phi, cosine - self.mm)
    49. # --------------------------- convert label to one-hot ---------------------------
    50. # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
    51. # 将cos(\theta + m)更新到tensor相应的位置中
    52. one_hot = torch.zeros(cosine.size(), device='cuda')
    53. # scatter_(dim, index, src)
    54. one_hot.scatter_(1, label.view(-1, 1).long(), 1)
    55. # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
    56. output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
    57. # you can use torch.where if your torch.__version__ is 0.4
    58. output *= self.s
    59. # print(output)
    60. return output