1. import torch
  2. import torch.nn as nn
  3. import torch.nn.init as init
  4. import torch.nn.functional as F
  5. import functools
  6. from torch.autograd import grad as Grad
  7. from torch.autograd import Function
  8. import numpy as np
  9. from math import sqrt
  10. ###############################################################################
  11. def weights_init(init_type='gaussian'):
  12. def init_fun(m):
  13. classname = m.__class__.__name__
  14. if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'):
  15. if init_type == 'gaussian':
  16. init.normal_(m.weight.data, 0.0, 0.02)
  17. elif init_type == 'xavier':
  18. init.xavier_normal_(m.weight.data, gain=math.sqrt(2))
  19. elif init_type == 'kaiming':
  20. init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
  21. elif init_type == 'orthogonal':
  22. init.orthogonal_(m.weight.data, gain=math.sqrt(2))
  23. elif init_type == 'default':
  24. pass
  25. else:
  26. assert 0, "Unsupported initialization: {}".format(init_type)
  27. if hasattr(m, 'bias') and m.bias is not None:
  28. init.constant_(m.bias.data, 0.0)
  29. return init_fun
  30. netD = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
  31. netD.apply(weights_init('gaussian'))
  32. >>> @torch.no_grad()
  33. >>> def init_weights(m):
  34. >>> print(m)
  35. >>> if type(m) == nn.Linear:
  36. >>> m.weight.fill_(1.0)
  37. >>> print(m.weight)
  38. >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
  39. >>> net.apply(init_weights)

APPLY

Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use
includes initializing the parameters of a model (see also torch.nn.init).
权重初始化

TORCH.NN.INIT

https://pytorch.org/docs/stable/nn.init.html#