Module介绍

  1. [torch/nn/modules/module.py]
  2. class Module:
  3. pass

Module类是所有神经网络的基类,例如nn.Linear()直接继承Module, nn.Conv2d()间接继承Module

  1. [torch/nn/modules/linear.py]
  2. class Linear(Module):
  3. pass
  4. [torch/nn/modules/conv.py]
  5. class _ConvNd(Module):
  6. pass
  7. class Conv2d(_ConvNd):
  8. pass

Module属性

属性名 类型 描述
training bool Module当前的模式,True表示训练状态,False表示评估(evaluation)状态
_parameters OrderedDict 有序字典保存parmater
_buffers OrderedDict 存储模型中不是parameter的状态,例如BatchNorm中的running_mean
_non_persistent_buffers_set Set
_backward_hooks OrderedDict 反向传播钩子
_is_full_backward_hook
_forward_hooks OrderedDict 前向传播钩子,调用forward()之后
_forward_pre_hooks OrderedDict 前向传播钩子,调用forward()之前
_state_dict_hooks OrderedDict
_load_state_dict_pre_hooks OrderedDict
_modules OrderedDict 有序字典保存module

Module方法

张量注册

  • register_buffer

往Module中添加buffer

def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None:
    """Adds a buffer to the module.

    self.register_buffer('running_mean', torch.zeros(num_features))
    """
    pass
  • register_parameter

往Module中添加parmater

def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
    """Adds a parameter to the module."""
    pass
  • add_module

往Module中添加子module

def add_module(self, name: str, module: Optional['Module']) -> None:
    """Adds a child module to the current module."""
    pass

数据类型转换

  • _apply

递归地执行fn

def _apply(self, fn):
    for module in self.children():
        module._apply(fn)
    for key, param in self._parameters.items():
        ...
    for key, buf in self._buffers.items():
        ...
  • apply

递归地在子module和self中执行fn

def apply(self: T, fn: Callable[['Module'], None]) -> T:
    """Applies fn recursively to every submodule (as returned by .children())
        as well as self. Typical use includes initializing the parameters of a model.
    """
    for module in self.children():
        module.apply(fn)
    fn(self)
    return self
  • cuda

    def cuda(self: T, device: Optional[Union[int, device]] = None) -> T:
      """Moves all model parameters and buffers to the GPU."""
      return self._apply(lambda t: t.cuda(device))
    
  • float

    def float(self: T) -> T:
      r"""Casts all floating point parameters and buffers to float datatype."""
      return self._apply(lambda t: t.float() if t.is_floating_point() else t)
    

    前向和反向钩子

    需要再看看hooks.py
    钩子的应用场景将在下一节介绍。

  • register_full_backward_hook

往module中注册后向hook,调用时机是反向计算梯度的时候

def register_full_backward_hook(
        self, hook: Callable[['Module', _grad_t, _grad_t], Union[None, Tensor]]
    ) -> RemovableHandle:
    """Registers a backward hook on the module.

    The hook will be called every time the gradients with respect to module
    inputs are computed.
    """
    pass
  • register_forward_pre_hook

往module中注册前向hook,调用时机是调用forward()之前

def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle:
    """Registers a forward pre-hook on the module."""
    pass
  • register_forward_hook

往module中注册hook,调用时机是forward()调用完成之后

def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle:
    """Registers a forward hook on the module."""
    pass
  • _call_impl???

call = _call_impl

def _call_impl(self, *input, **kwargs):
    # Do not call functions when jit is used
    full_backward_hooks, non_full_backward_hooks = [], []
    if len(self._backward_hooks) > 0 or len(_global_backward_hooks) > 0:
        full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()

    for hook in itertools.chain(
            _global_forward_pre_hooks.values(),
            self._forward_pre_hooks.values()):
        result = hook(self, input)
        if result is not None:
            if not isinstance(result, tuple):
                result = (result,)
            input = result

    bw_hook = None
    if len(full_backward_hooks) > 0:
        bw_hook = hooks.BackwardHook(self, full_backward_hooks)
        input = bw_hook.setup_input_hook(input)

    if torch._C._get_tracing_state():
        result = self._slow_forward(*input, **kwargs)
    else:
        result = self.forward(*input, **kwargs)
    for hook in itertools.chain(
            _global_forward_hooks.values(),
            self._forward_hooks.values()):
        hook_result = hook(self, input, result)
        if hook_result is not None:
            result = hook_result

    if bw_hook:
        result = bw_hook.setup_output_hook(result)

    # Handle the non-full backward hooks
    if len(non_full_backward_hooks) > 0:
        var = result
        while not isinstance(var, torch.Tensor):
            if isinstance(var, dict):
                var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
            else:
                var = var[0]
        grad_fn = var.grad_fn
        if grad_fn is not None:
            for hook in non_full_backward_hooks:
                wrapper = functools.partial(hook, self)
                functools.update_wrapper(wrapper, hook)
                grad_fn.register_hook(wrapper)
            self._maybe_warn_non_full_backward_hook(input, result, grad_fn)

    return result

属性增删改操作

  • getattr

获取属性

def __getattr__(self, name: str) -> Union[Tensor, 'Module']:
    pass
  • setattr

设置属性

