快捷模式在mmdetection中部署:

注意事项

mmdetection中对模型的forward 函数进行了重载,与一般的forward只需要输入Tensor不同,还需要img_meta,这与EasyPruner所需要的jit.trace里面的forward函数不兼容。本工具包参考了mmdetection自带的转onnx代码的一些写法,进行解决。

示例

以FOCS目标检测方法或者SCRFD人脸检测方法为例:

步骤1 常规训练

步骤2 准备微调代码,在微调前,加载预训练模型,并加入剪枝代码

之后,在建模型后,加载常规训练得到的模型权重,并进行剪枝,SCRFD中不支持默认加载预训练模型,需要手工加载一下,例如在mmdet/apis/train.py 的train_detector中修改如下:

  1. data_loaders = [
  2. build_dataloader(
  3. ds,
  4. cfg.data.samples_per_gpu,
  5. cfg.data.workers_per_gpu,
  6. # cfg.gpus will be ignored if distributed
  7. len(cfg.gpu_ids),
  8. dist=distributed,
  9. seed=cfg.seed) for ds in dataset
  10. ]
  11. '''
  12. 以上为原始mmdetection代码
  13. '''
  14. # 对网络进行修改,使其改为单变量输入
  15. from easypruner import fastpruner
  16. from easypruner.utils.rebuild import rebuild
  17. from functools import partial
  18. import copy
  19. from mmdet.core.export import pytorch2onnx
  20. imgpath = "data/retinaface/val/images/0--Parade/0_Parade_Parade_0_194.jpg" # 需指定任意一个图片
  21. normalize_cfg = {'mean': [127,127,127], 'std': [55,55,55]} #格式需要,随便指定即可
  22. input_shape = [1,3,640,640] # 格式需要,按实际情况指定即可
  23. input_config = {
  24. 'input_shape': input_shape,
  25. 'input_path': imgpath,
  26. 'normalize_cfg': normalize_cfg
  27. }
  28. one_img, one_meta = pytorch2onnx.preprocess_example_input(input_config)
  29. model.forward = partial(
  30. model.forward, img_metas=[[one_meta]], return_loss=False)
  31. #加载模型并剪枝
  32. state_dict = torch.load("34Gmodel.pth")
  33. device = next(model.parameters()).device
  34. model.load_state_dict(state_dict['state_dict'],strict=True)
  35. model.cpu()
  36. fastpruner.fastpruner(model, prune_factor = 0.4, method="Ratio", input_dim=[3,640,640])##Ratio uniform两种方式都可以试试,注意大小写
  37. #fastpruner.fastpruner(model, prune_factor = 0.5, method="Uniform", input_dim=[3,640,640])##Ratio 和uniform两种方式都可以试试,注意大小写
  38. #fastpruner.fastpruner(model, prune_factor = 0.01, method="Order", input_dim=[3,640,640])##Ratio 和uniform两种方式都可以试试,注意大小写
  39. model.to(device)
  40. #保存剪枝后的模型权重
  41. state_dict = state_dict['state_dict']
  42. for k,v in model.state_dict().items():
  43. state_dict[k] = v
  44. save_path = './model_pruned_34_0.4uniform.pt' #
  45. torch.save(model.state_dict(),save_path)
  46. #torch.export()
  47. #复原模型
  48. model.forward = model.forward.func
  49. exit(0) #剪枝完可以直接训练finetune/或者exit退出,后续通过rebuild重构网络再finetune
  50. '''
  51. 以下为原始mmdetection代码
  52. '''
  53. # put model on gpus
  54. if distributed:
  55. find_unused_parameters = cfg.get('find_unused_parameters', False)
  56. # Sets the `find_unused_parameters` parameter in
  57. # torch.nn.parallel.DistributedDataParallel
  58. model = MMDistributedDataParallel(
  59. model.cuda(),
  60. device_ids=[torch.cuda.current_device()],
  61. broadcast_buffers=False,
  62. find_unused_parameters=find_unused_parameters)
  63. else:
  64. model = MMDataParallel(
  65. model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)

