Module介绍
[torch/nn/modules/module.py]
class Module:
pass
Module
类是所有神经网络的基类,例如nn.Linear()
直接继承Module
, nn.Conv2d()
间接继承Module
。
[torch/nn/modules/linear.py]
class Linear(Module):
pass
[torch/nn/modules/conv.py]
class _ConvNd(Module):
pass
class Conv2d(_ConvNd):
pass
Module属性
属性名 | 类型 | 描述 |
---|---|---|
training | bool | Module当前的模式,True表示训练状态,False表示评估(evaluation)状态 |
_parameters | OrderedDict | 有序字典保存parmater |
_buffers | OrderedDict | 存储模型中不是parameter的状态,例如BatchNorm中的running_mean |
_non_persistent_buffers_set | Set | |
_backward_hooks | OrderedDict | 反向传播钩子 |
_is_full_backward_hook | ||
_forward_hooks | OrderedDict | 前向传播钩子,调用forward()之后 |
_forward_pre_hooks | OrderedDict | 前向传播钩子,调用forward()之前 |
_state_dict_hooks | OrderedDict | |
_load_state_dict_pre_hooks | OrderedDict | |
_modules | OrderedDict | 有序字典保存module |
Module方法
张量注册
- register_buffer
往Module中添加buffer
def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None:
"""Adds a buffer to the module.
self.register_buffer('running_mean', torch.zeros(num_features))
"""
pass
- register_parameter
往Module中添加parmater
def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
"""Adds a parameter to the module."""
pass
- add_module
往Module中添加子module
def add_module(self, name: str, module: Optional['Module']) -> None:
"""Adds a child module to the current module."""
pass
数据类型转换
- _apply
递归地执行fn
def _apply(self, fn):
for module in self.children():
module._apply(fn)
for key, param in self._parameters.items():
...
for key, buf in self._buffers.items():
...
- apply
递归地在子module和self中执行fn
def apply(self: T, fn: Callable[['Module'], None]) -> T:
"""Applies fn recursively to every submodule (as returned by .children())
as well as self. Typical use includes initializing the parameters of a model.
"""
for module in self.children():
module.apply(fn)
fn(self)
return self
cuda
def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: """Moves all model parameters and buffers to the GPU.""" return self._apply(lambda t: t.cuda(device))
float
def float(self: T) -> T: r"""Casts all floating point parameters and buffers to float datatype.""" return self._apply(lambda t: t.float() if t.is_floating_point() else t)
前向和反向钩子
需要再看看hooks.py
钩子的应用场景将在下一节介绍。register_full_backward_hook
往module中注册后向hook,调用时机是反向计算梯度的时候
def register_full_backward_hook(
self, hook: Callable[['Module', _grad_t, _grad_t], Union[None, Tensor]]
) -> RemovableHandle:
"""Registers a backward hook on the module.
The hook will be called every time the gradients with respect to module
inputs are computed.
"""
pass
- register_forward_pre_hook
往module中注册前向hook,调用时机是调用forward()之前
def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle:
"""Registers a forward pre-hook on the module."""
pass
- register_forward_hook
往module中注册hook,调用时机是forward()调用完成之后
def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle:
"""Registers a forward hook on the module."""
pass
- _call_impl???
call = _call_impl
def _call_impl(self, *input, **kwargs):
# Do not call functions when jit is used
full_backward_hooks, non_full_backward_hooks = [], []
if len(self._backward_hooks) > 0 or len(_global_backward_hooks) > 0:
full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
for hook in itertools.chain(
_global_forward_pre_hooks.values(),
self._forward_pre_hooks.values()):
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
bw_hook = None
if len(full_backward_hooks) > 0:
bw_hook = hooks.BackwardHook(self, full_backward_hooks)
input = bw_hook.setup_input_hook(input)
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
for hook in itertools.chain(
_global_forward_hooks.values(),
self._forward_hooks.values()):
hook_result = hook(self, input, result)
if hook_result is not None:
result = hook_result
if bw_hook:
result = bw_hook.setup_output_hook(result)
# Handle the non-full backward hooks
if len(non_full_backward_hooks) > 0:
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in non_full_backward_hooks:
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
self._maybe_warn_non_full_backward_hook(input, result, grad_fn)
return result
属性增删改操作
- getattr
获取属性
def __getattr__(self, name: str) -> Union[Tensor, 'Module']:
pass
- setattr
设置属性
def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
...
# 假如实例化的对象为A,不能在__setattr__中使用A.name = value,否则会循环调用自身
# 见https://stackoverflow.com/questions/38682318/why-favor-object-setattr-self-name-value-only-in-new-style-classes
object.__setattr__(self, name, value)
- delattr
删除属性
def __delattr__(self, name):
...
# 原因同__setattr__
object.__delattr__(self, name)
模型状态保存与加载
_register_state_dict_hook
def _register_state_dict_hook(self, hook): pass
state_dict
_save_to_state_dict只保存当前模块的状态,包括parameter和buffer
def _save_to_state_dict(self, destination, prefix, keep_vars):
"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.
"""
pass
def state_dict(self, destination=None, prefix='', keep_vars=False):
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
"""
...
# 保存当前模块的状态
self._save_to_state_dict(destination, prefix, keep_vars)
# 递归调用state_dict
for name, module in self._modules.items():
if module is not None:
module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
...
_register_load_state_dict_pre_hook
def _register_load_state_dict_pre_hook(self, hook): pass
load_state_dict
_load_from_state_dict只加载当前模块的状态,包括parameter和buffer
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
pass
def load_state_dict(self, state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]],
strict: bool = True):
"""Copies parameters and buffers from state_dict into
this module and its descendants. If strict is True, then
the keys of state_dict must exactly match the keys returned
by this module's ~torch.nn.Module.state_dict function.
"""
pass
迭代
_named_members
def _named_members(self, get_members_fn, prefix='', recurse=True): """Helper method for yielding various names + members of modules.""" memo = set() modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)] for module_prefix, module in modules: members = get_members_fn(module) for k, v in members: if v is None or v in memo: continue memo.add(v) name = module_prefix + ('.' if module_prefix else '') + k yield name, v
parameters
def parameters(self, recurse: bool = True) -> Iterator[Parameter]: """Returns an iterator over module parameters.""" pass
named_parameters
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Tensor]]: """Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. """ pass
buffers
def buffers(self, recurse: bool = True) -> Iterator[Tensor]: """Returns an iterator over module buffers.""" pass
named_buffers
def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Tensor]]: """Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. """ pass
children
def children(self) -> Iterator['Module']: """Returns an iterator over immediate children modules.""" pass
named_children
def named_children(self) -> Iterator[Tuple[str, 'Module']]: """Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself. """ pass
modules
def modules(self) -> Iterator['Module']: """Returns an iterator over all modules in the network.""" pass
named_modules
def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = ''): """Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.""" pass
模块状态切换
train
def train(self: T, mode: bool = True) -> T: """Sets the module in training mode.""" self.training = mode for module in self.children(): module.train(mode) return self
eval
def eval(self: T) -> T: """Sets the module in evaluation mode.""" return self.train(False)
梯度操作
requiresgrad
def requires_grad_(self: T, requires_grad: bool = True) -> T: """Change if autograd should record operations on parameters in this module. """ pass
zero_grad
def zero_grad(self, set_to_none: bool = False) -> None: r"""Sets gradients of all model parameters to zero. See similar function under :class:`torch.optim.Optimizer` for more context. Args: set_to_none (bool): instead of setting to zero, set the grads to None. See :meth:`torch.optim.Optimizer.zero_grad` for details. """ if getattr(self, '_is_replica', False): warnings.warn( "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. " "The parameters are copied (in a differentiable manner) from the original module. " "This means they are not leaf nodes in autograd and so don't accumulate gradients. " "If you need gradients in your forward method, consider using autograd.grad instead.") for p in self.parameters(): if p.grad is not None: if set_to_none: p.grad = None else: if p.grad.grad_fn is not None: # 将一个tensor从创建它的图中分离,并把它设置成叶子tensor p.grad.detach_() else: p.grad.requires_grad_(False) p.grad.zero_()
其他
_replicate_for_data_parallel
复制模型副本用于数据并行操作
def _replicate_for_data_parallel(self):
replica = self.__new__(type(self))
replica.__dict__ = self.__dict__.copy()
# replicas do not have parameters themselves, the replicas reference the original
# module.
replica._parameters = OrderedDict()
replica._buffers = replica._buffers.copy()
replica._modules = replica._modules.copy()
replica._is_replica = True
return replica