def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
    ...
    # 假如实例化的对象为A,不能在__setattr__中使用A.name = value,否则会循环调用自身
    # 见https://stackoverflow.com/questions/38682318/why-favor-object-setattr-self-name-value-only-in-new-style-classes
    object.__setattr__(self, name, value)
  • delattr

删除属性

def __delattr__(self, name):
    ...
    # 原因同__setattr__
    object.__delattr__(self, name)

模型状态保存与加载

  • _register_state_dict_hook

    def _register_state_dict_hook(self, hook):
      pass
    
  • state_dict

_save_to_state_dict只保存当前模块的状态,包括parameter和buffer

def _save_to_state_dict(self, destination, prefix, keep_vars):
    """Saves module state to `destination` dictionary, containing a state
    of the module, but not its descendants. This is called on every
    submodule in :meth:`~torch.nn.Module.state_dict`.
    """
    pass

def state_dict(self, destination=None, prefix='', keep_vars=False):
    """Returns a dictionary containing a whole state of the module.

    Both parameters and persistent buffers (e.g. running averages) are
    included. Keys are corresponding parameter and buffer names.
    """
    ...
    # 保存当前模块的状态
    self._save_to_state_dict(destination, prefix, keep_vars)

    # 递归调用state_dict
    for name, module in self._modules.items():
        if module is not None:
            module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
    ...
  • _register_load_state_dict_pre_hook

    def _register_load_state_dict_pre_hook(self, hook):
      pass
    
  • load_state_dict

_load_from_state_dict只加载当前模块的状态,包括parameter和buffer

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
    pass

 def load_state_dict(self, state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]],
                        strict: bool = True):
     """Copies parameters and buffers from state_dict into
        this module and its descendants. If strict is True, then
        the keys of state_dict must exactly match the keys returned
        by this module's ~torch.nn.Module.state_dict function.
     """
     pass

迭代

  • _named_members

    def _named_members(self, get_members_fn, prefix='', recurse=True):
      """Helper method for yielding various names + members of modules."""
      memo = set()
      modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
      for module_prefix, module in modules:
          members = get_members_fn(module)
          for k, v in members:
              if v is None or v in memo:
                  continue
              memo.add(v)
              name = module_prefix + ('.' if module_prefix else '') + k
              yield name, v
    
  • parameters

    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
      """Returns an iterator over module parameters."""
      pass
    
  • named_parameters

    def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Tensor]]:
      """Returns an iterator over module parameters, yielding both the
      name of the parameter as well as the parameter itself.
      """
      pass
    
  • buffers

    def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
      """Returns an iterator over module buffers."""
      pass
    
  • named_buffers

    def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Tensor]]:
      """Returns an iterator over module buffers, yielding both the
      name of the buffer as well as the buffer itself.
      """
      pass
    
  • children

    def children(self) -> Iterator['Module']:
      """Returns an iterator over immediate children modules."""
      pass
    
  • named_children

    def named_children(self) -> Iterator[Tuple[str, 'Module']]:
      """Returns an iterator over immediate children modules, yielding both
      the name of the module as well as the module itself.
      """
      pass
    
  • modules

    def modules(self) -> Iterator['Module']:
      """Returns an iterator over all modules in the network."""
      pass
    
  • named_modules

    def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = ''):
      """Returns an iterator over all modules in the network, yielding
      both the name of the module as well as the module itself."""
      pass
    

    模块状态切换

  • train

    def train(self: T, mode: bool = True) -> T:
      """Sets the module in training mode."""
      self.training = mode
      for module in self.children():
          module.train(mode)
      return self
    
  • eval

    def eval(self: T) -> T:
      """Sets the module in evaluation mode."""
      return self.train(False)
    

    梯度操作

  • requiresgrad

    def requires_grad_(self: T, requires_grad: bool = True) -> T:
      """Change if autograd should record operations on parameters in this
      module.
      """
      pass
    
  • zero_grad

    def zero_grad(self, set_to_none: bool = False) -> None:
      r"""Sets gradients of all model parameters to zero. See similar function
      under :class:`torch.optim.Optimizer` for more context.
    
      Args:
          set_to_none (bool): instead of setting to zero, set the grads to None.
              See :meth:`torch.optim.Optimizer.zero_grad` for details.
      """
      if getattr(self, '_is_replica', False):
          warnings.warn(
              "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. "
              "The parameters are copied (in a differentiable manner) from the original module. "
              "This means they are not leaf nodes in autograd and so don't accumulate gradients. "
              "If you need gradients in your forward method, consider using autograd.grad instead.")
    
      for p in self.parameters():
          if p.grad is not None:
              if set_to_none:
                  p.grad = None
              else:
                  if p.grad.grad_fn is not None:
                      # 将一个tensor从创建它的图中分离,并把它设置成叶子tensor
                      p.grad.detach_()
                  else:
                      p.grad.requires_grad_(False)
                  p.grad.zero_()
    

    其他

  • _replicate_for_data_parallel

复制模型副本用于数据并行操作

def _replicate_for_data_parallel(self):
    replica = self.__new__(type(self))
    replica.__dict__ = self.__dict__.copy()

    # replicas do not have parameters themselves, the replicas reference the original
    # module.
    replica._parameters = OrderedDict()
    replica._buffers = replica._buffers.copy()
    replica._modules = replica._modules.copy()
    replica._is_replica = True

    return replica