在利用官方代码生成用于训练的数据集时,需要先指定训练和测试集的样本名称。以下代码可以直接生成tfrecord所需的index。
注意:如果做了数据增广,那么需要在原始标注数据上做训练集的划分。防止数据泄露
# -*- coding: utf-8 -*-"""@File : voc2tfrecord_index.py@Time : 2022-5-26 21:37@Author : qian733@Description :"""import osimport sysimport argparsefrom glob import globfrom random import shuffledef parse_arguments(argv):parser = argparse.ArgumentParser()parser.add_argument('--root_path', type=str, help='output images dir', default='')parser.add_argument('--train_percent', type=float, help='the proportion of training data in the sample database',default='0.9')return parser.parse_args(argv)args = parse_arguments(sys.argv[1:])root_path = args.root_pathimage_path = os.path.join(root_path, "images")images = glob(image_path + "/*")shuffle(images)total_num = len(images)print("总数量:", total_num)train_num = int(total_num * args.train_percent)print("训练集数量:", train_num)print("验证集数量:", total_num - train_num)index_path = os.path.join(root_path, "index")if not os.path.exists(index_path):os.mkdir(index_path)train_path = os.path.join(index_path, "train.txt")val_path = os.path.join(index_path, "val.txt")list_file_1 = open(train_path, 'w')for item1 in images[:train_num]:image_id = item1.split('/')[-1].replace(".jpg", "").replace(".png", "")list_file_1.write('%s\n' % (image_id))list_file_1.close()list_file_2 = open(val_path, 'w')for item2 in images[train_num:]:image_id = item2.split('/')[-1].replace(".jpg", "").replace(".png", "")list_file_2.write('%s\n' % (image_id))list_file_2.close()