使用uniform或者测试模型flops压缩率时,需要将模型定义与模型前向中(甚至模型转onnx)函数中的neck和head部分去掉。因为bbox每次产生的个数不一样所以flops也不一样,所以我们仅考虑backbone的flops大小。比如faster_rcnn,我们需要修改 mmdet/models/detectors/two_stage.py。以2.18.0版本mmdet为例:

  1. @DETECTORS.register_module()
  2. class TwoStageDetector(BaseDetector):
  3. """Base class for two-stage detectors.
  4. Two-stage detectors typically consisting of a region proposal network and a
  5. task-specific regression head.
  6. """
  7. def __init__(self,
  8. backbone,
  9. neck=None,
  10. rpn_head=None,
  11. roi_head=None,
  12. train_cfg=None,
  13. test_cfg=None,
  14. pretrained=None,
  15. init_cfg=None):
  16. super(TwoStageDetector, self).__init__(init_cfg)
  17. if pretrained:
  18. warnings.warn('DeprecationWarning: pretrained is deprecated, '
  19. 'please use "init_cfg" instead')
  20. backbone.pretrained = pretrained
  21. self.backbone = build_backbone(backbone)
  22. ####注释掉这里
  23. '''
  24. if neck is not None:
  25. self.neck = build_neck(neck)
  26. if rpn_head is not None:
  27. rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None
  28. rpn_head_ = rpn_head.copy()
  29. rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)
  30. self.rpn_head = build_head(rpn_head_)
  31. if roi_head is not None:
  32. # update train and test cfg here for now
  33. # TODO: refactor assigner & sampler
  34. rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None
  35. roi_head.update(train_cfg=rcnn_train_cfg)
  36. roi_head.update(test_cfg=test_cfg.rcnn)
  37. roi_head.pretrained = pretrained
  38. self.roi_head = build_head(roi_head)
  39. '''
  40. self.train_cfg = train_cfg
  41. self.test_cfg = test_cfg
  42. @property
  43. def with_rpn(self):
  44. """bool: whether the detector has RPN"""
  45. return hasattr(self, 'rpn_head') and self.rpn_head is not None
  46. ...
  47. def forward_dummy(self, img):
  48. """Used for computing network flops.
  49. See `mmdetection/tools/analysis_tools/get_flops.py`
  50. """
  51. outs = ()
  52. # backbone
  53. x = self.extract_feat(img)
  54. return x #增加这行
  55. #注释这里
  56. '''
  57. # rpn
  58. if self.with_rpn:
  59. rpn_outs = self.rpn_head(x)
  60. outs = outs + (rpn_outs, )
  61. proposals = torch.randn(1000, 4).to(img.device)
  62. # roi_head
  63. roi_outs = self.roi_head.forward_dummy(x, proposals)
  64. outs = outs + (roi_outs, )
  65. return outs
  66. '''
  67. def forward_train(self,
  68. img,
  69. img_metas,
  70. gt_bboxes,
  71. gt_labels,
  72. gt_bboxes_ignore=None,
  73. gt_masks=None,
  74. proposals=None,
  75. **kwargs):
  76. """
  77. Args:
  78. img (Tensor): of shape (N, C, H, W) encoding input images.
  79. Typically these should be mean centered and std scaled.
  80. img_metas (list[dict]): list of image info dict where each dict
  81. has: 'img_shape', 'scale_factor', 'flip', and may also contain
  82. 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
  83. For details on the values of these keys see
  84. `mmdet/datasets/pipelines/formatting.py:Collect`.
  85. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
  86. shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  87. gt_labels (list[Tensor]): class indices corresponding to each box
  88. gt_bboxes_ignore (None | list[Tensor]): specify which bounding
  89. boxes can be ignored when computing the loss.
  90. gt_masks (None | Tensor) : true segmentation masks for each box
  91. used if the architecture supports a segmentation task.
  92. proposals : override rpn proposals with custom proposals. Use when
  93. `with_rpn` is False.
  94. Returns:
  95. dict[str, Tensor]: a dictionary of loss components
  96. """
  97. x = self.extract_feat(img)
  98. return x #增加这行
  99. #注释这里
  100. '''
  101. losses = dict()
  102. # RPN forward and loss
  103. if self.with_rpn:
  104. proposal_cfg = self.train_cfg.get('rpn_proposal',
  105. self.test_cfg.rpn)
  106. rpn_losses, proposal_list = self.rpn_head.forward_train(
  107. x,
  108. img_metas,
  109. gt_bboxes,
  110. gt_labels=None,
  111. gt_bboxes_ignore=gt_bboxes_ignore,
  112. proposal_cfg=proposal_cfg,
  113. **kwargs)
  114. losses.update(rpn_losses)
  115. else:
  116. proposal_list = proposals
  117. roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list,
  118. gt_bboxes, gt_labels,
  119. gt_bboxes_ignore, gt_masks,
  120. **kwargs)
  121. losses.update(roi_losses)
  122. return losses
  123. '''
  124. ...
  125. def simple_test(self, img, img_metas, proposals=None, rescale=False):
  126. """Test without augmentation."""
  127. #注释这里
  128. #assert self.with_bbox, 'Bbox head must be implemented.'
  129. x = self.extract_feat(img)
  130. return x #增加这行
  131. #注释这里
  132. '''
  133. if proposals is None:
  134. proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
  135. else:
  136. proposal_list = proposals
  137. return self.roi_head.simple_test(
  138. x, proposal_list, img_metas, rescale=rescale)
  139. '''
  140. ...
  141. def onnx_export(self, img, img_metas):
  142. img_shape = torch._shape_as_tensor(img)[2:]
  143. img_metas[0]['img_shape_for_onnx'] = img_shape
  144. x = self.extract_feat(img)
  145. return x #增加这行
  146. #注释这里
  147. '''
  148. proposals = self.rpn_head.onnx_export(x, img_metas)
  149. if hasattr(self.roi_head, 'onnx_export'):
  150. return self.roi_head.onnx_export(x, proposals, img_metas)
  151. else:
  152. raise NotImplementedError(
  153. f'{self.__class__.__name__} can not '
  154. f'be exported to ONNX. Please refer to the '
  155. f'list of supported models,'
  156. f'https://mmdetection.readthedocs.io/en/latest/tutorials/pytorch2onnx.html#list-of-supported-models-exportable-to-onnx' # noqa E501
  157. )
  158. '''
  159. ...

