模型拓扑结构变量
self.norm_layer_names
self.prec_layers
self.succ_layers
作用:
内容说明:
1) self.norm_layer_names, self.prec_layers , self.succ_layers 分别是三个list, 每个list里面的元素也是list,在此list里面的元素为字符串,表示模型的名称
2) self.norm_layer_names为bn层的名称, self.prec_layers 为BN层对应的前一层的名称, self.succ_layers为BN层对应的后一层的名称
3)剪枝的时候,由BN索引得到剪枝mask,同时剪掉self.prec_layers 的输出通道,以及self.succ_layers的输入通道
4)举例
案例一
self.norm_layer_names = [[bn1], … ]
self.prec_layers = [[conv1], …]
self.succ_layers = [[layers.0.conv1], …]
案例二
self.norm_layer_names = [[bn1 , layer1.0.bn2 , … ] , [layer1.0.bn1] ]
self.prec_layers = [[conv1 , layer1.0.conv2 , ….] , [layer1.0.conv1]…]
self.succ_layers = [[layer1.0.conv1 , layer1.1.conv1 , … ], [layer1.1.conv1] …]
有add算子连接的BN要统一mask
修改操作:
1.stride=2的卷积不剪枝
#stride =2
prec_layers_ = []
norm_layer_names_ = []
succ_layers_ = []
if self.prune_stride_2 == False:
for index,convlist in enumerate( self.prec_layers ):
flag = False
for conv_name in convlist:
container_names = conv_name.split('.')
if hasattr(self.model, 'module'):
container = self.model.module
else:
container = self.model
for container_name in container_names:
container = container._modules[container_name]
import pdb;pdb.set_trace()
assert hasattr(container, 'stride')
if container.stride==(2,2):
flag = True
break
if flag:
continue
else:
prec_layers_.append( self.prec_layers[index] )
norm_layer_names_.append(self.norm_layer_names[index])
succ_layers_.append( self.succ_layers[index])
self.prec_layers = prec_layers_
self.norm_layer_names = norm_layer_names_
self.succ_layers = succ_layers_
2.去掉暂时无法处理的层比如 非depthwise的 group卷积
prec_layers_ = []
norm_layer_names_ = []
succ_layers_ = []
# groups != out_chanel
for index, conv_preclist in enumerate(self.prec_layers ):
flag = True
conv_succlist = self.succ_layers[index]
for conv_name in conv_succlist + conv_preclist:
container_names = conv_name.split('.')
if hasattr(self.model, 'module'):
container = self.model.module
else:
container = self.model
for container_name in container_names:
container = container._modules[container_name]
if not hasattr(container, 'groups') :
print( conv_name ,"continue" )
continue
assert hasattr(container, 'groups')
assert hasattr(container, 'out_channels')
if container.groups !=1 and container.groups !=container.out_channels:
flag = False
print("Warning!!!! groups != out_channels not support!!!!")
break
if flag:
prec_layers_.append( self.prec_layers[index] )
norm_layer_names_.append(self.norm_layer_names[index])
succ_layers_.append( self.succ_layers[index])
剪枝操作遍历:
for id,norm_layer_name in enumerate(norm_layer_names):
# 1. prune source normalization layer,先剪掉剪掉源层(BN)
for name in norm_layer_name:
for weight_name in grouped_weight_names[name]:
pass
# 2. prune target succeeding conv/linear/... layers 剪掉源层(BN)后继层的(输入通道)
for prune_layer_name in succ_layers[id]: #遍历后继
for weight_name in grouped_weight_names[prune_layer_name]: #遍历后继的各层(如bn层的 weights ,bias ,meaning ,var)
pass
# 3. prune target preceding conv/linear/... layers #剪掉源层(BN)前继层的(输入通道)也就是该modules的卷积
for prune_layer_name in prec_layers[id]: #同上
for weight_name in grouped_weight_names[prune_layer_name]:#同上
pass