SphereFace主要工作是归一化了网络中的权值W。

    权值W归一化之后,训练时优化会更加集中在深度特征映射和特征向量角度上。这样做的主要好处是可以降低样本数量不均衡问题。

    SphereFace当中的loss被称作A-softmax,除了归一化权值之外还采用了乘性margin,但是乘性margin效果并不理想,主要是其训练时难以收敛。

    Screenshot from 2020-07-21 19-04-45.png

    1. # SphereFace
    2. class SphereProduct(nn.Module):
    3. r"""Implement of large margin cosine distance: :
    4. Args:
    5. in_features: size of each input sample
    6. out_features: size of each output sample
    7. m: margin
    8. cos(m*theta)
    9. """
    10. def __init__(self, in_features, out_features, m=4):
    11. super(SphereProduct, self).__init__()
    12. self.in_features = in_features
    13. self.out_features = out_features
    14. self.m = m
    15. self.base = 1000.0
    16. self.gamma = 0.12
    17. self.power = 1
    18. self.LambdaMin = 5.0
    19. self.iter = 0
    20. self.weight = Parameter(torch.FloatTensor(out_features, in_features))
    21. nn.init.xavier_uniform(self.weight)
    22. # duplication formula
    23. # 将x\in[-1,1]范围的重复index次映射到y\[-1,1]上
    24. self.mlambda = [
    25. lambda x: x ** 0,
    26. lambda x: x ** 1,
    27. lambda x: 2 * x ** 2 - 1,
    28. lambda x: 4 * x ** 3 - 3 * x,
    29. lambda x: 8 * x ** 4 - 8 * x ** 2 + 1,
    30. lambda x: 16 * x ** 5 - 20 * x ** 3 + 5 * x
    31. ]
    32. """
    33. 执行以下代码直观了解mlambda
    34. import matplotlib.pyplot as plt
    35. mlambda = [
    36. lambda x: x ** 0,
    37. lambda x: x ** 1,
    38. lambda x: 2 * x ** 2 - 1,
    39. lambda x: 4 * x ** 3 - 3 * x,
    40. lambda x: 8 * x ** 4 - 8 * x ** 2 + 1,
    41. lambda x: 16 * x ** 5 - 20 * x ** 3 + 5 * x
    42. ]
    43. x = [0.01 * i for i in range(-100, 101)]
    44. print(x)
    45. for f in mlambda:
    46. plt.plot(x,[f(i) for i in x])
    47. plt.show()
    48. """
    49. def forward(self, input, label):
    50. # lambda = max(lambda_min,base*(1+gamma*iteration)^(-power))
    51. self.iter += 1
    52. self.lamb = max(self.LambdaMin, self.base * (1 + self.gamma * self.iter) ** (-1 * self.power))
    53. # --------------------------- cos(theta) & phi(theta) ---------------------------
    54. cos_theta = F.linear(F.normalize(input), F.normalize(self.weight))
    55. cos_theta = cos_theta.clamp(-1, 1)
    56. cos_m_theta = self.mlambda[self.m](cos_theta)
    57. theta = cos_theta.data.acos()
    58. k = (self.m * theta / 3.14159265).floor()
    59. phi_theta = ((-1.0) ** k) * cos_m_theta - 2 * k
    60. NormOfFeature = torch.norm(input, 2, 1)
    61. # --------------------------- convert label to one-hot ---------------------------
    62. one_hot = torch.zeros(cos_theta.size())
    63. one_hot = one_hot.cuda() if cos_theta.is_cuda else one_hot
    64. one_hot.scatter_(1, label.view(-1, 1), 1)
    65. # --------------------------- Calculate output ---------------------------
    66. output = (one_hot * (phi_theta - cos_theta) / (1 + self.lamb)) + cos_theta
    67. output *= NormOfFeature.view(-1, 1)
    68. return output
    69. def __repr__(self):
    70. return self.__class__.__name__ + '(' \
    71. + 'in_features=' + str(self.in_features) \
    72. + ', out_features=' + str(self.out_features) \
    73. + ', m=' + str(self.m) + ')'