步骤 3 基于常规训练的模型微调

步骤 4 单独准确率测试、部署转onnx时,需要将权重加载代码改为用rebuild函数加载,如以下代码:

依然在mmdet/apis/train.py 的train_detector中进行修改

            data_loaders = [
          build_dataloader(
              ds,
              cfg.data.samples_per_gpu,
              cfg.data.workers_per_gpu,
              # cfg.gpus will be ignored if distributed
              len(cfg.gpu_ids),
              dist=distributed,
              seed=cfg.seed) for ds in dataset
      ]
  '''
  以上为原始mmdetection代码
  '''        
    from easypruner.utils.rebuild import rebuild 
    state_dict = torch.load("model_pruned_34_0.4uniform_finetuned.pt")#剪枝后的模型权重文件
    #import pdb;pdb.set_trace() 
    model = rebuild(model , state_dict)
  '''
  以下为原始mmdetection代码
  '''
     # put model on gpus
      if distributed:
         find_unused_parameters = cfg.get('find_unused_parameters', False)
         # Sets the `find_unused_parameters` parameter in
         # torch.nn.parallel.DistributedDataParallel
         model = MMDistributedDataParallel(
             model.cuda(),
             device_ids=[torch.cuda.current_device()],
             broadcast_buffers=False,
             find_unused_parameters=find_unused_parameters)
      else:
         model = MMDataParallel(
             model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)