train_net.py 这个脚本一共包含 convert_pretrained,get_lr_scheduler,train_net 三个函数,其中最重要的是 train_net 函数,这个函数也是 train.py 脚本训练模型时候调用的函数,建议从 train_net 函数开始看起。

  1. import tools.find_mxnet
  2. import mxnet as mx
  3. import logging
  4. import sys
  5. import os
  6. import importlib
  7. import re
  8. from dataset.iterator import DetRecordIter
  9. from train.metric import MultiBoxMetric
  10. from evaluate.eval_metric import MApMetric, VOC07MApMetric
  11. from config.config import cfg
  12. from symbol.symbol_factory import get_symbol_train
  13. def convert_pretrained(name, args):
  14. """
  15. Special operations need to be made due to name inconsistance, etc
  16. Parameters:
  17. ---------
  18. name : str
  19. pretrained model name
  20. args : dict
  21. loaded arguments
  22. Returns:
  23. ---------
  24. processed arguments as dict
  25. """
  26. return args
  27. def get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio,
  28. num_example, batch_size, begin_epoch):
  29. """
  30. Compute learning rate and refactor scheduler
  31. Parameters:
  32. ---------
  33. learning_rate : float
  34. original learning rate
  35. lr_refactor_step : comma separated str
  36. epochs to change learning rate
  37. lr_refactor_ratio : float
  38. lr *= ratio at certain steps
  39. num_example : int
  40. number of training images, used to estimate the iterations given epochs
  41. batch_size : int
  42. training batch size
  43. begin_epoch : int
  44. starting epoch
  45. Returns:
  46. ---------
  47. (learning_rate, mx.lr_scheduler) as tuple
  48. """
  49. assert lr_refactor_ratio > 0
  50. iter_refactor = [int(r) for r in lr_refactor_step.split(',') if r.strip()]
  51. if lr_refactor_ratio >= 1:
  52. return (learning_rate, None)
  53. else:
  54. lr = learning_rate
  55. epoch_size = num_example // batch_size
  56. for s in iter_refactor:
  57. if begin_epoch >= s:
  58. lr *= lr_refactor_ratio
  59. if lr != learning_rate:
  60. logging.getLogger().info("Adjusted learning rate to {} for epoch {}".format(lr, begin_epoch))
  61. steps = [epoch_size * (x - begin_epoch) for x in iter_refactor if x > begin_epoch]
  62. if not steps:
  63. return (lr, None)
  64. lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=lr_refactor_ratio)
  65. return (lr, lr_scheduler)
  66. def train_net(net, train_path, num_classes, batch_size,
  67. data_shape, mean_pixels, resume, finetune, pretrained, epoch,
  68. prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate,
  69. momentum, weight_decay, lr_refactor_step, lr_refactor_ratio,
  70. freeze_layer_pattern='',
  71. num_example=10000, label_pad_width=350,
  72. nms_thresh=0.45, force_nms=False, ovp_thresh=0.5,
  73. use_difficult=False, class_names=None,
  74. voc07_metric=False, nms_topk=400, force_suppress=False,
  75. train_list="", val_path="", val_list="", iter_monitor=0,
  76. monitor_pattern=".*", log_file=None):
  77. """
  78. Wrapper for training phase.
  79. Parameters:
  80. ----------
  81. net : str
  82. symbol name for the network structure
  83. train_path : str
  84. record file path for training
  85. num_classes : int
  86. number of object classes, not including background
  87. batch_size : int
  88. training batch-size
  89. data_shape : int or tuple
  90. width/height as integer or (3, height, width) tuple
  91. mean_pixels : tuple of floats
  92. mean pixel values for red, green and blue
  93. resume : int
  94. resume from previous checkpoint if > 0
  95. finetune : int
  96. fine-tune from previous checkpoint if > 0
  97. pretrained : str
  98. prefix of pretrained model, including path
  99. epoch : int
  100. load epoch of either resume/finetune/pretrained model
  101. prefix : str
  102. prefix for saving checkpoints
  103. ctx : [mx.cpu()] or [mx.gpu(x)]
  104. list of mxnet contexts
  105. begin_epoch : int
  106. starting epoch for training, should be 0 if not otherwise specified
  107. end_epoch : int
  108. end epoch of training
  109. frequent : int
  110. frequency to print out training status
  111. learning_rate : float
  112. training learning rate
  113. momentum : float
  114. trainig momentum
  115. weight_decay : float
  116. training weight decay param
  117. lr_refactor_ratio : float
  118. multiplier for reducing learning rate
  119. lr_refactor_step : comma separated integers
  120. at which epoch to rescale learning rate, e.g. '30, 60, 90'
  121. freeze_layer_pattern : str
  122. regex pattern for layers need to be fixed
  123. num_example : int
  124. number of training images
  125. label_pad_width : int
  126. force padding training and validation labels to sync their label widths
  127. nms_thresh : float
  128. non-maximum suppression threshold for validation
  129. force_nms : boolean
  130. suppress overlaped objects from different classes
  131. train_list : str
  132. list file path for training, this will replace the embeded labels in record
  133. val_path : str
  134. record file path for validation
  135. val_list : str
  136. list file path for validation, this will replace the embeded labels in record
  137. iter_monitor : int
  138. monitor internal stats in networks if > 0, specified by monitor_pattern
  139. monitor_pattern : str
  140. regex pattern for monitoring network stats
  141. log_file : str
  142. log to file if enabled
  143. """
  144. logging.basicConfig()
  145. logger = logging.getLogger()
  146. logger.setLevel(logging.INFO)
  147. if log_file:
  148. fh = logging.FileHandler(log_file)
  149. logger.addHandler(fh)
  150. if isinstance(data_shape, int):
  151. data_shape = (3, data_shape, data_shape)
  152. assert len(data_shape) == 3 and data_shape[0] == 3
  153. if prefix.endswith('_'):
  154. prefix += '_' + str(data_shape[1])
  155. if isinstance(mean_pixels, (int, float)):
  156. mean_pixels = [mean_pixels, mean_pixels, mean_pixels]
  157. assert len(mean_pixels) == 3, "must provide all RGB mean values"
  158. train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels,
  159. label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train)
  160. if val_path:
  161. val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels,
  162. label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid)
  163. else:
  164. val_iter = None
  165. net = get_symbol_train(net, data_shape[1], num_classes=num_classes,
  166. nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk)
  167. if freeze_layer_pattern.strip():
  168. re_prog = re.compile(freeze_layer_pattern)
  169. fixed_param_names = [name for name in net.list_arguments() if re_prog.match(name)]
  170. else:
  171. fixed_param_names = None
  172. ctx_str = '('+ ','.join([str(c) for c in ctx]) + ')'
  173. if resume > 0:
  174. logger.info("Resume training with {} from epoch {}"
  175. .format(ctx_str, resume))
  176. _, args, auxs = mx.model.load_checkpoint(prefix, resume)
  177. begin_epoch = resume
  178. elif finetune > 0:
  179. logger.info("Start finetuning with {} from epoch {}"
  180. .format(ctx_str, finetune))
  181. _, args, auxs = mx.model.load_checkpoint(prefix, finetune)
  182. begin_epoch = finetune
  183. exe = net.simple_bind(mx.cpu(), data=(1, 3, 300, 300), label=(1, 1, 5), grad_req='null')
  184. arg_dict = exe.arg_dict
  185. fixed_param_names = []
  186. for k, v in arg_dict.items():
  187. if k in args:
  188. if v.shape != args[k].shape:
  189. del args[k]
  190. logging.info("Removed %s" % k)
  191. else:
  192. if not 'pred' in k:
  193. fixed_param_names.append(k)
  194. elif pretrained:
  195. logger.info("Start training with {} from pretrained model {}"
  196. .format(ctx_str, pretrained))
  197. _, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
  198. args = convert_pretrained(pretrained, args)
  199. else:
  200. logger.info("Experimental: start training from scratch with {}"
  201. .format(ctx_str))
  202. args = None
  203. auxs = None
  204. fixed_param_names = None
  205. if fixed_param_names:
  206. logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']')
  207. mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx,
  208. fixed_param_names=fixed_param_names)
  209. batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent)
  210. epoch_end_callback = mx.callback.do_checkpoint(prefix)
  211. learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step,
  212. lr_refactor_ratio, num_example, batch_size, begin_epoch)
  213. optimizer_params={'learning_rate':learning_rate,
  214. 'momentum':momentum,
  215. 'wd':weight_decay,
  216. 'lr_scheduler':lr_scheduler,
  217. 'clip_gradient':None,
  218. 'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0 }
  219. monitor = mx.mon.Monitor(iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None
  220. if voc07_metric:
  221. valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3)
  222. else:
  223. valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3)
  224. mod.fit(train_iter,
  225. val_iter,
  226. eval_metric=MultiBoxMetric(),
  227. validation_metric=valid_metric,
  228. batch_end_callback=batch_end_callback,
  229. epoch_end_callback=epoch_end_callback,
  230. optimizer='sgd',
  231. optimizer_params=optimizer_params,
  232. begin_epoch=begin_epoch,
  233. num_epoch=end_epoch,
  234. initializer=mx.init.Xavier(),
  235. arg_params=args,
  236. aux_params=auxs,
  237. allow_missing=True,
  238. monitor=monitor)

这篇博客介绍了SSD算法的整体架构,旨在从宏观上加深对该算法的认识。从上面的代码介绍可以看出,在train_net函数中关于网络结构的构建是通过symbol_factory.py脚本的get_symbol_train函数进行的,因为网络结构的构建是SSD算法的核心,因此接下来一篇博客先来介绍关于网络结构构建的一些参数配置:

参考资料