模型剪裁教程

原文:https://pytorch.org/tutorials/intermediate/pruning_tutorial.html

作者Michela Paganini

最新的深度学习技术依赖于难以部署的过度参数化模型。 相反,已知生物神经网络使用有效的稀疏连通性。 为了减少内存,电池和硬件消耗,同时又不牺牲精度,在设备上部署轻量级模型并通过私有设备上计算来确保私密性,确定通过减少模型中的参数数量来压缩模型的最佳技术很重要。 在研究方面,剪裁用于研究参数过度配置和参数不足网络在学习动态方面的差异,以研究幸运的稀疏子网络的作用(“彩票”),以及初始化,作为破坏性的神经结构搜索技术等等。

在本教程中,您将学习如何使用torch.nn.utils.prune稀疏神经网络,以及如何扩展它以实现自己的自定义剪裁技术。

要求

"torch>=1.4.0a0+8e8a5e0"

  1. import torch
  2. from torch import nn
  3. import torch.nn.utils.prune as prune
  4. import torch.nn.functional as F

创建模型

在本教程中,我们使用 LeCun 等人,1998 年的 LeNet 架构。

  1. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  2. class LeNet(nn.Module):
  3. def __init__(self):
  4. super(LeNet, self).__init__()
  5. # 1 input image channel, 6 output channels, 3x3 square conv kernel
  6. self.conv1 = nn.Conv2d(1, 6, 3)
  7. self.conv2 = nn.Conv2d(6, 16, 3)
  8. self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5x5 image dimension
  9. self.fc2 = nn.Linear(120, 84)
  10. self.fc3 = nn.Linear(84, 10)
  11. def forward(self, x):
  12. x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
  13. x = F.max_pool2d(F.relu(self.conv2(x)), 2)
  14. x = x.view(-1, int(x.nelement() / x.shape[0]))
  15. x = F.relu(self.fc1(x))
  16. x = F.relu(self.fc2(x))
  17. x = self.fc3(x)
  18. return x
  19. model = LeNet().to(device=device)

检查模块

让我们检查一下 LeNet 模型中的(未剪裁)conv1层。 现在它将包含两个参数weightbias,并且没有缓冲区。

  1. module = model.conv1
  2. print(list(module.named_parameters()))

