流程分析:在利用TensorFlow框架进行语义分割训练时,前面的数据转化过程步骤中含有将voc格式的语义分割图像标签转化为物体像素值为训练类别索引的图像标签这一步骤。
    面对的问题:在项目开发过程中发现,tfrecord图像大小会影响模型最优在val数据集上的miou精度。同时目前负责开发的项目中数据分布不均衡,也需要进行一些数据增广操作。
    解决方案:写了以下代码,兼顾标签数据格式转化及数据增广功能
    可以优化的地方:下面代码是通过np.where来判断目标物体和背景,所以只支持单类别的数据格式处理。

    1. # NOTE: input_dir和output_dir可以是同一个路径。
    2. python augment.py --dst_size="480,480" --input_dir 输入voc数据路径 --output_dir 输出数据增广后的路径
    1. # -*- coding: utf-8 -*-
    2. """
    3. @File : augment.py
    4. @Time : 2022-2-26 22:57
    5. @Author : qian0733
    6. @Description : 代替remove_gt_colormap函数进行数据格式转化,并对数据进行了增广操作
    7. @局限性 : 只支持单类别的数据格式处理!
    8. """
    9. import os
    10. import cv2
    11. import sys
    12. import argparse
    13. import numpy as np
    14. def parse_arguments(argv):
    15. parser = argparse.ArgumentParser()
    16. parser.add_argument('--dst_size', type=str, help='dst resize size', default='480,480')
    17. parser.add_argument('--input_dir', type=str, help='input jpg and label dir (voc type)', default='')
    18. parser.add_argument('--output_dir', type=str, help='output images dir', default='')
    19. return parser.parse_args(argv)
    20. def augment_data():
    21. for img in os.listdir(input_img):
    22. img_path = os.path.join(input_img, img)
    23. label_path = os.path.join(input_label, img.split(".")[0] + ".png")
    24. jpg = cv2.imread(img_path)
    25. label = cv2.imread(label_path)
    26. # resize/hist/flip/crop
    27. # resize/hist/flip/crop
    28. jpg_bak = jpg.copy()
    29. label_bak = label.copy()
    30. h, w = jpg_bak.shape[:2]
    31. # 缩放至目标尺寸
    32. jpg_resize = cv2.resize(jpg_bak, dst_size, interpolation=cv2.INTER_AREA)
    33. label_resize = cv2.resize(label_bak, dst_size, interpolation=cv2.INTER_AREA)
    34. label_resize = np.where(label_resize[:, :, 2] > 0, 1, 0) # 3通道二值图
    35. label_resize = np.asanyarray(label_resize, dtype=np.uint8)
    36. cv2.imwrite(os.path.join(output_img, img.split(".")[0] + r"_resize.jpg"), jpg_resize)
    37. cv2.imwrite(os.path.join(output_labels, img.split(".")[0] + r"_resize.png"), label_resize.astype(np.uint8))
    38. #在缩放好的图像上做直方图均值化
    39. b,g,r =cv2.split(jpg_resize)
    40. bh = cv2.equalizeHist(b)
    41. gh = cv2.equalizeHist(g)
    42. rh = cv2.equalizeHist(r)
    43. jpg_hist = cv2.merge((bh, gh, rh))
    44. label_hist = label_resize
    45. cv2.imwrite(os.path.join(output_img, img.split(".")[0] + r"_hist.png"), label_hist.astype(np.uint8))
    46. # 图像水平方向的镜面翻转
    47. jpg_flip = cv2.flip(jpg_bak, 1)
    48. label_flip = cv2.flip(label_bak, 1)
    49. jpg_resize = cv2.resize(jpg_flip, dst_size, interpolation=cv2.INTER_AREA)
    50. label_resize = cv2.resize(label_flip, dst_size, interpolation=cv2.INTER_AREA)
    51. label_resize = np.where(label_resize[:, :, 2] > 0, 1, 0) # 二值图
    52. label_resize = np.asanyarray(label_resize, dtype=np.uint8)
    53. print(label_resize.shape)
    54. cv2.imwrite(os.path.join(output_img, img.split(".")[0] + r"_flip.jpg"), jpg_resize)
    55. cv2.imwrite(os.path.join(output_labels, img.split(".")[0] + r"_flip.png"), label_resize.astype(np.uint8))
    56. # 随机裁剪,以扩大目标管的面积占比
    57. idx = np.random.randint(1, 3)
    58. if idx == 1:
    59. jpg_crop = jpg_bak[: int(h * 0.75), : int(w * 0.75)]
    60. label_crop = label_bak[: int(h * 0.75), : int(w * 0.75)]
    61. else:
    62. jpg_crop = jpg_bak[int(h * 0.75) :, int(w * 0.75) :]
    63. label_crop = label_bak[int(h * 0.75) :, int(w * 0.75) :]
    64. label_crop = np.where(label_crop[:, :, 2] > 0, 1, 0) # 二值图
    65. label_crop = np.asanyarray(label_crop, dtype=np.uint8)
    66. cv2.imwrite(os.path.join(output_img, img.split(".")[0] + r"_crop.jpg"), jpg_crop)
    67. cv2.imwrite(os.path.join(output_labels, img.split(".")[0] + r"_crop.png"), label_crop.astype(np.uint8))
    68. if __name__ == '__main__':
    69. args = parse_arguments(sys.argv[1:])
    70. dst_size = tuple([int(i) for i in args.dst_size.split(",")])
    71. input = args.input_dir
    72. output = args.output_dir
    73. input_img, input_label = os.path.join(input, "JPEGImages"), os.path.join(input, "SegmentationClassPNG")
    74. output_img, output_labels = os.path.join(output, "images"), os.path.join(output, "labels")
    75. if not os.path.exists(output_img):
    76. os.makedirs(output_img)
    77. os.makedirs(output_labels)
    78. augment_data()