在利用官方代码生成用于训练的数据集时,需要先指定训练和测试集的样本名称。以下代码可以直接生成tfrecord所需的index。
注意:如果做了数据增广,那么需要在原始标注数据上做训练集的划分。防止数据泄露
# -*- coding: utf-8 -*-
"""
@File : voc2tfrecord_index.py
@Time : 2022-5-26 21:37
@Author : qian733
@Description :
"""
import os
import sys
import argparse
from glob import glob
from random import shuffle
def parse_arguments(argv):
parser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str, help='output images dir', default='')
parser.add_argument('--train_percent', type=float, help='the proportion of training data in the sample database',
default='0.9')
return parser.parse_args(argv)
args = parse_arguments(sys.argv[1:])
root_path = args.root_path
image_path = os.path.join(root_path, "images")
images = glob(image_path + "/*")
shuffle(images)
total_num = len(images)
print("总数量:", total_num)
train_num = int(total_num * args.train_percent)
print("训练集数量:", train_num)
print("验证集数量:", total_num - train_num)
index_path = os.path.join(root_path, "index")
if not os.path.exists(index_path):
os.mkdir(index_path)
train_path = os.path.join(index_path, "train.txt")
val_path = os.path.join(index_path, "val.txt")
list_file_1 = open(train_path, 'w')
for item1 in images[:train_num]:
image_id = item1.split('/')[-1].replace(".jpg", "").replace(".png", "")
list_file_1.write('%s\n' % (image_id))
list_file_1.close()
list_file_2 = open(val_path, 'w')
for item2 in images[train_num:]:
image_id = item2.split('/')[-1].replace(".jpg", "").replace(".png", "")
list_file_2.write('%s\n' % (image_id))
list_file_2.close()