ArcFace与 AM-Softmax同样也是加性的 margin,差别只是 ArcFace的 margin 加在 Cos 算子的里面,而 AM-Softmax的 margin 在加性算子的外面。 ArcFace的loss伪代码如下:
- 对x进行归一化
- 对W进行归一化
- 计算Wx得到预测向量y
- 从y中挑出与ground truth对应的值
- 计算其反余弦得到角度
- 角度加上m
- 从y 中挑出与ground truth对应的值所在位置的独热码
- 将 cos(theta+m)通过独热码放回原来的位置
- 对所有值乘上尺度因子S
ArcFaceclass ArcMarginProduct(nn.Module):r"""Implement of large margin arc distance: :Args:in_features: size of each input sampleout_features: size of each output samples: norm of input featurem: margincos(theta + m)"""def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False):super(ArcMarginProduct, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.s = sself.m = m# Parameter 的用途:# 将一个不可训练的类型Tensor转换成可以训练的类型parameter# 并将这个parameter绑定到这个module里面# net.parameter()中就有这个绑定的parameter,所以在参数优化的时候可以进行优化的# https://www.jianshu.com/p/d8b77cc02410# 初始化权重self.weight = Parameter(torch.FloatTensor(out_features, in_features))nn.init.xavier_uniform_(self.weight)self.easy_margin = easy_marginself.cos_m = math.cos(m)self.sin_m = math.sin(m)self.th = math.cos(math.pi - m)self.mm = math.sin(math.pi - m) * mdef forward(self, input, label):# --------------------------- cos(theta) & phi(theta) ---------------------------# torch.nn.functional.linear(input, weight, bias=None)# y=x*W^T+bcosine = F.linear(F.normalize(input), F.normalize(self.weight))sine = torch.sqrt(1.0 - torch.pow(cosine, 2))# cos(a+b)=cos(a)*cos(b)-size(a)*sin(b)phi = cosine * self.cos_m - sine * self.sin_mif self.easy_margin:# torch.where(condition, x, y) → Tensor# condition (ByteTensor) – When True (nonzero), yield x, otherwise yield y# x (Tensor) – values selected at indices where condition is True# y (Tensor) – values selected at indices where condition is False# return:# A tensor of shape equal to the broadcasted shape of condition, x, y# cosine>0 means two class is similar, thus use the phi which make itphi = torch.where(cosine > 0, phi, cosine)else:phi = torch.where(cosine > self.th, phi, cosine - self.mm)# --------------------------- convert label to one-hot ---------------------------# one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')# 将cos(\theta + m)更新到tensor相应的位置中one_hot = torch.zeros(cosine.size(), device='cuda')# scatter_(dim, index, src)one_hot.scatter_(1, label.view(-1, 1).long(), 1)# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------output = (one_hot * phi) + ((1.0 - one_hot) * cosine)# you can use torch.where if your torch.__version__ is 0.4output *= self.s# print(output)return output
