上周搭建了代码环境,跑通了整个训练流程。现在记录一下voc数据集如何转为YOLOv4的训练数据格式。
以下是几个必备文件
- JPEGImages文件夹:图像数据—不用修改,与原始voc数据一致
- labels文件夹:标注数据(TXT文件)—需要解析voc的xml文件后进行生成
- object.names文件:训练数据中包含的类别名称
- object.data文件:训练的输入输出数据位置,以及类别数
1、JPEGImages文件夹直接使用现成的就好,这里不做过多介绍。
不过需要注意的时,在数据准备阶段最好把图像的后缀(如将.JPEG改为.jpg)统一一下,后面会省去不少麻烦事。
2、labels文件夹需要自己新建,其中存放的文件是TXT格式的。
每个TXT文件与JPEGImages文件夹下到的图像进行一一对应,其内容分别对应的训练图像中所标注的Bbox信息。
每一行对应图中的一个Bbox,按照
- 是0~len(classes)-1的整数
分别是(0, 1]之间的归一化浮点数,这里的归一化是基于图像宽高进行的。
以下是我的labels文件夹生成python文件。它还会生成一个包含所有训练图像数据的txt文件,这个可以在后面用于拆分自己的训练和测试数据。
# -*- coding: utf-8 -*-
"""
@File : voc2darknetyolo.py
@Time : 2021-4-3 8:31
@Author : yinqq
@Description :
@usage : python voc2darknetyolo.py --input_dir voc_root/ --output_txt trainval.txt --classes_file zks.names
"""
import os
import cv2
import sys
import codecs
import xml.etree.ElementTree as ET
import argparse
def get_class_names(file_path):
"""
读取类别名称列表
:param file_path:
:return class_names: list of class_name
"""
with codecs.open(file_path, mode='r', encoding='utf-8') as f:
lines = f.readlines()
return [l.strip() for l in lines]
def xml_parse(image_id, input_anno, input_jpg, output_label, classes):
"""
解析给定imageID的xml文件,并将结果写入对应的txt文件中
:param image_id: 不带.jpg后缀的图像名称
:param input_anno: xml文件存放路径
:param input_jpg: jpg文件存放路径
:param output_label: labels文件夹下的txt输出路径
:param classes:
:return:
"""
xml_file = open(os.path.join(input_anno, '{}.xml'.format(image_id)))
tree = ET.parse(xml_file)
root = tree.getroot()
# 获取图像宽高信息,用于将Bbox转化到[0,1]区间内
img_path = os.path.join(input_jpg, "{}.jpg".format(image_id))
img = cv2.imread(img_path)
dh, dw = 1. / img.shape[0], 1. / img.shape[1]
# 将标注的Bbox转化为[0,1]区间内的值,并存入label文件夹下、与imageID同名TXT文件内
output_label_txt = os.path.join(output_label, "{}.txt".format(image_id))
with codecs.open(output_label_txt, mode='a', encoding='utf-8') as lt:
if not len(root.findall('object')):
print("{}.jpg not tagged!".format(image_id))
for obj in root.iter('object'):
cls = obj.find('name').text
if cls not in classes:
continue
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
# if dw * float(xmlbox.find('xmin').text) < 0. or dw * float(xmlbox.find('xmax').text) < 0. or dh * float(
# xmlbox.find('ymin').text) < 0. or dh * float(xmlbox.find('ymax').text) < 0.:
# print("{}.jpg tagged error!".format(image_id))
b = (dw * float(xmlbox.find('xmin').text), dw * float(xmlbox.find('xmax').text),
dh * float(xmlbox.find('ymin').text), dh * float(xmlbox.find('ymax').text))
lt.write(str(cls_id) + " " + str((b[0] + b[1]) / 2) + " " + str((b[2] + b[3]) / 2) + " " + str(
b[1] - b[0]) + " " + str(b[3] - b[2]))
lt.write("\n")
return img_path
def arg_parse():
parse = argparse.ArgumentParser()
parse.add_argument("--input_dir", type=str, default="", help="存放图像JPEGImages和标注Annotations文件夹的文件夹路径")
parse.add_argument("--output_txt", type=str, default="", help="输出标注图像名称绝对路径的TXT文件")
parse.add_argument("--classes_file", type=str, default="", help="obj.names文件,用于生成类别信息")
arguments = parse.parse_args(sys.argv[1:])
input_anno = os.path.join(arguments.input_dir, "Annotations")
input_jpge = os.path.join(arguments.input_dir, "JPEGImages")
output_label = os.path.join(arguments.input_dir, "label")
# output_txt = os.path.join(arguments.output_txt, "txt")
classes = get_class_names(arguments.classes_file)
if not os.path.exists(output_label):
os.makedirs(output_label)
return input_anno, input_jpge, output_label, arguments.output_txt, classes
def main():
input_anno, input_jpge, output_label, output_txt, classes = arg_parse()
print(input_anno)
print(input_jpge)
print(output_label)
print(output_txt)
print(classes)
# input_anno标注文件夹, input_jpge图像存放文件夹, input_txt所有图像ID不带后准, output_label与图像对应的TXT标注文件路进, output_txt图像保存的绝对路径
with codecs.open(output_txt, mode='w', encoding='utf-8') as o_txt:
for image in os.listdir(input_jpge):
image_id = image.split(".")[0]
img_path = xml_parse(image_id, input_anno, input_jpge, output_label, classes)
o_txt.write(img_path + "\n")
if __name__ == '__main__':
main()
3、object.names文件:训练数据中包含的类别名称,每行一个类名
class1
class2
4、object.data文件:训练的输入输出数据位置,以及类别数
classes= 2 # 训练数据包含的类别个数
train = data/train.txt # 存储训练图像数据的路径
valid = data/test.txt # 存储测试图像数据的路径
names = data/object.names # 每行一个类别的名称
backup = backup/ # 存放模型的weight文件
参考文档: