上周搭建了代码环境,跑通了整个训练流程。现在记录一下voc数据集如何转为YOLOv4的训练数据格式。

    以下是几个必备文件

    • JPEGImages文件夹:图像数据—不用修改,与原始voc数据一致
    • labels文件夹:标注数据(TXT文件)—需要解析voc的xml文件后进行生成
    • object.names文件:训练数据中包含的类别名称
    • object.data文件:训练的输入输出数据位置,以及类别数

    1、JPEGImages文件夹直接使用现成的就好,这里不做过多介绍。
    不过需要注意的时,在数据准备阶段最好把图像的后缀(如将.JPEG改为.jpg)统一一下,后面会省去不少麻烦事。

    2、labels文件夹需要自己新建,其中存放的文件是TXT格式的。
    每个TXT文件与JPEGImages文件夹下到的图像进行一一对应,其内容分别对应的训练图像中所标注的Bbox信息。
    每一行对应图中的一个Bbox,按照 的顺序进行写入,各元素以空格分隔。

    • 是0~len(classes)-1的整数
    • 分别是(0, 1]之间的归一化浮点数,这里的归一化是基于图像宽高进行的。

    以下是我的labels文件夹生成python文件。它还会生成一个包含所有训练图像数据的txt文件,这个可以在后面用于拆分自己的训练和测试数据。

    1. # -*- coding: utf-8 -*-
    2. """
    3. @File : voc2darknetyolo.py
    4. @Time : 2021-4-3 8:31
    5. @Author : yinqq
    6. @Description :
    7. @usage : python voc2darknetyolo.py --input_dir voc_root/ --output_txt trainval.txt --classes_file zks.names
    8. """
    9. import os
    10. import cv2
    11. import sys
    12. import codecs
    13. import xml.etree.ElementTree as ET
    14. import argparse
    15. def get_class_names(file_path):
    16. """
    17. 读取类别名称列表
    18. :param file_path:
    19. :return class_names: list of class_name
    20. """
    21. with codecs.open(file_path, mode='r', encoding='utf-8') as f:
    22. lines = f.readlines()
    23. return [l.strip() for l in lines]
    24. def xml_parse(image_id, input_anno, input_jpg, output_label, classes):
    25. """
    26. 解析给定imageID的xml文件,并将结果写入对应的txt文件中
    27. :param image_id: 不带.jpg后缀的图像名称
    28. :param input_anno: xml文件存放路径
    29. :param input_jpg: jpg文件存放路径
    30. :param output_label: labels文件夹下的txt输出路径
    31. :param classes:
    32. :return:
    33. """
    34. xml_file = open(os.path.join(input_anno, '{}.xml'.format(image_id)))
    35. tree = ET.parse(xml_file)
    36. root = tree.getroot()
    37. # 获取图像宽高信息,用于将Bbox转化到[0,1]区间内
    38. img_path = os.path.join(input_jpg, "{}.jpg".format(image_id))
    39. img = cv2.imread(img_path)
    40. dh, dw = 1. / img.shape[0], 1. / img.shape[1]
    41. # 将标注的Bbox转化为[0,1]区间内的值,并存入label文件夹下、与imageID同名TXT文件内
    42. output_label_txt = os.path.join(output_label, "{}.txt".format(image_id))
    43. with codecs.open(output_label_txt, mode='a', encoding='utf-8') as lt:
    44. if not len(root.findall('object')):
    45. print("{}.jpg not tagged!".format(image_id))
    46. for obj in root.iter('object'):
    47. cls = obj.find('name').text
    48. if cls not in classes:
    49. continue
    50. cls_id = classes.index(cls)
    51. xmlbox = obj.find('bndbox')
    52. # if dw * float(xmlbox.find('xmin').text) < 0. or dw * float(xmlbox.find('xmax').text) < 0. or dh * float(
    53. # xmlbox.find('ymin').text) < 0. or dh * float(xmlbox.find('ymax').text) < 0.:
    54. # print("{}.jpg tagged error!".format(image_id))
    55. b = (dw * float(xmlbox.find('xmin').text), dw * float(xmlbox.find('xmax').text),
    56. dh * float(xmlbox.find('ymin').text), dh * float(xmlbox.find('ymax').text))
    57. lt.write(str(cls_id) + " " + str((b[0] + b[1]) / 2) + " " + str((b[2] + b[3]) / 2) + " " + str(
    58. b[1] - b[0]) + " " + str(b[3] - b[2]))
    59. lt.write("\n")
    60. return img_path
    61. def arg_parse():
    62. parse = argparse.ArgumentParser()
    63. parse.add_argument("--input_dir", type=str, default="", help="存放图像JPEGImages和标注Annotations文件夹的文件夹路径")
    64. parse.add_argument("--output_txt", type=str, default="", help="输出标注图像名称绝对路径的TXT文件")
    65. parse.add_argument("--classes_file", type=str, default="", help="obj.names文件,用于生成类别信息")
    66. arguments = parse.parse_args(sys.argv[1:])
    67. input_anno = os.path.join(arguments.input_dir, "Annotations")
    68. input_jpge = os.path.join(arguments.input_dir, "JPEGImages")
    69. output_label = os.path.join(arguments.input_dir, "label")
    70. # output_txt = os.path.join(arguments.output_txt, "txt")
    71. classes = get_class_names(arguments.classes_file)
    72. if not os.path.exists(output_label):
    73. os.makedirs(output_label)
    74. return input_anno, input_jpge, output_label, arguments.output_txt, classes
    75. def main():
    76. input_anno, input_jpge, output_label, output_txt, classes = arg_parse()
    77. print(input_anno)
    78. print(input_jpge)
    79. print(output_label)
    80. print(output_txt)
    81. print(classes)
    82. # input_anno标注文件夹, input_jpge图像存放文件夹, input_txt所有图像ID不带后准, output_label与图像对应的TXT标注文件路进, output_txt图像保存的绝对路径
    83. with codecs.open(output_txt, mode='w', encoding='utf-8') as o_txt:
    84. for image in os.listdir(input_jpge):
    85. image_id = image.split(".")[0]
    86. img_path = xml_parse(image_id, input_anno, input_jpge, output_label, classes)
    87. o_txt.write(img_path + "\n")
    88. if __name__ == '__main__':
    89. main()

    3、object.names文件:训练数据中包含的类别名称,每行一个类名

    1. class1
    2. class2

    4、object.data文件:训练的输入输出数据位置,以及类别数

    1. classes= 2 # 训练数据包含的类别个数
    2. train = data/train.txt # 存储训练图像数据的路径
    3. valid = data/test.txt # 存储测试图像数据的路径
    4. names = data/object.names # 每行一个类别的名称
    5. backup = backup/ # 存放模型的weight文件

    参考文档:

    1. 官网:https://github.com/AlexeyAB/darknet
    2. 官网文档的中文翻译:https://zhuanlan.zhihu.com/p/102628373