在利用官方代码生成用于训练的数据集时,需要先指定训练和测试集的样本名称。以下代码可以直接生成tfrecord所需的index。
    注意:如果做了数据增广,那么需要在原始标注数据上做训练集的划分。防止数据泄露

    1. # -*- coding: utf-8 -*-
    2. """
    3. @File : voc2tfrecord_index.py
    4. @Time : 2022-5-26 21:37
    5. @Author : qian733
    6. @Description :
    7. """
    8. import os
    9. import sys
    10. import argparse
    11. from glob import glob
    12. from random import shuffle
    13. def parse_arguments(argv):
    14. parser = argparse.ArgumentParser()
    15. parser.add_argument('--root_path', type=str, help='output images dir', default='')
    16. parser.add_argument('--train_percent', type=float, help='the proportion of training data in the sample database',
    17. default='0.9')
    18. return parser.parse_args(argv)
    19. args = parse_arguments(sys.argv[1:])
    20. root_path = args.root_path
    21. image_path = os.path.join(root_path, "images")
    22. images = glob(image_path + "/*")
    23. shuffle(images)
    24. total_num = len(images)
    25. print("总数量:", total_num)
    26. train_num = int(total_num * args.train_percent)
    27. print("训练集数量:", train_num)
    28. print("验证集数量:", total_num - train_num)
    29. index_path = os.path.join(root_path, "index")
    30. if not os.path.exists(index_path):
    31. os.mkdir(index_path)
    32. train_path = os.path.join(index_path, "train.txt")
    33. val_path = os.path.join(index_path, "val.txt")
    34. list_file_1 = open(train_path, 'w')
    35. for item1 in images[:train_num]:
    36. image_id = item1.split('/')[-1].replace(".jpg", "").replace(".png", "")
    37. list_file_1.write('%s\n' % (image_id))
    38. list_file_1.close()
    39. list_file_2 = open(val_path, 'w')
    40. for item2 in images[train_num:]:
    41. image_id = item2.split('/')[-1].replace(".jpg", "").replace(".png", "")
    42. list_file_2.write('%s\n' % (image_id))
    43. list_file_2.close()