下面是使用快速剪枝模式对YOLOv5进行剪枝的过程,相较常规网络的剪枝,YOLO剪枝的特殊性主要包括:

  1. YOLO中存在的SiLU等算子无法直接基于pytorch自带函数进行onnx转换:

——需要用YOLO自带工具(包含了算子替换等操作)转换onnx后,传入剪枝函数。

  1. YOLO中onnx转换工具自动合并了BN:

——需要改为不合并,因为剪枝工具需要BN。

步骤一. 常规训练。

步骤二. 修改load模型的代码,去掉合并BN。

进入models/experimental.py, 复制attempt_load函数为attempt_load_without_fuse函数,去掉加载模型时的fuse()调用:

  1. def attempt_load_without_fuse(weights, map_location=None):
  2. # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
  3. model = Ensemble()
  4. for w in weights if isinstance(weights, list) else [weights]:
  5. attempt_download(w)
  6. model.append(torch.load(w, map_location=map_location)['model'].float().eval()) # load FP32 model ; 此处无 fuse()!!!!!!!!!!!!!!!!!!!
  7. # Compatibility updates
  8. for m in model.modules():
  9. if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
  10. m.inplace = True # pytorch 1.7.0 compatibility
  11. elif type(m) is Conv:
  12. m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
  13. if len(model) == 1:
  14. return model[-1] # return model
  15. else:
  16. print('Ensemble created with %s\n' % weights)
  17. for k in ['names', 'stride']:
  18. setattr(model, k, getattr(model[-1], k))
  19. return model # return ensemble

步骤三.转换带BN的onnx,用于网络图分析。

复制export.py 为export_with_bn.py,并修改

  1. from models.experimental import attempt_load

from models.experimental import attempt_load_without_fuse as attempt_load

并将常规训练好的pt转onnx,注意加上—train:

python export_with_bn.py --train --weights runs/train/exp/weights/best.pt

如果YOLO版本比较低,可能会报错没有—train的选项,则直接去export_with_bn中修改export函数,在参数中增加training和do_constant_folding两项:

torch.onnx.export(model, img, f, verbose=False, opset_version=11, input_names=['images'],
                          output_names=['classes', 'boxes'] if y is None else ['output'],
                          dynamic_axes={'images': {0: 'batch_size'}, 'output': {0: 'batch_size'}},training=torch.onnx.TrainingMode.TRAINING ,
                              do_constant_folding=False)

步骤四. 复制train.py 为 train_prune_finetune.py,并添加代码。

在网络初始化完成后,添加:

from easypruner import  fastpruner
model.cpu()
fastpruner.fastpruner(model, prune_factor = 0.5, method="Uniform", input_dim=[3,416,416] )
model.to(device)

找到网络初始化的位置,仅仅增加##中代码即可。

    # Model
    pretrained = weights.endswith('.pt')
    if pretrained:
        with torch_distributed_zero_first(rank):
            weights = attempt_download(weights)  # download if not found locally
        ckpt = torch.load(weights, map_location=device)  # load checkpoint
        model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors'))#.to(device)  # create
        exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else []  # exclude keys
        state_dict = ckpt['model'].float().state_dict()  # to FP32
        state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude)  # intersect
        model.load_state_dict(state_dict, strict=False)  # load
        logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights))  # report
    else:
        model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors'))#.to(device)  # create
    with torch_distributed_zero_first(rank):
        check_dataset(data_dict)  # check

    '''
    以上为原代码
    '''
    ########################
    from easypruner import  fastpruner
    model.cpu()
    fastpruner.fastpruner(model, flops_saved = 0.5, method="Uniform", input_dim=[3,416,416], onnx_file="runs/train/exp10000/weights/best.onnx" ) #参数依次为:模型对象、剪枝后模型保留比率、剪枝方法、输入大小、步骤三获得的没合并BN的onnx文件
    model.to(device)
    #######################
    '''
    以下为原代码
    '''

    train_path = data_dict['train']
    test_path = data_dict['val']



    # Freeze
    freeze = []  # parameter names to freeze (full or partial)
    for k, v in model.named_parameters():
        v.requires_grad = True  # train all layers
        if any(x in k for x in freeze):
            print('freezing %s' % k)
            v.requires_grad = False

步骤五. 基于步骤一的预训练权重,执行微调。

例如:

CUDA_VISIBLE_DEVICES=2,3 python train_prune.py --cfg models/yolov5m.yaml  --img-size 400  --batch-size 64 --data data/person_1classes.yaml --weights runs/train/exp2/weights/best.pt

步骤六,重复步骤三即可获得剪枝后的onnx文件。

注意:YOLOV5模型不需要rebuild,直接attempt_load即可,直接finetune/test。