import torchimport torch.nn as nnimport torch.nn.init as initimport torch.nn.functional as Fimport functoolsfrom torch.autograd import grad as Gradfrom torch.autograd import Functionimport numpy as npfrom math import sqrt###############################################################################def weights_init(init_type='gaussian'):def init_fun(m):classname = m.__class__.__name__if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'):if init_type == 'gaussian':init.normal_(m.weight.data, 0.0, 0.02)elif init_type == 'xavier':init.xavier_normal_(m.weight.data, gain=math.sqrt(2))elif init_type == 'kaiming':init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')elif init_type == 'orthogonal':init.orthogonal_(m.weight.data, gain=math.sqrt(2))elif init_type == 'default':passelse:assert 0, "Unsupported initialization: {}".format(init_type)if hasattr(m, 'bias') and m.bias is not None:init.constant_(m.bias.data, 0.0)return init_funnetD = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))netD.apply(weights_init('gaussian'))>>> @torch.no_grad()>>> def init_weights(m):>>> print(m)>>> if type(m) == nn.Linear:>>> m.weight.fill_(1.0)>>> print(m.weight)>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))>>> net.apply(init_weights)
APPLY
- torch.jit.ScriptModule.apply (Python method, in ScriptModule)
- torch.nn.Flatten.apply (Python method, in Flatten)
- torch.nn.Module.apply (Python method, in Module)
- torch.nn.Unflatten.apply (Python method, in Unflatten)
- torch.nn.utils.prune.BasePruningMethod.apply (Python method, in BasePruningMethod)
- torch.nn.utils.prune.CustomFromMask.apply (Python method, in CustomFromMask)
- torch.nn.utils.prune.Identity.apply (Python method, in Identity)
- torch.nn.utils.prune.L1Unstructured.apply (Python method, in L1Unstructured)
- torch.nn.utils.prune.LnStructured.apply (Python method, in LnStructured)
- torch.nn.utils.prune.PruningContainer.apply (Python method, in PruningContainer)
- torch.nn.utils.prune.RandomStructured.apply (Python method, in RandomStructured)
- torch.nn.utils.prune.RandomUnstructured.apply (Python method, in RandomUnstructured)
Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use
includes initializing the parameters of a model (see also torch.nn.init).
权重初始化
