1. 自定义一个pipeline

新建一个mmdet/datasets/pipelines/middleway.py

pipeline的运作逻辑从call(self, results)的results入,从return的results。

  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from ..builder import PIPELINES
  3. try:
  4. from panopticapi.utils import rgb2id
  5. except ImportError:
  6. rgb2id = None
  7. @PIPELINES.register_module()
  8. class MiddleDebug:
  9. """Load an image from file.
  10. Required keys are "img_prefix" and "img_info" (a dict that must contain the
  11. key "filename"). Added or updated keys are "filename", "img", "img_shape",
  12. "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
  13. "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
  14. Args:
  15. to_float32 (bool): Whether to convert the loaded image to a float32
  16. numpy array. If set to False, the loaded image is an uint8 array.
  17. Defaults to False.
  18. color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
  19. Defaults to 'color'.
  20. file_client_args (dict): Arguments to instantiate a FileClient.
  21. See :class:`mmcv.fileio.FileClient` for details.
  22. Defaults to ``dict(backend='disk')``.
  23. """
  24. def __init__(self, mode=1):
  25. self.mode = 1
  26. def __call__(self, results):
  27. """Call functions to load image and get image meta information.
  28. Args:
  29. results (dict): Result dict from :obj:`mmdet.CustomDataset`.
  30. Returns:
  31. dict: The dict contains loaded image and meta information.
  32. """
  33. print(results)
  34. return results

2. 使用自定义的pipeline

  1. pipeline=[
  2. dict(type='LoadImageFromFile'),
  3. dict(type='MiddleDebug'),
  4. dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
  5. dict(type='MiddleDebug'),
  6. dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
  7. dict(type='RandomFlip', flip_ratio=0.5),
  8. dict(type='MiddleDebug'),
  9. dict(
  10. type='Normalize',
  11. mean=[123.675, 116.28, 103.53],
  12. std=[58.395, 57.12, 57.375],
  13. to_rgb=True),
  14. dict(type='Pad', size_divisor=32),
  15. dict(type='DefaultFormatBundle'),
  16. dict(
  17. type='Collect',
  18. keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks'])
  19. ]),