import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import functools
from torch.autograd import grad as Grad
from torch.autograd import Function
import numpy as np
from math import sqrt
###############################################################################
def weights_init(init_type='gaussian'):
def init_fun(m):
classname = m.__class__.__name__
if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'):
if init_type == 'gaussian':
init.normal_(m.weight.data, 0.0, 0.02)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=math.sqrt(2))
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=math.sqrt(2))
elif init_type == 'default':
pass
else:
assert 0, "Unsupported initialization: {}".format(init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
return init_fun
netD = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
netD.apply(weights_init('gaussian'))
>>> @torch.no_grad()
>>> def init_weights(m):
>>> print(m)
>>> if type(m) == nn.Linear:
>>> m.weight.fill_(1.0)
>>> print(m.weight)
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
>>> net.apply(init_weights)
APPLY
- torch.jit.ScriptModule.apply (Python method, in ScriptModule)
- torch.nn.Flatten.apply (Python method, in Flatten)
- torch.nn.Module.apply (Python method, in Module)
- torch.nn.Unflatten.apply (Python method, in Unflatten)
- torch.nn.utils.prune.BasePruningMethod.apply (Python method, in BasePruningMethod)
- torch.nn.utils.prune.CustomFromMask.apply (Python method, in CustomFromMask)
- torch.nn.utils.prune.Identity.apply (Python method, in Identity)
- torch.nn.utils.prune.L1Unstructured.apply (Python method, in L1Unstructured)
- torch.nn.utils.prune.LnStructured.apply (Python method, in LnStructured)
- torch.nn.utils.prune.PruningContainer.apply (Python method, in PruningContainer)
- torch.nn.utils.prune.RandomStructured.apply (Python method, in RandomStructured)
- torch.nn.utils.prune.RandomUnstructured.apply (Python method, in RandomUnstructured)
Applies fn
recursively to every submodule (as returned by .children()
) as well as self. Typical use
includes initializing the parameters of a model (see also torch.nn.init).
权重初始化