“Compute gradient penalty: (L2_norm(dy/dx))**2.”

  1. class R1_reg(nn.Module):
  2. def __init__(self, lambda_r1=10.0):
  3. super(R1_reg, self).__init__()
  4. self.lambda_r1 = lambda_r1
  5. def __call__(self, d_out, d_in):
  6. """Compute gradient penalty: (L2_norm(dy/dx))**2."""
  7. b = d_in.shape[0]
  8. dydx = torch.autograd.grad(outputs=d_out.mean(),
  9. inputs=d_in,
  10. retain_graph=True,
  11. create_graph=True,
  12. only_inputs=True)[0]
  13. dydx_sq = dydx.pow(2)
  14. assert (dydx_sq.size() == d_in.size())
  15. r1_reg = dydx_sq.sum() / b
  16. return r1_reg * self.lambda_r1

SelectiveClassesNonSatGANLoss

  1. class SelectiveClassesNonSatGANLoss(nn.Module):
  2. def __init__(self):
  3. super(SelectiveClassesNonSatGANLoss, self).__init__()
  4. self.sofplus = nn.Softplus()
  5. def __call__(self, input, target_classes, target_is_real, is_gen=False):
  6. bSize = input.shape[0]
  7. b_ind = torch.arange(bSize).long()
  8. relevant_inputs = input[b_ind, target_classes, :, :]
  9. if target_is_real:
  10. loss = self.sofplus(-relevant_inputs).mean()
  11. else:
  12. loss = self.sofplus(relevant_inputs).mean()
  13. return loss