- BaseRunner
- Runner
- EpochBasedRunner
- IterBasedRunner
- HOOKS
- Hook
- CheckpointHook
- ClosureHook
BaseRunner
通过from abc import ABCMeta, abstractmethod中的ABCMeta和abstractmethod定义了Runner的基类
需要重写的函数
All subclasses should implement the following APIs:
- run()
- train()
- val()
- save_checkpoint()
Args 相关参数
- model (:obj:
torch.nn.Module
): The model to be run. - batch_processor (callable): A callable method that process a data
batch. The interface of this method should be
batch_processor(model, data, train_mode) -> dict
- optimizer (dict or :obj:
torch.optim.Optimizer
): It can be either an
optimizer (in most cases) or a dict of optimizers (in models that
requires more than one optimizer, e.g., GAN). - work_dir (str, optional): The working directory to save checkpoints
and logs. Defaults to None. - logger (:obj:
logging.Logger
): Logger used during training.
Defaults to None. (The default value is just for backward
compatibility) - meta (dict | None): A dict records some import information such as
environment info and seed, which will be logged in logger hook.
Defaults to None. - max_epochs (int, optional): Total training epochs.
- max_iters (int, optional): Total training iterations.