SSD算法是 object detection 领域比较经典的算法,github 上有一个写得比较好的 MXNet 版本的实现代码,项目地址:https://github.com/zhreshold/mxnet-ssd,目前该项目代码也已经并入MXNet官方 github。想要本地实现可以参考项目地址中 README.md 的介绍或者参考博客:SSD 算法的 MXNet 实现

    接下来这一系列博客想介绍该代码中关于实现 SSD 算法的一些细节,也会涉及部分 Python 语言的巧妙代码,以训练模型为切入口展开介绍,最好按顺序阅读,详细注释已经在代码中给出。

    这一篇博客介绍训练模型的入口代码:train.py 脚本,该脚本主要包含一些参数设置和主函数。

    1. import argparse
    2. import tools.find_mxnet
    3. import mxnet as mx
    4. import os
    5. import sys
    6. from train.train_net import train_net
    7. def parse_args():
    8. parser = argparse.ArgumentParser(description='Train a Single-shot detection network')
    9. parser.add_argument('--train-path', dest='train_path', help='train record to use',
    10. default=os.path.join(os.getcwd(), 'data', 'train.rec'), type=str)
    11. parser.add_argument('--train-list', dest='train_list', help='train list to use',
    12. default="", type=str)
    13. parser.add_argument('--val-path', dest='val_path', help='validation record to use',
    14. default=os.path.join(os.getcwd(), 'data', 'val.rec'), type=str)
    15. parser.add_argument('--val-list', dest='val_list', help='validation list to use',
    16. default="", type=str)
    17. parser.add_argument('--network', dest='network', type=str, default='vgg16_reduced',
    18. help='which network to use')
    19. parser.add_argument('--batch-size', dest='batch_size', type=int, default=32,
    20. help='training batch size')
    21. parser.add_argument('--resume', dest='resume', type=int, default=-1,
    22. help='resume training from epoch n')
    23. parser.add_argument('--finetune', dest='finetune', type=int, default=-1,
    24. help='finetune from epoch n, rename the model before doing this')
    25. parser.add_argument('--pretrained', dest='pretrained', help='pretrained model prefix',
    26. default=os.path.join(os.getcwd(), 'model', 'vgg16_reduced'), type=str)
    27. parser.add_argument('--epoch', dest='epoch', help='epoch of pretrained model',
    28. default=1, type=int)
    29. parser.add_argument('--prefix', dest='prefix', help='new model prefix',
    30. default=os.path.join(os.getcwd(), 'model', 'ssd'), type=str)
    31. parser.add_argument('--gpus', dest='gpus', help='GPU devices to train with',
    32. default='0', type=str)
    33. parser.add_argument('--begin-epoch', dest='begin_epoch', help='begin epoch of training',
    34. default=0, type=int)
    35. parser.add_argument('--end-epoch', dest='end_epoch', help='end epoch of training',
    36. default=240, type=int)
    37. parser.add_argument('--frequent', dest='frequent', help='frequency of logging',
    38. default=20, type=int)
    39. parser.add_argument('--data-shape', dest='data_shape', type=int, default=300,
    40. help='set image shape')
    41. parser.add_argument('--label-width', dest='label_width', type=int, default=350,
    42. help='force padding label width to sync across train and validation')
    43. parser.add_argument('--lr', dest='learning_rate', type=float, default=0.004,
    44. help='learning rate')
    45. parser.add_argument('--momentum', dest='momentum', type=float, default=0.9,
    46. help='momentum')
    47. parser.add_argument('--wd', dest='weight_decay', type=float, default=0.0005,
    48. help='weight decay')
    49. parser.add_argument('--mean-r', dest='mean_r', type=float, default=123,
    50. help='red mean value')
    51. parser.add_argument('--mean-g', dest='mean_g', type=float, default=117,
    52. help='green mean value')
    53. parser.add_argument('--mean-b', dest='mean_b', type=float, default=104,
    54. help='blue mean value')
    55. parser.add_argument('--lr-steps', dest='lr_refactor_step', type=str, default='80, 160',
    56. help='refactor learning rate at specified epochs')
    57. parser.add_argument('--lr-factor', dest='lr_refactor_ratio', type=str, default=0.1,
    58. help='ratio to refactor learning rate')
    59. parser.add_argument('--freeze', dest='freeze_pattern', type=str, default="^(conv1_|conv2_).*",
    60. help='freeze layer pattern')
    61. parser.add_argument('--log', dest='log_file', type=str, default="train.log",
    62. help='save training log to file')
    63. parser.add_argument('--monitor', dest='monitor', type=int, default=0,
    64. help='log network parameters every N iters if larger than 0')
    65. parser.add_argument('--pattern', dest='monitor_pattern', type=str, default=".*",
    66. help='monitor parameter pattern, as regex')
    67. parser.add_argument('--num-class', dest='num_class', type=int, default=20,
    68. help='number of classes')
    69. parser.add_argument('--num-example', dest='num_example', type=int, default=16551,
    70. help='number of image examples')
    71. parser.add_argument('--class-names', dest='class_names', type=str,
    72. default='aeroplane, bicycle, bird, boat, bottle, bus, \
    73. car, cat, chair, cow, diningtable, dog, horse, motorbike, \
    74. person, pottedplant, sheep, sofa, train, tvmonitor',
    75. help='string of comma separated names, or text filename')
    76. parser.add_argument('--nms', dest='nms_thresh', type=float, default=0.45,
    77. help='non-maximum suppression threshold')
    78. parser.add_argument('--overlap', dest='overlap_thresh', type=float, default=0.5,
    79. help='evaluation overlap threshold')
    80. parser.add_argument('--force', dest='force_nms', type=bool, default=False,
    81. help='force non-maximum suppression on different class')
    82. parser.add_argument('--use-difficult', dest='use_difficult', type=bool, default=False,
    83. help='use difficult ground-truths in evaluation')
    84. parser.add_argument('--voc07', dest='use_voc07_metric', type=bool, default=True,
    85. help='use PASCAL VOC 07 11-point metric')
    86. args = parser.parse_args()
    87. return args
    88. def parse_class_names(args):
    89. """ parse # classes and class_names if applicable """
    90. num_class = args.num_class
    91. if len(args.class_names) > 0:
    92. if os.path.isfile(args.class_names):
    93. with open(args.class_names, 'r') as f:
    94. class_names = [l.strip() for l in f.readlines()]
    95. else:
    96. class_names = [c.strip() for c in args.class_names.split(',')]
    97. assert len(class_names) == num_class, str(len(class_names))
    98. for name in class_names:
    99. assert len(name) > 0
    100. else:
    101. class_names = None
    102. return class_names
    103. if __name__ == '__main__':
    104. args = parse_args()
    105. ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()]
    106. ctx = [mx.cpu()] if not ctx else ctx
    107. class_names = parse_class_names(args)
    108. train_net(args.network, args.train_path,
    109. args.num_class, args.batch_size,
    110. args.data_shape, [args.mean_r, args.mean_g, args.mean_b],
    111. args.resume, args.finetune, args.pretrained,
    112. args.epoch, args.prefix, ctx, args.begin_epoch, args.end_epoch,
    113. args.frequent, args.learning_rate, args.momentum, args.weight_decay,
    114. args.lr_refactor_step, args.lr_refactor_ratio,
    115. val_path=args.val_path,
    116. num_example=args.num_example,
    117. class_names=class_names,
    118. label_pad_width=args.label_width,
    119. freeze_layer_pattern=args.freeze_pattern,
    120. iter_monitor=args.monitor,
    121. monitor_pattern=args.monitor_pattern,
    122. log_file=args.log_file,
    123. nms_thresh=args.nms_thresh,
    124. force_nms=args.force_nms,
    125. ovp_thresh=args.overlap_thresh,
    126. use_difficult=args.use_difficult,
    127. voc07_metric=args.use_voc07_metric)