下面是使用快速剪枝模式对YOLOv5进行剪枝的过程,相较常规网络的剪枝,YOLO剪枝的特殊性主要包括:
- YOLO中存在的SiLU等算子无法直接基于pytorch自带函数进行onnx转换:
——需要用YOLO自带工具(包含了算子替换等操作)转换onnx后,传入剪枝函数。
- YOLO中onnx转换工具自动合并了BN:
——需要改为不合并,因为剪枝工具需要BN。
步骤一. 常规训练。
步骤二. 修改load模型的代码,去掉合并BN。
进入models/experimental.py, 复制attempt_load函数为attempt_load_without_fuse函数,去掉加载模型时的fuse()调用:
def attempt_load_without_fuse(weights, map_location=None):
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]:
attempt_download(w)
model.append(torch.load(w, map_location=map_location)['model'].float().eval()) # load FP32 model ; 此处无 fuse()!!!!!!!!!!!!!!!!!!!
# Compatibility updates
for m in model.modules():
if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
m.inplace = True # pytorch 1.7.0 compatibility
elif type(m) is Conv:
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
if len(model) == 1:
return model[-1] # return model
else:
print('Ensemble created with %s\n' % weights)
for k in ['names', 'stride']:
setattr(model, k, getattr(model[-1], k))
return model # return ensemble
步骤三.转换带BN的onnx,用于网络图分析。
复制export.py 为export_with_bn.py,并修改
from models.experimental import attempt_load
为
from models.experimental import attempt_load_without_fuse as attempt_load
并将常规训练好的pt转onnx,注意加上—train:
python export_with_bn.py --train --weights runs/train/exp/weights/best.pt
如果YOLO版本比较低,可能会报错没有—train的选项,则直接去export_with_bn中修改export函数,在参数中增加training和do_constant_folding两项:
torch.onnx.export(model, img, f, verbose=False, opset_version=11, input_names=['images'],
output_names=['classes', 'boxes'] if y is None else ['output'],
dynamic_axes={'images': {0: 'batch_size'}, 'output': {0: 'batch_size'}},training=torch.onnx.TrainingMode.TRAINING ,
do_constant_folding=False)
步骤四. 复制train.py 为 train_prune_finetune.py,并添加代码。
在网络初始化完成后,添加:
from easypruner import fastpruner
model.cpu()
fastpruner.fastpruner(model, prune_factor = 0.5, method="Uniform", input_dim=[3,416,416] )
model.to(device)
找到网络初始化的位置,仅仅增加##中代码即可。
# Model
pretrained = weights.endswith('.pt')
if pretrained:
with torch_distributed_zero_first(rank):
weights = attempt_download(weights) # download if not found locally
ckpt = torch.load(weights, map_location=device) # load checkpoint
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors'))#.to(device) # create
exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
state_dict = ckpt['model'].float().state_dict() # to FP32
state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
model.load_state_dict(state_dict, strict=False) # load
logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
else:
model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors'))#.to(device) # create
with torch_distributed_zero_first(rank):
check_dataset(data_dict) # check
'''
以上为原代码
'''
########################
from easypruner import fastpruner
model.cpu()
fastpruner.fastpruner(model, flops_saved = 0.5, method="Uniform", input_dim=[3,416,416], onnx_file="runs/train/exp10000/weights/best.onnx" ) #参数依次为:模型对象、剪枝后模型保留比率、剪枝方法、输入大小、步骤三获得的没合并BN的onnx文件
model.to(device)
#######################
'''
以下为原代码
'''
train_path = data_dict['train']
test_path = data_dict['val']
# Freeze
freeze = [] # parameter names to freeze (full or partial)
for k, v in model.named_parameters():
v.requires_grad = True # train all layers
if any(x in k for x in freeze):
print('freezing %s' % k)
v.requires_grad = False
步骤五. 基于步骤一的预训练权重,执行微调。
例如:
CUDA_VISIBLE_DEVICES=2,3 python train_prune.py --cfg models/yolov5m.yaml --img-size 400 --batch-size 64 --data data/person_1classes.yaml --weights runs/train/exp2/weights/best.pt
步骤六,重复步骤三即可获得剪枝后的onnx文件。
注意:YOLOV5模型不需要rebuild,直接attempt_load即可,直接finetune/test。