“Compute gradient penalty: (L2_norm(dy/dx))**2.”
class R1_reg(nn.Module): def __init__(self, lambda_r1=10.0): super(R1_reg, self).__init__() self.lambda_r1 = lambda_r1 def __call__(self, d_out, d_in): """Compute gradient penalty: (L2_norm(dy/dx))**2.""" b = d_in.shape[0] dydx = torch.autograd.grad(outputs=d_out.mean(), inputs=d_in, retain_graph=True, create_graph=True, only_inputs=True)[0] dydx_sq = dydx.pow(2) assert (dydx_sq.size() == d_in.size()) r1_reg = dydx_sq.sum() / b return r1_reg * self.lambda_r1
SelectiveClassesNonSatGANLoss
class SelectiveClassesNonSatGANLoss(nn.Module): def __init__(self): super(SelectiveClassesNonSatGANLoss, self).__init__() self.sofplus = nn.Softplus() def __call__(self, input, target_classes, target_is_real, is_gen=False): bSize = input.shape[0] b_ind = torch.arange(bSize).long() relevant_inputs = input[b_ind, target_classes, :, :] if target_is_real: loss = self.sofplus(-relevant_inputs).mean() else: loss = self.sofplus(relevant_inputs).mean() return loss