出:

  1. [('weight', Parameter containing:
  2. tensor([[[[ 0.1552, 0.0102, -0.1944],
  3. [ 0.0263, 0.1374, -0.3139],
  4. [ 0.2838, 0.1943, 0.0948]]],
  5. [[[-0.0296, -0.2514, 0.1300],
  6. [ 0.0756, -0.3155, -0.2900],
  7. [-0.1840, 0.1143, -0.0120]]],
  8. [[[-0.2383, -0.3022, 0.2295],
  9. [-0.0050, 0.2485, -0.3230],
  10. [-0.1317, -0.0054, 0.2659]]],
  11. [[[-0.0932, 0.1316, 0.0670],
  12. [ 0.0572, -0.1845, 0.0870],
  13. [ 0.1372, 0.1080, 0.0324]]],
  14. [[[ 0.0908, -0.3280, 0.0365],
  15. [-0.3108, 0.2317, -0.2271],
  16. [ 0.1171, 0.2113, -0.2259]]],
  17. [[[ 0.0407, 0.0512, 0.0954],
  18. [-0.0437, 0.0302, -0.1317],
  19. [ 0.2573, 0.0626, 0.0883]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
  20. tensor([-0.1803, 0.1331, -0.3267, 0.3173, -0.0349, 0.1828], device='cuda:0',
  21. requires_grad=True))]
  1. print(list(module.named_buffers()))

出:

  1. []

剪裁模块

要剪裁模块(在此示例中,为 LeNet 架构的conv1层),请首先从torch.nn.utils.prune中可用的那些技术中选择一种剪裁技术(或通过子类化BasePruningMethod实现您自己的东西)。 然后,指定模块和该模块中要剪裁的参数的名称。 最后,使用所选剪裁技术所需的适当关键字参数,指定剪裁参数。

在此示例中,我们将在conv1层中名为weight的参数中随机剪裁 30% 的连接。 模块作为第一个参数传递给函数; name使用其字符串标识符在该模块中标识参数; amount表示与剪裁的连接百分比(如果是介于 0 和 1 之间的浮点数),或表示与剪裁的连接的绝对数量(如果它是非负整数)。

  1. prune.random_unstructured(module, name="weight", amount=0.3)

剪裁是通过从参数中删除weight并将其替换为名为weight_orig的新参数(即,将"_orig"附加到初始参数name)来进行的。 weight_orig存储未剪裁的张量版本。 bias未剪裁,因此它将保持完整。

  1. print(list(module.named_parameters()))

出:

  1. [('bias', Parameter containing:
  2. tensor([-0.1803, 0.1331, -0.3267, 0.3173, -0.0349, 0.1828], device='cuda:0',
  3. requires_grad=True)), ('weight_orig', Parameter containing:
  4. tensor([[[[ 0.1552, 0.0102, -0.1944],
  5. [ 0.0263, 0.1374, -0.3139],
  6. [ 0.2838, 0.1943, 0.0948]]],
  7. [[[-0.0296, -0.2514, 0.1300],
  8. [ 0.0756, -0.3155, -0.2900],
  9. [-0.1840, 0.1143, -0.0120]]],
  10. [[[-0.2383, -0.3022, 0.2295],
  11. [-0.0050, 0.2485, -0.3230],
  12. [-0.1317, -0.0054, 0.2659]]],
  13. [[[-0.0932, 0.1316, 0.0670],
  14. [ 0.0572, -0.1845, 0.0870],
  15. [ 0.1372, 0.1080, 0.0324]]],
  16. [[[ 0.0908, -0.3280, 0.0365],
  17. [-0.3108, 0.2317, -0.2271],
  18. [ 0.1171, 0.2113, -0.2259]]],
  19. [[[ 0.0407, 0.0512, 0.0954],
  20. [-0.0437, 0.0302, -0.1317],
  21. [ 0.2573, 0.0626, 0.0883]]]], device='cuda:0', requires_grad=True))]

通过以上选择的剪裁技术生成的剪裁掩码将保存为名为weight_mask的模块缓冲区(即,将"_mask"附加到初始参数name)。

  1. print(list(module.named_buffers()))

出:

  1. [('weight_mask', tensor([[[[1., 1., 0.],
  2. [0., 0., 1.],
  3. [1., 0., 1.]]],
  4. [[[1., 1., 1.],
  5. [1., 1., 1.],
  6. [1., 1., 1.]]],
  7. [[[1., 1., 0.],
  8. [1., 0., 0.],
  9. [1., 0., 1.]]],
  10. [[[1., 1., 1.],
  11. [1., 0., 1.],
  12. [1., 1., 1.]]],
  13. [[[1., 1., 1.],
  14. [0., 0., 1.],
  15. [1., 1., 1.]]],
  16. [[[1., 0., 0.],
  17. [1., 0., 1.],
  18. [1., 0., 0.]]]], device='cuda:0'))]

为了使正向传播不更改即可工作,需要存在weight属性。 在torch.nn.utils.prune中实现的剪裁技术计算权重的剪裁版本(通过将掩码与原始参数组合)并将它们存储在属性weight中。 注意,这不再是module的参数,现在只是一个属性。

  1. print(module.weight)

出:

  1. tensor([[[[ 0.1552, 0.0102, -0.0000],
  2. [ 0.0000, 0.0000, -0.3139],
  3. [ 0.2838, 0.0000, 0.0948]]],
  4. [[[-0.0296, -0.2514, 0.1300],
  5. [ 0.0756, -0.3155, -0.2900],
  6. [-0.1840, 0.1143, -0.0120]]],
  7. [[[-0.2383, -0.3022, 0.0000],
  8. [-0.0050, 0.0000, -0.0000],
  9. [-0.1317, -0.0000, 0.2659]]],
  10. [[[-0.0932, 0.1316, 0.0670],
  11. [ 0.0572, -0.0000, 0.0870],
  12. [ 0.1372, 0.1080, 0.0324]]],
  13. [[[ 0.0908, -0.3280, 0.0365],
  14. [-0.0000, 0.0000, -0.2271],
  15. [ 0.1171, 0.2113, -0.2259]]],
  16. [[[ 0.0407, 0.0000, 0.0000],
  17. [-0.0437, 0.0000, -0.1317],
  18. [ 0.2573, 0.0000, 0.0000]]]], device='cuda:0',
  19. grad_fn=<MulBackward0>)

最后,使用 PyTorch 的forward_pre_hooks在每次向前传递之前应用剪裁。 具体来说,当剪裁module时(如我们在此处所做的那样),它将为与之关联的每个参数获取forward_pre_hook进行剪裁。 在这种情况下,由于到目前为止我们只剪裁了名称为weight的原始参数,因此只会出现一个钩子。

  1. print(module._forward_pre_hooks)

出:

  1. OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7fda78275e48>)])

为了完整起见,我们现在也可以剪裁bias,以查看module的参数,缓冲区,挂钩和属性如何变化。 仅出于尝试另一种剪裁技术的目的,在此我们按 L1 范数剪裁偏差中的 3 个最小条目,如l1_unstructured剪裁函数中所实现的。

  1. prune.l1_unstructured(module, name="bias", amount=3)

现在,我们希望命名参数同时包含weight_orig(从前)和bias_orig。 缓冲区将包括weight_maskbias_mask。 两个张量的剪裁后的版本将作为模块属性存在,并且该模块现在将具有两个forward_pre_hooks

  1. print(list(module.named_parameters()))

出:

  1. [('weight_orig', Parameter containing:
  2. tensor([[[[ 0.1552, 0.0102, -0.1944],
  3. [ 0.0263, 0.1374, -0.3139],
  4. [ 0.2838, 0.1943, 0.0948]]],
  5. [[[-0.0296, -0.2514, 0.1300],
  6. [ 0.0756, -0.3155, -0.2900],
  7. [-0.1840, 0.1143, -0.0120]]],
  8. [[[-0.2383, -0.3022, 0.2295],
  9. [-0.0050, 0.2485, -0.3230],
  10. [-0.1317, -0.0054, 0.2659]]],
  11. [[[-0.0932, 0.1316, 0.0670],
  12. [ 0.0572, -0.1845, 0.0870],
  13. [ 0.1372, 0.1080, 0.0324]]],
  14. [[[ 0.0908, -0.3280, 0.0365],
  15. [-0.3108, 0.2317, -0.2271],
  16. [ 0.1171, 0.2113, -0.2259]]],
  17. [[[ 0.0407, 0.0512, 0.0954],
  18. [-0.0437, 0.0302, -0.1317],
  19. [ 0.2573, 0.0626, 0.0883]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
  20. tensor([-0.1803, 0.1331, -0.3267, 0.3173, -0.0349, 0.1828], device='cuda:0',
  21. requires_grad=True))]
  1. print(list(module.named_buffers()))

出:

  1. [('weight_mask', tensor([[[[1., 1., 0.],
  2. [0., 0., 1.],
  3. [1., 0., 1.]]],
  4. [[[1., 1., 1.],
  5. [1., 1., 1.],
  6. [1., 1., 1.]]],
  7. [[[1., 1., 0.],
  8. [1., 0., 0.],
  9. [1., 0., 1.]]],
  10. [[[1., 1., 1.],
  11. [1., 0., 1.],
  12. [1., 1., 1.]]],
  13. [[[1., 1., 1.],
  14. [0., 0., 1.],
  15. [1., 1., 1.]]],
  16. [[[1., 0., 0.],
  17. [1., 0., 1.],
  18. [1., 0., 0.]]]], device='cuda:0')), ('bias_mask', tensor([0., 0., 1., 1., 0., 1.], device='cuda:0'))]
  1. print(module.bias)

出:

  1. tensor([-0.0000, 0.0000, -0.3267, 0.3173, -0.0000, 0.1828], device='cuda:0',
  2. grad_fn=<MulBackward0>)
  1. print(module._forward_pre_hooks)

出:

  1. OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7fda78275e48>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7fda80bbe470>)])

迭代式剪裁

一个模块中的同一参数可以被多次剪裁,各种剪裁调用的效果等于连接应用的各种蒙版的组合。 PruningContainercompute_mask方法可处理新遮罩与旧遮罩的组合。

例如,假设我们现在想进一步剪裁module.weight,这一次是使用沿着张量的第 0 轴的结构化剪裁(第 0 轴对应于卷积层的输出通道,并且对于conv1具有 6 维) ,基于渠道的 L2 规范。 这可以通过ln_structuredn=2dim=0函数来实现。

  1. prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
  2. # As we can verify, this will zero out all the connections corresponding to
  3. # 50% (3 out of 6) of the channels, while preserving the action of the
  4. # previous mask.
  5. print(module.weight)

出:

  1. tensor([[[[ 0.0000, 0.0000, -0.0000],
  2. [ 0.0000, 0.0000, -0.0000],
  3. [ 0.0000, 0.0000, 0.0000]]],
  4. [[[-0.0296, -0.2514, 0.1300],
  5. [ 0.0756, -0.3155, -0.2900],
  6. [-0.1840, 0.1143, -0.0120]]],
  7. [[[-0.2383, -0.3022, 0.0000],
  8. [-0.0050, 0.0000, -0.0000],
  9. [-0.1317, -0.0000, 0.2659]]],
  10. [[[-0.0000, 0.0000, 0.0000],
  11. [ 0.0000, -0.0000, 0.0000],
  12. [ 0.0000, 0.0000, 0.0000]]],
  13. [[[ 0.0908, -0.3280, 0.0365],
  14. [-0.0000, 0.0000, -0.2271],
  15. [ 0.1171, 0.2113, -0.2259]]],
  16. [[[ 0.0000, 0.0000, 0.0000],
  17. [-0.0000, 0.0000, -0.0000],
  18. [ 0.0000, 0.0000, 0.0000]]]], device='cuda:0',
  19. grad_fn=<MulBackward0>)

现在,对应的钩子将为torch.nn.utils.prune.PruningContainer类型,并将存储应用于weight参数的剪裁历史。

  1. for hook in module._forward_pre_hooks.values():
  2. if hook._tensor_name == "weight": # select out the correct hook
  3. break
  4. print(list(hook)) # pruning history in the container

出:

  1. [<torch.nn.utils.prune.RandomUnstructured object at 0x7fda78275e48>, <torch.nn.utils.prune.LnStructured object at 0x7fda80071828>]

序列化剪裁的模型

所有相关的张量,包括掩码缓冲区和用于计算剪裁的张量的原始参数,都存储在模型的state_dict中,因此可以根据需要轻松地序列化和保存。

  1. print(model.state_dict().keys())

出:

  1. odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

删除剪裁重新参数化

要使剪裁永久化,请删除weight_origweight_mask的重新参数化,然后删除forward_pre_hook,我们可以使用torch.nn.utils.pruneremove函数。 请注意,这不会撤消剪裁,好像从未发生过。 而是通过将参数weight重新分配给模型参数(剪裁后的版本)来使其永久不变。

删除重新参数化之前:

  1. print(list(module.named_parameters()))

出:

  1. [('weight_orig', Parameter containing:
  2. tensor([[[[ 0.1552, 0.0102, -0.1944],
  3. [ 0.0263, 0.1374, -0.3139],
  4. [ 0.2838, 0.1943, 0.0948]]],
  5. [[[-0.0296, -0.2514, 0.1300],
  6. [ 0.0756, -0.3155, -0.2900],
  7. [-0.1840, 0.1143, -0.0120]]],
  8. [[[-0.2383, -0.3022, 0.2295],
  9. [-0.0050, 0.2485, -0.3230],
  10. [-0.1317, -0.0054, 0.2659]]],
  11. [[[-0.0932, 0.1316, 0.0670],
  12. [ 0.0572, -0.1845, 0.0870],
  13. [ 0.1372, 0.1080, 0.0324]]],
  14. [[[ 0.0908, -0.3280, 0.0365],
  15. [-0.3108, 0.2317, -0.2271],
  16. [ 0.1171, 0.2113, -0.2259]]],
  17. [[[ 0.0407, 0.0512, 0.0954],
  18. [-0.0437, 0.0302, -0.1317],
  19. [ 0.2573, 0.0626, 0.0883]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
  20. tensor([-0.1803, 0.1331, -0.3267, 0.3173, -0.0349, 0.1828], device='cuda:0',
  21. requires_grad=True))]
  1. print(list(module.named_buffers()))

出:

  1. [('weight_mask', tensor([[[[0., 0., 0.],
  2. [0., 0., 0.],
  3. [0., 0., 0.]]],
  4. [[[1., 1., 1.],
  5. [1., 1., 1.],
  6. [1., 1., 1.]]],
  7. [[[1., 1., 0.],
  8. [1., 0., 0.],
  9. [1., 0., 1.]]],
  10. [[[0., 0., 0.],
  11. [0., 0., 0.],
  12. [0., 0., 0.]]],
  13. [[[1., 1., 1.],
  14. [0., 0., 1.],
  15. [1., 1., 1.]]],
  16. [[[0., 0., 0.],
  17. [0., 0., 0.],
  18. [0., 0., 0.]]]], device='cuda:0')), ('bias_mask', tensor([0., 0., 1., 1., 0., 1.], device='cuda:0'))]
  1. print(module.weight)

出:

  1. tensor([[[[ 0.0000, 0.0000, -0.0000],
  2. [ 0.0000, 0.0000, -0.0000],
  3. [ 0.0000, 0.0000, 0.0000]]],
  4. [[[-0.0296, -0.2514, 0.1300],
  5. [ 0.0756, -0.3155, -0.2900],
  6. [-0.1840, 0.1143, -0.0120]]],
  7. [[[-0.2383, -0.3022, 0.0000],
  8. [-0.0050, 0.0000, -0.0000],
  9. [-0.1317, -0.0000, 0.2659]]],
  10. [[[-0.0000, 0.0000, 0.0000],
  11. [ 0.0000, -0.0000, 0.0000],
  12. [ 0.0000, 0.0000, 0.0000]]],
  13. [[[ 0.0908, -0.3280, 0.0365],
  14. [-0.0000, 0.0000, -0.2271],
  15. [ 0.1171, 0.2113, -0.2259]]],
  16. [[[ 0.0000, 0.0000, 0.0000],
  17. [-0.0000, 0.0000, -0.0000],
  18. [ 0.0000, 0.0000, 0.0000]]]], device='cuda:0',
  19. grad_fn=<MulBackward0>)

删除重新参数化后:

  1. prune.remove(module, 'weight')
  2. print(list(module.named_parameters()))

出:

  1. [('bias_orig', Parameter containing:
  2. tensor([-0.1803, 0.1331, -0.3267, 0.3173, -0.0349, 0.1828], device='cuda:0',
  3. requires_grad=True)), ('weight', Parameter containing:
  4. tensor([[[[ 0.0000, 0.0000, -0.0000],
  5. [ 0.0000, 0.0000, -0.0000],
  6. [ 0.0000, 0.0000, 0.0000]]],
  7. [[[-0.0296, -0.2514, 0.1300],
  8. [ 0.0756, -0.3155, -0.2900],
  9. [-0.1840, 0.1143, -0.0120]]],
  10. [[[-0.2383, -0.3022, 0.0000],
  11. [-0.0050, 0.0000, -0.0000],
  12. [-0.1317, -0.0000, 0.2659]]],
  13. [[[-0.0000, 0.0000, 0.0000],
  14. [ 0.0000, -0.0000, 0.0000],
  15. [ 0.0000, 0.0000, 0.0000]]],
  16. [[[ 0.0908, -0.3280, 0.0365],
  17. [-0.0000, 0.0000, -0.2271],
  18. [ 0.1171, 0.2113, -0.2259]]],
  19. [[[ 0.0000, 0.0000, 0.0000],
  20. [-0.0000, 0.0000, -0.0000],
  21. [ 0.0000, 0.0000, 0.0000]]]], device='cuda:0', requires_grad=True))]
  1. print(list(module.named_buffers()))

出:

  1. [('bias_mask', tensor([0., 0., 1., 1., 0., 1.], device='cuda:0'))]

剪裁模型中的多个参数

通过指定所需的剪裁技术和参数,我们可以轻松地剪裁网络中的多个张量,也许根据它们的类型,如在本示例中将看到的那样。

  1. new_model = LeNet()
  2. for name, module in new_model.named_modules():
  3. # prune 20% of connections in all 2D-conv layers
  4. if isinstance(module, torch.nn.Conv2d):
  5. prune.l1_unstructured(module, name='weight', amount=0.2)
  6. # prune 40% of connections in all linear layers
  7. elif isinstance(module, torch.nn.Linear):
  8. prune.l1_unstructured(module, name='weight', amount=0.4)
  9. print(dict(new_model.named_buffers()).keys()) # to verify that all masks exist

出:

  1. dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])

全局剪裁

到目前为止,我们仅查看了通常称为“局部”剪裁的情况,即通过比较每个条目的统计信息(权重,激活度,梯度等)来逐个剪裁模型中的张量的做法。 到该张量中的其他条目。 但是,一种通用且可能更强大的技术是通过删除(例如)删除整个模型中最低的 20% 的连接,而不是删除每一层中最低的 20% 的连接来一次剪裁模型。 这很可能导致每个层的剪裁百分比不同。 让我们看看如何使用torch.nn.utils.prune中的global_unstructured进行操作。

  1. model = LeNet()
  2. parameters_to_prune = (
  3. (model.conv1, 'weight'),
  4. (model.conv2, 'weight'),
  5. (model.fc1, 'weight'),
  6. (model.fc2, 'weight'),
  7. (model.fc3, 'weight'),
  8. )
  9. prune.global_unstructured(
  10. parameters_to_prune,
  11. pruning_method=prune.L1Unstructured,
  12. amount=0.2,
  13. )

现在,我们可以检查每个剪裁参数中引起的稀疏性,该稀疏性将不等于每层中的 20%。 但是,全局稀疏度将(大约)为 20%。

  1. print(
  2. "Sparsity in conv1.weight: {:.2f}%".format(
  3. 100\. * float(torch.sum(model.conv1.weight == 0))
  4. / float(model.conv1.weight.nelement())
  5. )
  6. )
  7. print(
  8. "Sparsity in conv2.weight: {:.2f}%".format(
  9. 100\. * float(torch.sum(model.conv2.weight == 0))
  10. / float(model.conv2.weight.nelement())
  11. )
  12. )
  13. print(
  14. "Sparsity in fc1.weight: {:.2f}%".format(
  15. 100\. * float(torch.sum(model.fc1.weight == 0))
  16. / float(model.fc1.weight.nelement())
  17. )
  18. )
  19. print(
  20. "Sparsity in fc2.weight: {:.2f}%".format(
  21. 100\. * float(torch.sum(model.fc2.weight == 0))
  22. / float(model.fc2.weight.nelement())
  23. )
  24. )
  25. print(
  26. "Sparsity in fc3.weight: {:.2f}%".format(
  27. 100\. * float(torch.sum(model.fc3.weight == 0))
  28. / float(model.fc3.weight.nelement())
  29. )
  30. )
  31. print(
  32. "Global sparsity: {:.2f}%".format(
  33. 100\. * float(
  34. torch.sum(model.conv1.weight == 0)
  35. + torch.sum(model.conv2.weight == 0)
  36. + torch.sum(model.fc1.weight == 0)
  37. + torch.sum(model.fc2.weight == 0)
  38. + torch.sum(model.fc3.weight == 0)
  39. )
  40. / float(
  41. model.conv1.weight.nelement()
  42. + model.conv2.weight.nelement()
  43. + model.fc1.weight.nelement()
  44. + model.fc2.weight.nelement()
  45. + model.fc3.weight.nelement()
  46. )
  47. )
  48. )

出:

  1. Sparsity in conv1.weight: 3.70%
  2. Sparsity in conv2.weight: 8.10%
  3. Sparsity in fc1.weight: 22.05%
  4. Sparsity in fc2.weight: 12.29%
  5. Sparsity in fc3.weight: 8.45%
  6. Global sparsity: 20.00%

使用自定义剪裁函数扩展torch.nn.utils.prune

要实现自己的剪裁函数,可以通过继承BasePruningMethod基类的子类来扩展nn.utils.prune模块,这与所有其他剪裁方法一样。 基类为您实现以下方法:__call__apply_maskapplypruneremove。 除了一些特殊情况外,您无需为新的剪裁技术重新实现这些方法。 但是,您将必须实现__init__(构造器)和compute_mask(有关如何根据剪裁技术的逻辑为给定张量计算掩码的说明)。 另外,您将必须指定此技术实现的剪裁类型(支持的选项为globalstructuredunstructured)。 需要确定在迭代应用剪裁的情况下如何组合蒙版。 换句话说,当剪裁预剪裁的参数时,当前的剪裁技术应作用于参数的未剪裁部分。 指定PRUNING_TYPE将使PruningContainer(处理剪裁掩码的迭代应用)正确识别要剪裁的参数。

例如,假设您要实现一种剪裁技术,以剪裁张量中的所有其他条目(或者-如果先前已剪裁过张量,则剪裁张量的其余未剪裁部分)。 这将是PRUNING_TYPE='unstructured',因为它作用于层中的单个连接,而不作用于整个单元/通道('structured'),或作用于不同的参数('global')。

  1. class FooBarPruningMethod(prune.BasePruningMethod):
  2. """Prune every other entry in a tensor
  3. """
  4. PRUNING_TYPE = 'unstructured'
  5. def compute_mask(self, t, default_mask):
  6. mask = default_mask.clone()
  7. mask.view(-1)[::2] = 0
  8. return mask

现在,要将其应用于nn.Module中的参数,还应该提供一个简单的函数来实例化该方法并将其应用。

  1. def foobar_unstructured(module, name):
  2. """Prunes tensor corresponding to parameter called `name` in `module`
  3. by removing every other entry in the tensors.
  4. Modifies module in place (and also return the modified module)
  5. by:
  6. 1) adding a named buffer called `name+'_mask'` corresponding to the
  7. binary mask applied to the parameter `name` by the pruning method.
  8. The parameter `name` is replaced by its pruned version, while the
  9. original (unpruned) parameter is stored in a new parameter named
  10. `name+'_orig'`.
  11. Args:
  12. module (nn.Module): module containing the tensor to prune
  13. name (string): parameter name within `module` on which pruning
  14. will act.
  15. Returns:
  16. module (nn.Module): modified (i.e. pruned) version of the input
  17. module
  18. Examples:
  19. >>> m = nn.Linear(3, 4)
  20. >>> foobar_unstructured(m, name='bias')
  21. """
  22. FooBarPruningMethod.apply(module, name)
  23. return module

试试吧!

  1. model = LeNet()
  2. foobar_unstructured(model.fc3, name='bias')
  3. print(model.fc3.bias_mask)

出:

  1. tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])

脚本的总运行时间:(0 分钟 0.135 秒)

下载 Python 源码:pruning_tutorial.py

下载 Jupyter 笔记本:pruning_tutorial.ipynb

由 Sphinx 画廊生成的画廊