使用字典等配置文件使用build初始化为实例,或者运行
Registry 类最大好处是:解耦性强、可扩展性强,代码更易理解。
在 OpenMMLab 中,Registry 类可以提供一种完全相似的对外装饰函数来管理构建不同的组件,例如 backbones、head 和 necks 等等,Registry 类内部其实维护的是一个全局 key-value 对。通过 Registry 类,用户可以通过字符串方式实例化任何想要的模块。
MMCV implements registry to manage different modules that share similar functionalities, e.g., backbones, head, and necks, in detectors.
Most projects in OpenMMLab use registry to manage modules of datasets and models, such as MMDetection, MMDetection3D, MMClassification, MMEditing, etc.
What is registry
In MMCV, registry can be regarded as a mapping that maps a class to a string.
These classes contained by a single registry usually have similar APIs but implement different algorithms or support different datasets.
With the registry, users can find and instantiate the class through its corresponding string, and use the instantiated(实例化) module as they want.
One typical example is the config systems in most OpenMMLab projects, which use the registry to create hooks, runners, models, and datasets, through configs.
The API reference could be found here.
To manage your modules in the codebase by Registry
, there are three steps as below.
- Create a build method (optional, in most cases you can just use the default one).
- Create a registry.
- Use this registry to manage the modules.
build_func
argument of Registry
is to customize how to instantiate the class instance, the default one is build_from_cfg
implemented here.
A Simple Example
Here we show a simple example of using registry to manage modules in a package.
You can find more practical examples in OpenMMLab projects.
Assuming we want to implement a series of Dataset Converter for converting different formats of data to the expected data format.
We create a directory as a package named converters
.
In the package, we first create a file to implement builders, named converters/builder.py
, as below
from mmcv.utils import Registry
# create a registry for converters
CONVERTERS = Registry('converter')
Then we can implement different converters in the package. For example, implement Converter1
in converters/converter1.py
from .builder import CONVERTERS
# use the registry to manage the module
@CONVERTERS.register_module()
class Converter1(object):
def __init__(self, a, b):
self.a = a
self.b = b
The key step to using the registry for managing the modules is to register the implemented module into the registry CONVERTERS
through@CONVERTERS.register_module()
when you are creating the module. By this way, a mapping between a string and the class is built and maintained by CONVERTERS
as below
'Converter1' -> <class 'Converter1'>
If the module is successfully registered, you can use this converter through configs as
converter_cfg = dict(type='Converter1', a=a_value, b=b_value)
converter = CONVERTERS.build(converter_cfg)
Customize Build Function
Suppose we would like to customize how converters
are built, we could implement a customized build_func
and pass it into the registry.
from mmcv.utils import Registry
# create a build function
def build_converter(cfg, registry, *args, **kwargs):
cfg_ = cfg.copy()
converter_type = cfg_.pop('type')
if converter_type not in registry:
raise KeyError(f'Unrecognized converter type {converter_type}')
else:
converter_cls = registry.get(converter_type)
converter = converter_cls(*args, **kwargs, **cfg_)
return converter
# create a registry for converters and pass ``build_converter`` function
CONVERTERS = Registry('converter', build_func=build_converter)
In this example, we demonstrate how to use the build_func
argument to customize the way to build a class instance.
The functionality is similar to the default build_from_cfg
. In most cases, default one would be sufficient.build_model_from_cfg
is also implemented to build PyTorch module in nn.Sequentail
, you may directly use them instead of implementing by yourself.
Hierarchy Registry
You could also build modules from more than one OpenMMLab frameworks, e.g. you could use all backbones in MMClassification for object detectors in MMDetection, you may also combine an object detection model in MMDetection and semantic segmentation model in MMSegmentation.
All MODELS
registries of downstream codebases are children registries of MMCV’s MODELS
registry.
Basically, there are two ways to build a module from child or sibling registries.
- Build from children registries.
For example:
In MMDetection we define:
In MMClassification we define:
We could build two net in either MMDetection or MMClassification by:
or
```python from mmcv.utils import Registry from mmcv.cnn import MODELS as MMCV_MODELS MODELS = Registry(‘model’, parent=MMCV_MODELS)
@MODELS.register_module() class NetA(nn.Module): def forward(self, x): return x
```python
from mmcv.utils import Registry
from mmcv.cnn import MODELS as MMCV_MODELS
MODELS = Registry('model', parent=MMCV_MODELS)
@MODELS.register_module()
class NetB(nn.Module):
def forward(self, x):
return x + 1
from mmdet.models import MODELS
net_a = MODELS.build(cfg=dict(type='NetA'))
net_b = MODELS.build(cfg=dict(type='mmcls.NetB'))
from mmcls.models import MODELS
net_a = MODELS.build(cfg=dict(type='mmdet.NetA'))
net_b = MODELS.build(cfg=dict(type='NetB'))
- Build from parent registry.
The sharedMODELS
registry in MMCV is the parent registry for all downstream codebases (root registry):from mmcv.cnn import MODELS as MMCV_MODELS
net_a = MMCV_MODELS.build(cfg=dict(type='mmdet.NetA'))
net_b = MMCV_MODELS.build(cfg=dict(type='mmcls.NetB'))
Registry用法
# 0. 先构建一个全局的 CATS 注册器类
CATS = mmcv.Registry('cat')
# 通过装饰器方式作用在想要加入注册器的具体类中
#===============================================================
# 1. 不需要传入任何参数,此时默认实例化的配置字符串是 str (类名)
@CATS.register_module()
class BritishShorthair:
pass
# 类实例化
CATS.get('BritishShorthair')(**args)
#==============================================================
# 2.传入指定 str,实例化时候只需要传入对应相同 str 即可
@CATS.register_module(name='Siamese')
class SiameseCat:
pass
# 类实例化
CATS.get('Siamese')(**args)
#===============================================================
# 3.如果出现同名 Registry Key,可以选择报错或者强制覆盖
# 如果指定了 force=True,那么不会报错
# 此时 Registry 的 Key 中,Siamese2Cat 类会覆盖 SiameseCat 类
# 否则会报错
@CATS.register_module(name='Siamese',force=True)
class Siamese2Cat:
pass
# 类实例化
CATS.get('Siamese')(**args)
#==============================================================
# 4. 可以直接注册类
class Munchkin:
pass
CATS.register_module(Munchkin)
# 类实例化
CATS.get('Munchkin')(**args)
Registry 最简实现
(1) 最简实现
# 方便起见,此处并未使用类方式构建,而是直接采用全局变量
_module_dict = dict()
# 定义装饰器函数
def register_module(name):
def _register(cls):
_module_dict[name] = cls
return cls
return _register
# 装饰器用法
@register_module('one_class')
class OneTest(object):
pass
@register_module('two_class')
class TwoTest(object):
pass
if __name__ == '__main__':
# 通过注册类名实现自动实例化功能
one_test = _module_dict['one_class']()
print(one_test)
(2) 实现无需传入参数,自动根据类名初始化类
_module_dict = dict()
def register_module(module_name=None):
def _register(cls):
name = module_name
# 如果 module_name 没有给,则自动获取
if module_name is None:
name = cls.__name__
_module_dict[name] = cls
return cls
return _register
@register_module('one_class')
class OneTest(object):
pass
@register_module()
class TwoTest(object):
pass
if __name__ == '__main__':
one_test = _module_dict['one_class']
# 方便起见,此处仅仅打印了类对象,而没有实例化。如果要实例化,只需要 one_test() 即可
print(one_test)
two_test = _module_dict['TwoTest']
print(two_test)
(3) 实现重名注册强制报错功能
_module_dict = dict()
def register_module(module_name=None):
def _register(cls):
name = module_name
if module_name is None:
name = cls.__name__
# 如果重名注册,则强制报错
if name in _module_dict:
raise KeyError(f'{module_name} is already registered '
f'in {name}')
_module_dict[name] = cls
return cls
return _register
(4) 实现重名注册强制报错功能
def register_module(module_name=None,force=False):
def _register(cls):
name = module_name
if module_name is None:
name = cls.__name__
# 如果重名注册,则强制报错
if not force and name in _module_dict:
raise KeyError(f'{module_name} is already registered '
f'in {name}')
_module_dict[name] = cls
return cls
return _register
@register_module('one_class')
class OneTest(object):
pass
@register_module('one_class',True)
class TwoTest(object):
pass
if __name__ == '__main__':
one_test = _module_dict['one_class']
print(one_test)
(5) 实现直接注册类功能
实现直接注册类的功能,只需要 _module_dict[‘name’] = module_class 即可。
上述内容基本讲解了 Registry 里面所有功能。实际上采用类的方式来管理会更加优雅方便,也就是 MMCV 中的实现方式。
Registry 类实现
基于上面的理解,此时再来看 MMCV 实现就会非常简单了,核心逻辑如下:
class Registry:
def __init__(self, name):
# 可实现注册类细分功能
self._name = name
# 内部核心内容,维护所有的已经注册好的 class
self._module_dict = dict()
def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, '
f'but got {type(module_class)}')
if module_name is None:
module_name = module_class.__name__
if not force and module_name in self._module_dict:
raise KeyError(f'{module_name} is already registered '
f'in {self.name}')
# 最核心代码
self._module_dict[module_name] = module_class
# 装饰器函数
def register_module(self, name=None, force=False, module=None):
if module is not None:
# 如果已经是 module,那就知道 增加到字典中即可
self._register_module(
module_class=module, module_name=name, force=force)
return module
# 最标准用法
# use it as a decorator: @x.register_module()
def _register(cls):
self._register_module(
module_class=cls, module_name=name, force=force)
return cls
return _register
在 MMCV 中所有的类实例化都是通过 build_from_cfg 函数实现,做的事情非常简单,就是给定 module_name,然后从 self._module_dict 提取即可。
def build_from_cfg(cfg, registry, default_args=None):
args = cfg.copy()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
obj_type = args.pop('type') # 注册 str 类名
if is_str(obj_type):
# 相当于 self._module_dict[obj_type]
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
# 如果已经实例化了,那就直接返回
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
# 最终初始化对于类,并且返回,就完成了一个类的实例化过程
return obj_cls(**args)
参考
https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html
MMCV 核心组件分析(五): Registry