SphereFace主要工作是归一化了网络中的权值W。
权值W归一化之后,训练时优化会更加集中在深度特征映射和特征向量角度上。这样做的主要好处是可以降低样本数量不均衡问题。
SphereFace当中的loss被称作A-softmax,除了归一化权值之外还采用了乘性margin,但是乘性margin效果并不理想,主要是其训练时难以收敛。

# SphereFaceclass SphereProduct(nn.Module):r"""Implement of large margin cosine distance: :Args:in_features: size of each input sampleout_features: size of each output samplem: margincos(m*theta)"""def __init__(self, in_features, out_features, m=4):super(SphereProduct, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.m = mself.base = 1000.0self.gamma = 0.12self.power = 1self.LambdaMin = 5.0self.iter = 0self.weight = Parameter(torch.FloatTensor(out_features, in_features))nn.init.xavier_uniform(self.weight)# duplication formula# 将x\in[-1,1]范围的重复index次映射到y\[-1,1]上self.mlambda = [lambda x: x ** 0,lambda x: x ** 1,lambda x: 2 * x ** 2 - 1,lambda x: 4 * x ** 3 - 3 * x,lambda x: 8 * x ** 4 - 8 * x ** 2 + 1,lambda x: 16 * x ** 5 - 20 * x ** 3 + 5 * x]"""执行以下代码直观了解mlambdaimport matplotlib.pyplot as pltmlambda = [lambda x: x ** 0,lambda x: x ** 1,lambda x: 2 * x ** 2 - 1,lambda x: 4 * x ** 3 - 3 * x,lambda x: 8 * x ** 4 - 8 * x ** 2 + 1,lambda x: 16 * x ** 5 - 20 * x ** 3 + 5 * x]x = [0.01 * i for i in range(-100, 101)]print(x)for f in mlambda:plt.plot(x,[f(i) for i in x])plt.show()"""def forward(self, input, label):# lambda = max(lambda_min,base*(1+gamma*iteration)^(-power))self.iter += 1self.lamb = max(self.LambdaMin, self.base * (1 + self.gamma * self.iter) ** (-1 * self.power))# --------------------------- cos(theta) & phi(theta) ---------------------------cos_theta = F.linear(F.normalize(input), F.normalize(self.weight))cos_theta = cos_theta.clamp(-1, 1)cos_m_theta = self.mlambda[self.m](cos_theta)theta = cos_theta.data.acos()k = (self.m * theta / 3.14159265).floor()phi_theta = ((-1.0) ** k) * cos_m_theta - 2 * kNormOfFeature = torch.norm(input, 2, 1)# --------------------------- convert label to one-hot ---------------------------one_hot = torch.zeros(cos_theta.size())one_hot = one_hot.cuda() if cos_theta.is_cuda else one_hotone_hot.scatter_(1, label.view(-1, 1), 1)# --------------------------- Calculate output ---------------------------output = (one_hot * (phi_theta - cos_theta) / (1 + self.lamb)) + cos_thetaoutput *= NormOfFeature.view(-1, 1)return outputdef __repr__(self):return self.__class__.__name__ + '(' \+ 'in_features=' + str(self.in_features) \+ ', out_features=' + str(self.out_features) \+ ', m=' + str(self.m) + ')'
