MaskL1原理简介
综合评价
优势点:
1) 效果好,大部分情况下相较其他方法有显著优势。
权衡点:
1)需要在训练框架中找到梯度更新部分,并增加一些代码。<br /> 2)需要增加一个阶段进行稀疏。<br /> 2)操作步骤多一些。
操作步骤
步骤1. 常规训练
得到收敛后的模型,作为用于剪枝的 pretrained 模型。
步骤2. 修改训练代码,插入稀疏约束的代码
建议拷贝一份原训练代码train.py文件为train_sparsity.py 进行修改。 代码修改如下:
在导言区,
#导言区
from easypruner.regularize.sparsity import update_layer, display_layer
from easypruner.fastpruner import getprunelayer
在构建网络后获得可剪枝网络层中的BN算子列表:
#在构建网络net对象后
#norm_layer_names = getprunelayer(net,onnx_file="")#转onnx失败请自己指定一下文件路劲
norm_layer_names = getprunelayer(net)
在训练代码中的 loss.backward() 与 optimizer.step() 之间增加 update_layer 和 display_layer, 用于对所获得的BN算子进行稀疏约束,同时观察稀疏化情况:
scaler.scale(loss).backward()
####
prune_ratio = 0.5 # 需要设置为你要剪枝的通道数目比率(如,剪枝率为50%,则此处设置0.5)
update_layer(model, norm_layer_names, factor=0.01,scaler=True, mask_dict = prune_ratio, regular_method='L1') #yolov5训练中使用了scaler.scale(loss).backward()
if iter % 100:
display_layer(model, norm_layer_names)
####
optimizer.step()
步骤3. 稀疏训练
基于步骤2中修改好的稀疏约束的训练代码,令模型在 pretrianed 基础上进行稀疏化训练,超参配置建议与 pretrianed 一致。
步骤4. 剪枝
对稀疏训练后的模型进行Order剪枝(基于阈值的剪枝方法)。
Order剪枝代码为:
from easypruner import fastpruner
model.cpu()
fastpruner.fastpruner(model, prune_factor = 0.01, method= "Order", input_dim=[3,416,416])
# prune_factor 为 剪枝阈值
# 或传入自己提前转好的onnx
#fastpruner.fastpruner(model, prune_factor = 0.01, method= "Order", input_dim=[3,416,416],onnx_file = "runs/***/my_yolo.onnx") #prune_factor 为 剪枝阈值,onnx_file为转换的onnx文件。
model.to(device)
save_path = '/your_path/model_pruned.pt' #可选
torch.save(model.state_dict(),save_path) #可选
对于prune_factor,一般设置为0.01即满足大多数情况,如果想实现更精细的效果,可以进行如下环节:
(可选)子步骤4.1 观察网络各通道重要性打分的分布情况。
方法一、步骤2 第5行会打印训练过程中的分布情况变化,用停止训练时的分布情况即可。
方法二、加载稀疏训练后的模型,利用如下代码进行测试:
#导言区
from easypruner.regularize.sparsity import display_layer
from easypruner.fastpruner import getprunelayer
#在构建网络net对象后
norm_layer_names = getprunelayer(net)
#以下代码进行展示各个阈值下的剪枝情况
display_layer(model, norm_layer_names)
(可选)子步骤4.2 根据BN放缩因子分布情况,利用选择合适剪枝阈值,为Order方法剪枝准备。
例如,如果分布情况为:
<1e-1:0.654
<1e-2:0.521
<1e-3:0.489
<1e-4:0.488
<1e-5:0.488
<1e-6:0.487
<1e-7:0.487
...
则,阈值建议为0.001,即BN放缩因子小于1e-3的通道将被剪枝掉。
步骤5. 导出onnx文件
对剪枝后的模型对象,用标注pytorch转onnx的接口即可导出onnx文件。
#以下的model对象为剪枝微调后的model对象。
input_tensor = torch.randn([1, 3, 256, 512])
print ("Exporting to ONNX: ", onnx_save_name)
torch_onnx_out = torch.onnx.export(model, input_tensor, onnx_save_name,
export_params=True,
verbose=True,
input_names=['label'],
output_names=["synthesized"],
opset_version=11)
补充:rebuild函数
由于剪枝后模型的权重矩阵与模型定义的py文件中的参数量不一致,如果保存的是model.state_dict(),不能直接加载到原始代码建立的网络对象中。此时,需要用工具包提供的rebuild函数进行构建网络:
from easypruner.utils.rebuild import rebuild
model = YourSelfModelDef()
new_state_dict = torch.load(...)
model = rebuild(model , opt.new_state_dict)