作为一个toolbox,EasyPruner中主要由若干个接口函数构成,以下对每个接口函数进行了介绍:

rebuild

  1. rebuild(model, state_dict) -> loaded_model

作用:

加载剪枝后的模型权重,并使原始未剪枝大模型变成剪枝后的小模型。

参数说明:

model-用原始代码构建的模型对象
state_dict-剪枝后得到的模型权重对象,其中仅包含未剪枝通道的模型权重。
loaded_model-加载了保留通道权重的剪枝后模型对象。

示例:

    model = VGG()
    state_dict = torch.load("model_pruned_uniform.pt")
    model = rebuild(model , state_dict)

fastpruner

fastpruner(model,prune_factor=,input_dim = ,method = , onnx_file = )

作用:

用于基于快捷模式的剪枝。

参数说明:

net-被剪枝的模型对象。
prune_factor-剪枝系数,在不同剪枝方法下有不同的意义。float,默认0.5。
input_dim-模型对象输入的空间大小。list,默认[3,416,416]。
method-剪枝方法。str,默认’Ratio’。
‘Ratio’,全局通道按BN层放缩因子大小进行排序剪枝。prune_factor为保留的通道数所占的比例。
‘Uniform’,按照给定的flops限制每层按一样比例的通道数进行剪枝。prune_factor为剪枝后的模型保留的flops所占的比例。
‘Order’,按照给定的剪枝阈值进行剪枝。prune_factor为剪枝阈值,剪枝掉小于整个阈值的通道。
onnx_file-被剪枝模型的onnx文件路径。str,默认 None。

示例:

    model = VGG()
    state_dict = torch.load("model_pruned_uniform.pt")
    model = rebuild(model , state_dict)

getprunelayer

getprunelayer(net,input_dim = [3,224,224],onnx_file=None):

作用:

得到用于dislpay_layer的bn层名称

参数说明:

net-被剪枝的模型对象。
input_dim-模型对象输入的空间大小。list,默认[3,416,416]。
onnx_file-被剪枝模型的onnx文件路径。str,默认 None。

update_layer_grad_decay

### update_layer和 display_layer 参数解析
update_layer_grad_decay(model, norm_layer_names, optimizer, scaler =  , mask_dict = , epoch= ,epoch_decay = , iters = )

#model 
#norm_layer_names :被稀疏层的名称 , list
        norm_layer_names:
        1) 使用以下代码
        from easypruner.fastpruner import getprunelayer
        norm_layer_names = getprunelayer(net) 
        #如果输入不是224 
        #norm_layer_names = getprunelayer(net,input_dim = [3,224,224])
        2) 工具包config的成员变量  opt.sparsity_layernames
        3) mask.keys()
        4) 用户直指定的list


# optimizer :模型训练的优化器 torch.nn.optim
# lr        :模型训练的学习率 list
        这两个参数只需要给出一个就够了,如果都给了将以lr为准
        可以通过 optimizer 获得学习率,或者直接指定学习率,指定BN层weight和bias的学习率
        1)默认None,将不进行稀疏,不产生效果
        2)optimizer直接传入优化器,将自动获取每个参数的学习率。使用方便但是运行慢些。
        3)lr传入需要,bn层weight或bias有统一的学习率,需要用户自己获得lr并指定。使用有些不便但是运行快些。
           如果无法获得lr/不会获得lr,请直接传入optimizer。

# scaler :是否使用 amp.GradScaler,如yolov5训练代码
          !!剪枝yolov5的时候请把act 层 freeze,并且将ema.update(model) pass掉
          1)默认False
             如果使用amp.GradScaler,scaler参数设置为True
             否则为False

#mask_dict : 用于控制稀疏范围,None,float,dict
    mask的选取:
        1) 不使用,退化到netslim算法
        2)float,uniform方式获得的剪枝mask来进行maskL1
        3)dict,其他任何算法获得的剪枝mask来进行maskL1

#epoch  :当前的epoch number

#epoch_decay  :希望在 epoch_decay ,将weight 降到0 ,
       1)一般使用 int(0.75 * args.apoch)意味着3/4epochs时候将weight降到0
       2)可以尝试修改0.75为0.5或者其他的参数,用户可以自己调节

#iters : 一个epoch的迭代轮数
       1)请用户填入  dataloader的长度如   len(dataloader) / len(pbar)/ ...

display_layer

### update_layer和 display_layer 参数解析
update_layer_grad_decay(model, norm_layer_names, optimizer =  , lr =  ,scaler =  , mask_dict = , epoch= ,epoch_decay = , iters = )
display_layer(model ,norm_layer_names =  )

#model 
#norm_layer_names :被稀疏层的名称 , list
        norm_layer_names:
        1) 使用以下代码
        from easypruner.fastpruner import getprunelayer
        norm_layer_names = getprunelayer(net) 
        #如果输入不是224 
        #norm_layer_names = getprunelayer(net,input_dim = [3,224,224])
        2) 工具包config的成员变量  opt.sparsity_layernames
        3) mask.keys()
        4) 用户直指定的list


# optimizer :模型训练的优化器 torch.nn.optim
# lr        :模型训练的学习率 list
        这两个参数只需要给出一个就够了,如果都给了将以lr为准
        可以通过 optimizer 获得学习率,或者直接指定学习率,指定BN层weight和bias的学习率
        1)默认None,将不进行稀疏,不产生效果
        2)optimizer直接传入优化器,将自动获取每个参数的学习率。使用方便但是运行慢些。
        3)lr传入需要,bn层weight或bias有统一的学习率,需要用户自己获得lr并指定。使用有些不便但是运行快些。
           如果无法获得lr/不会获得lr,请直接传入optimizer。

# scaler :是否使用 amp.GradScaler,如yolov5训练代码
          !!剪枝yolov5的时候请把act 层 freeze,并且将ema.update(model) pass掉
          1)默认False
             如果使用amp.GradScaler,scaler参数设置为True
             否则为False

#mask_dict : 用于控制稀疏范围,None,float,dict
    mask的选取:
        1) 不使用,退化到netslim算法
        2)float,uniform方式获得的剪枝mask来进行maskL1
        3)dict,其他任何算法获得的剪枝mask来进行maskL1

#epoch  :当前的epoch number

#epoch_decay  :希望在 epoch_decay ,将weight 降到0 ,
       1)一般使用 int(0.75 * args.apoch)意味着3/4epochs时候将weight降到0
       2)可以尝试修改0.75为0.5或者其他的参数,用户可以自己调节

#iters : 一个epoch的迭代轮数
       1)请用户填入  dataloader的长度如   len(dataloader) / len(pbar)/ ...