模型拓扑结构变量

  1. self.norm_layer_names
  2. self.prec_layers
  3. self.succ_layers

作用:

为了保存网络结构关系,用于剪枝

内容说明:

1) self.norm_layer_names, self.prec_layers , self.succ_layers 分别是三个list, 每个list里面的元素也是list,在此list里面的元素为字符串,表示模型的名称
2) self.norm_layer_names为bn层的名称, self.prec_layers 为BN层对应的前一层的名称, self.succ_layers为BN层对应的后一层的名称
3)剪枝的时候,由BN索引得到剪枝mask,同时剪掉self.prec_layers 的输出通道,以及self.succ_layers的输入通道
4)举例
案例一
image.png
self.norm_layer_names = [[bn1], … ]
self.prec_layers = [[conv1], …]
self.succ_layers = [[layers.0.conv1], …]
案例二
image.png
self.norm_layer_names = [[bn1 , layer1.0.bn2 , … ] , [layer1.0.bn1] ]
self.prec_layers = [[conv1 , layer1.0.conv2 , ….] , [layer1.0.conv1]…]
self.succ_layers = [[layer1.0.conv1 , layer1.1.conv1 , … ], [layer1.1.conv1] …]
有add算子连接的BN要统一mask

修改操作:

1.stride=2的卷积不剪枝

  1. #stride =2
  2. prec_layers_ = []
  3. norm_layer_names_ = []
  4. succ_layers_ = []
  5. if self.prune_stride_2 == False:
  6. for index,convlist in enumerate( self.prec_layers ):
  7. flag = False
  8. for conv_name in convlist:
  9. container_names = conv_name.split('.')
  10. if hasattr(self.model, 'module'):
  11. container = self.model.module
  12. else:
  13. container = self.model
  14. for container_name in container_names:
  15. container = container._modules[container_name]
  16. import pdb;pdb.set_trace()
  17. assert hasattr(container, 'stride')
  18. if container.stride==(2,2):
  19. flag = True
  20. break
  21. if flag:
  22. continue
  23. else:
  24. prec_layers_.append( self.prec_layers[index] )
  25. norm_layer_names_.append(self.norm_layer_names[index])
  26. succ_layers_.append( self.succ_layers[index])
  27. self.prec_layers = prec_layers_
  28. self.norm_layer_names = norm_layer_names_
  29. self.succ_layers = succ_layers_

2.去掉暂时无法处理的层比如 非depthwise的 group卷积

        prec_layers_ = []
        norm_layer_names_ = []
        succ_layers_ = []
        # groups != out_chanel
        for index,  conv_preclist  in enumerate(self.prec_layers ):
            flag = True
            conv_succlist = self.succ_layers[index]
            for conv_name in conv_succlist + conv_preclist:
                container_names = conv_name.split('.')
                if hasattr(self.model, 'module'):
                    container = self.model.module
                else:
                    container = self.model

                for container_name in container_names:
                    container = container._modules[container_name]
                if not  hasattr(container, 'groups') :
                    print(  conv_name ,"continue"  )
                    continue
                assert  hasattr(container, 'groups')
                assert  hasattr(container, 'out_channels')
                if container.groups !=1 and container.groups !=container.out_channels:
                    flag = False
                    print("Warning!!!!  groups != out_channels not support!!!!")
                    break
            if flag:
                prec_layers_.append( self.prec_layers[index] )
                norm_layer_names_.append(self.norm_layer_names[index])
                succ_layers_.append( self.succ_layers[index])

剪枝操作遍历:

for id,norm_layer_name in enumerate(norm_layer_names):


  # 1. prune source normalization layer,先剪掉剪掉源层(BN)
  for name  in norm_layer_name:
    for weight_name in grouped_weight_names[name]:
      pass

  # 2. prune target succeeding conv/linear/... layers     剪掉源层(BN)后继层的(输入通道)
  for prune_layer_name in succ_layers[id]:   #遍历后继
    for weight_name in grouped_weight_names[prune_layer_name]: #遍历后继的各层(如bn层的 weights ,bias ,meaning ,var)
      pass 

  # 3. prune target preceding conv/linear/... layers    #剪掉源层(BN)前继层的(输入通道)也就是该modules的卷积
  for prune_layer_name in prec_layers[id]: #同上
    for weight_name in grouped_weight_names[prune_layer_name]:#同上
      pass