• BaseRunner
  • Runner
  • EpochBasedRunner
  • IterBasedRunner
  • HOOKS
  • Hook
  • CheckpointHook
  • ClosureHook

    BaseRunner

    通过from abc import ABCMeta, abstractmethod中的ABCMeta和abstractmethod定义了Runner的基类

需要重写的函数

All subclasses should implement the following APIs:

  • run()
  • train()
  • val()
  • save_checkpoint()

Args 相关参数

  • model (:obj:torch.nn.Module): The model to be run.
  • batch_processor (callable): A callable method that process a data
    batch. The interface of this method should be
    batch_processor(model, data, train_mode) -> dict
  • optimizer (dict or :obj:torch.optim.Optimizer): It can be either an
    optimizer (in most cases) or a dict of optimizers (in models that
    requires more than one optimizer, e.g., GAN).
  • work_dir (str, optional): The working directory to save checkpoints
    and logs. Defaults to None.
  • logger (:obj:logging.Logger): Logger used during training.
    Defaults to None. (The default value is just for backward
    compatibility)
  • meta (dict | None): A dict records some import information such as
    environment info and seed, which will be logged in logger hook.
    Defaults to None.
  • max_epochs (int, optional): Total training epochs.
  • max_iters (int, optional): Total training iterations.

EpochBasedRunner

run_iter()

train()

val()

run()

save_checkpoint()

参考

https://zhuanlan.zhihu.com/p/268571921