train_net.py 这个脚本一共包含 convert_pretrained,get_lr_scheduler,train_net 三个函数,其中最重要的是 train_net 函数,这个函数也是 train.py 脚本训练模型时候调用的函数,建议从 train_net 函数开始看起。
import tools.find_mxnetimport mxnet as mximport loggingimport sysimport osimport importlibimport refrom dataset.iterator import DetRecordIterfrom train.metric import MultiBoxMetricfrom evaluate.eval_metric import MApMetric, VOC07MApMetricfrom config.config import cfgfrom symbol.symbol_factory import get_symbol_traindef convert_pretrained(name, args):"""Special operations need to be made due to name inconsistance, etcParameters:---------name : strpretrained model nameargs : dictloaded argumentsReturns:---------processed arguments as dict"""return argsdef get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio,num_example, batch_size, begin_epoch):"""Compute learning rate and refactor schedulerParameters:---------learning_rate : floatoriginal learning ratelr_refactor_step : comma separated strepochs to change learning ratelr_refactor_ratio : floatlr *= ratio at certain stepsnum_example : intnumber of training images, used to estimate the iterations given epochsbatch_size : inttraining batch sizebegin_epoch : intstarting epochReturns:---------(learning_rate, mx.lr_scheduler) as tuple"""assert lr_refactor_ratio > 0iter_refactor = [int(r) for r in lr_refactor_step.split(',') if r.strip()]if lr_refactor_ratio >= 1:return (learning_rate, None)else:lr = learning_rateepoch_size = num_example // batch_sizefor s in iter_refactor:if begin_epoch >= s:lr *= lr_refactor_ratioif lr != learning_rate:logging.getLogger().info("Adjusted learning rate to {} for epoch {}".format(lr, begin_epoch))steps = [epoch_size * (x - begin_epoch) for x in iter_refactor if x > begin_epoch]if not steps:return (lr, None)lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=lr_refactor_ratio)return (lr, lr_scheduler)def train_net(net, train_path, num_classes, batch_size,data_shape, mean_pixels, resume, finetune, pretrained, epoch,prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate,momentum, weight_decay, lr_refactor_step, lr_refactor_ratio,freeze_layer_pattern='',num_example=10000, label_pad_width=350,nms_thresh=0.45, force_nms=False, ovp_thresh=0.5,use_difficult=False, class_names=None,voc07_metric=False, nms_topk=400, force_suppress=False,train_list="", val_path="", val_list="", iter_monitor=0,monitor_pattern=".*", log_file=None):"""Wrapper for training phase.Parameters:----------net : strsymbol name for the network structuretrain_path : strrecord file path for trainingnum_classes : intnumber of object classes, not including backgroundbatch_size : inttraining batch-sizedata_shape : int or tuplewidth/height as integer or (3, height, width) tuplemean_pixels : tuple of floatsmean pixel values for red, green and blueresume : intresume from previous checkpoint if > 0finetune : intfine-tune from previous checkpoint if > 0pretrained : strprefix of pretrained model, including pathepoch : intload epoch of either resume/finetune/pretrained modelprefix : strprefix for saving checkpointsctx : [mx.cpu()] or [mx.gpu(x)]list of mxnet contextsbegin_epoch : intstarting epoch for training, should be 0 if not otherwise specifiedend_epoch : intend epoch of trainingfrequent : intfrequency to print out training statuslearning_rate : floattraining learning ratemomentum : floattrainig momentumweight_decay : floattraining weight decay paramlr_refactor_ratio : floatmultiplier for reducing learning ratelr_refactor_step : comma separated integersat which epoch to rescale learning rate, e.g. '30, 60, 90'freeze_layer_pattern : strregex pattern for layers need to be fixednum_example : intnumber of training imageslabel_pad_width : intforce padding training and validation labels to sync their label widthsnms_thresh : floatnon-maximum suppression threshold for validationforce_nms : booleansuppress overlaped objects from different classestrain_list : strlist file path for training, this will replace the embeded labels in recordval_path : strrecord file path for validationval_list : strlist file path for validation, this will replace the embeded labels in recorditer_monitor : intmonitor internal stats in networks if > 0, specified by monitor_patternmonitor_pattern : strregex pattern for monitoring network statslog_file : strlog to file if enabled"""logging.basicConfig()logger = logging.getLogger()logger.setLevel(logging.INFO)if log_file:fh = logging.FileHandler(log_file)logger.addHandler(fh)if isinstance(data_shape, int):data_shape = (3, data_shape, data_shape)assert len(data_shape) == 3 and data_shape[0] == 3if prefix.endswith('_'):prefix += '_' + str(data_shape[1])if isinstance(mean_pixels, (int, float)):mean_pixels = [mean_pixels, mean_pixels, mean_pixels]assert len(mean_pixels) == 3, "must provide all RGB mean values"train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels,label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train)if val_path:val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels,label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid)else:val_iter = Nonenet = get_symbol_train(net, data_shape[1], num_classes=num_classes,nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk)if freeze_layer_pattern.strip():re_prog = re.compile(freeze_layer_pattern)fixed_param_names = [name for name in net.list_arguments() if re_prog.match(name)]else:fixed_param_names = Nonectx_str = '('+ ','.join([str(c) for c in ctx]) + ')'if resume > 0:logger.info("Resume training with {} from epoch {}".format(ctx_str, resume))_, args, auxs = mx.model.load_checkpoint(prefix, resume)begin_epoch = resumeelif finetune > 0:logger.info("Start finetuning with {} from epoch {}".format(ctx_str, finetune))_, args, auxs = mx.model.load_checkpoint(prefix, finetune)begin_epoch = finetuneexe = net.simple_bind(mx.cpu(), data=(1, 3, 300, 300), label=(1, 1, 5), grad_req='null')arg_dict = exe.arg_dictfixed_param_names = []for k, v in arg_dict.items():if k in args:if v.shape != args[k].shape:del args[k]logging.info("Removed %s" % k)else:if not 'pred' in k:fixed_param_names.append(k)elif pretrained:logger.info("Start training with {} from pretrained model {}".format(ctx_str, pretrained))_, args, auxs = mx.model.load_checkpoint(pretrained, epoch)args = convert_pretrained(pretrained, args)else:logger.info("Experimental: start training from scratch with {}".format(ctx_str))args = Noneauxs = Nonefixed_param_names = Noneif fixed_param_names:logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']')mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx,fixed_param_names=fixed_param_names)batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent)epoch_end_callback = mx.callback.do_checkpoint(prefix)learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step,lr_refactor_ratio, num_example, batch_size, begin_epoch)optimizer_params={'learning_rate':learning_rate,'momentum':momentum,'wd':weight_decay,'lr_scheduler':lr_scheduler,'clip_gradient':None,'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0 }monitor = mx.mon.Monitor(iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else Noneif voc07_metric:valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3)else:valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3)mod.fit(train_iter,val_iter,eval_metric=MultiBoxMetric(),validation_metric=valid_metric,batch_end_callback=batch_end_callback,epoch_end_callback=epoch_end_callback,optimizer='sgd',optimizer_params=optimizer_params,begin_epoch=begin_epoch,num_epoch=end_epoch,initializer=mx.init.Xavier(),arg_params=args,aux_params=auxs,allow_missing=True,monitor=monitor)
这篇博客介绍了SSD算法的整体架构,旨在从宏观上加深对该算法的认识。从上面的代码介绍可以看出,在train_net函数中关于网络结构的构建是通过symbol_factory.py脚本的get_symbol_train函数进行的,因为网络结构的构建是SSD算法的核心,因此接下来一篇博客先来介绍关于网络结构构建的一些参数配置:
