PyTorch 中通道在最后的内存格式(beta)

原文:https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html

作者Vitaly Fedyunin

什么是通道在最后

通道在最后的内存格式是在保留内存尺寸的顺序中对 NCHW 张量进行排序的另一种方法。 通道最后一个张量的排序方式使通道成为最密集的维度(又称为每像素存储图像)。

例如,NCHW 张量的经典(连续)存储(在我们的示例中是具有 3 个颜色通道的两个2x2图像)如下所示:

classic_memory_format

通道最后的存储格式对数据的排序方式不同:

channels_last_memory_format

Pytorch 通过使用现有的跨步结构支持内存格式(并提供与现有模型(包括 eager,JIT 和 TorchScript)的向后兼容性)。 例如,通道在最后的格式中的10x3x16x16批量的步幅等于(768, 1, 48, 3)

通道最后一个存储格式仅适用于 4D NCWH 张量。

  1. import torch
  2. N, C, H, W = 10, 3, 32, 32

内存格式 API

这是在连续和通道最后存储格式之间转换张量的方法。

经典 PyTorch 连续张量

  1. x = torch.empty(N, C, H, W)
  2. print(x.stride()) # Ouputs: (3072, 1024, 32, 1)

出:

  1. (3072, 1024, 32, 1)

转换运算符

  1. x = x.contiguous(memory_format=torch.channels_last)
  2. print(x.shape) # Outputs: (10, 3, 32, 32) as dimensions order preserved
  3. print(x.stride()) # Outputs: (3072, 1, 96, 3)

出:

  1. torch.Size([10, 3, 32, 32])
  2. (3072, 1, 96, 3)

返回连续

  1. x = x.contiguous(memory_format=torch.contiguous_format)
  2. print(x.stride()) # Outputs: (3072, 1024, 32, 1)

出:

  1. (3072, 1024, 32, 1)

替代选择

  1. x = x.to(memory_format=torch.channels_last)
  2. print(x.stride()) # Ouputs: (3072, 1, 96, 3)

出:

  1. (3072, 1, 96, 3)

格式检查

  1. print(x.is_contiguous(memory_format=torch.channels_last)) # Ouputs: True

出:

  1. True

最后创建为渠道

  1. x = torch.empty(N, C, H, W, memory_format=torch.channels_last)
  2. print(x.stride()) # Ouputs: (3072, 1, 96, 3)

出:

  1. (3072, 1, 96, 3)

clone保留内存格式

  1. y = x.clone()
  2. print(y.stride()) # Ouputs: (3072, 1, 96, 3)

出:

  1. (3072, 1, 96, 3)

tocudafloat…保留内存格式

  1. if torch.cuda.is_available():
  2. y = x.cuda()
  3. print(y.stride()) # Ouputs: (3072, 1, 96, 3)

出:

  1. (3072, 1, 96, 3)

empty_like*_like运算符保留内存格式

  1. y = torch.empty_like(x)
  2. print(y.stride()) # Ouputs: (3072, 1, 96, 3)

出:

  1. (3072, 1, 96, 3)

点向运算符保留内存格式

  1. z = x + y
  2. print(z.stride()) # Ouputs: (3072, 1, 96, 3)

出:

  1. (3072, 1, 96, 3)

转换,Batchnorm模块支持通道在最后(仅适用于CudNN >= 7.6

  1. if torch.backends.cudnn.version() >= 7603:
  2. input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device="cuda", requires_grad=True)
  3. model = torch.nn.Conv2d(8, 4, 3).cuda().float()
  4. input = input.contiguous(memory_format=torch.channels_last)
  5. model = model.to(memory_format=torch.channels_last) # Module parameters need to be Channels Last
  6. out = model(input)
  7. print(out.is_contiguous(memory_format=torch.channels_last)) # Ouputs: True

出:

  1. True

性能提升

在具有张量核心支持的 Nvidia 硬件上观察到了最大的性能提升。 在运行 Nvidia 提供的 AMP(自动混合精度)训练脚本时,我们可以将性能提高 22% 以上。

python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 ./data

  1. # opt_level = O2
  2. # keep_batchnorm_fp32 = None <class 'NoneType'>
  3. # loss_scale = None <class 'NoneType'>
  4. # CUDNN VERSION: 7603
  5. # => creating model 'resnet50'
  6. # Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights.
  7. # Defaults for this optimization level are:
  8. # enabled : True
  9. # opt_level : O2
  10. # cast_model_type : torch.float16
  11. # patch_torch_functions : False
  12. # keep_batchnorm_fp32 : True
  13. # master_weights : True
  14. # loss_scale : dynamic
  15. # Processing user overrides (additional kwargs that are not None)...
  16. # After processing overrides, optimization options are:
  17. # enabled : True
  18. # opt_level : O2
  19. # cast_model_type : torch.float16
  20. # patch_torch_functions : False
  21. # keep_batchnorm_fp32 : True
  22. # master_weights : True
  23. # loss_scale : dynamic
  24. # Epoch: [0][10/125] Time 0.866 (0.866) Speed 230.949 (230.949) Loss 0.6735125184 (0.6735) Prec@1 61.000 (61.000) Prec@5 100.000 (100.000)
  25. # Epoch: [0][20/125] Time 0.259 (0.562) Speed 773.481 (355.693) Loss 0.6968704462 (0.6852) Prec@1 55.000 (58.000) Prec@5 100.000 (100.000)
  26. # Epoch: [0][30/125] Time 0.258 (0.461) Speed 775.089 (433.965) Loss 0.7877287269 (0.7194) Prec@1 51.500 (55.833) Prec@5 100.000 (100.000)
  27. # Epoch: [0][40/125] Time 0.259 (0.410) Speed 771.710 (487.281) Loss 0.8285319805 (0.7467) Prec@1 48.500 (54.000) Prec@5 100.000 (100.000)
  28. # Epoch: [0][50/125] Time 0.260 (0.380) Speed 770.090 (525.908) Loss 0.7370464802 (0.7447) Prec@1 56.500 (54.500) Prec@5 100.000 (100.000)
  29. # Epoch: [0][60/125] Time 0.258 (0.360) Speed 775.623 (555.728) Loss 0.7592862844 (0.7472) Prec@1 51.000 (53.917) Prec@5 100.000 (100.000)
  30. # Epoch: [0][70/125] Time 0.258 (0.345) Speed 774.746 (579.115) Loss 1.9698858261 (0.9218) Prec@1 49.500 (53.286) Prec@5 100.000 (100.000)
  31. # Epoch: [0][80/125] Time 0.260 (0.335) Speed 770.324 (597.659) Loss 2.2505953312 (1.0879) Prec@1 50.500 (52.938) Prec@5 100.000 (100.000)

传递--channels-last true允许以通道在最后的格式运行模型,观察到 22% 的表现增益。

python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 --channels-last true ./data

  1. # opt_level = O2
  2. # keep_batchnorm_fp32 = None <class 'NoneType'>
  3. # loss_scale = None <class 'NoneType'>
  4. #
  5. # CUDNN VERSION: 7603
  6. #
  7. # => creating model 'resnet50'
  8. # Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights.
  9. #
  10. # Defaults for this optimization level are:
  11. # enabled : True
  12. # opt_level : O2
  13. # cast_model_type : torch.float16
  14. # patch_torch_functions : False
  15. # keep_batchnorm_fp32 : True
  16. # master_weights : True
  17. # loss_scale : dynamic
  18. # Processing user overrides (additional kwargs that are not None)...
  19. # After processing overrides, optimization options are:
  20. # enabled : True
  21. # opt_level : O2
  22. # cast_model_type : torch.float16
  23. # patch_torch_functions : False
  24. # keep_batchnorm_fp32 : True
  25. # master_weights : True
  26. # loss_scale : dynamic
  27. #
  28. # Epoch: [0][10/125] Time 0.767 (0.767) Speed 260.785 (260.785) Loss 0.7579724789 (0.7580) Prec@1 53.500 (53.500) Prec@5 100.000 (100.000)
  29. # Epoch: [0][20/125] Time 0.198 (0.482) Speed 1012.135 (414.716) Loss 0.7007197738 (0.7293) Prec@1 49.000 (51.250) Prec@5 100.000 (100.000)
  30. # Epoch: [0][30/125] Time 0.198 (0.387) Speed 1010.977 (516.198) Loss 0.7113101482 (0.7233) Prec@1 55.500 (52.667) Prec@5 100.000 (100.000)
  31. # Epoch: [0][40/125] Time 0.197 (0.340) Speed 1013.023 (588.333) Loss 0.8943189979 (0.7661) Prec@1 54.000 (53.000) Prec@5 100.000 (100.000)
  32. # Epoch: [0][50/125] Time 0.198 (0.312) Speed 1010.541 (641.977) Loss 1.7113249302 (0.9551) Prec@1 51.000 (52.600) Prec@5 100.000 (100.000)
  33. # Epoch: [0][60/125] Time 0.198 (0.293) Speed 1011.163 (683.574) Loss 5.8537774086 (1.7716) Prec@1 50.500 (52.250) Prec@5 100.000 (100.000)
  34. # Epoch: [0][70/125] Time 0.198 (0.279) Speed 1011.453 (716.767) Loss 5.7595844269 (2.3413) Prec@1 46.500 (51.429) Prec@5 100.000 (100.000)
  35. # Epoch: [0][80/125] Time 0.198 (0.269) Speed 1011.827 (743.883) Loss 2.8196096420 (2.4011) Prec@1 47.500 (50.938) Prec@5 100.000 (100.000)

以下模型列表完全支持通道在最后,并在 Volta 设备上显示了 8%-35% 的表现增益:alexnetmnasnet0_5mnasnet0_75mnasnet1_0mnasnet1_3mobilenet_v2resnet101resnet152resnet18resnet34resnet50resnext50_32x4dshufflenet_v2_x0_5shufflenet_v2_x1_0shufflenet_v2_x1_5shufflenet_v2_x2_0squeezenet1_0squeezenet1_1vgg11vgg11_bnvgg13vgg13_bnvgg16vgg16_bnvgg19vgg19_bnwide_resnet101_2wide_resnet50_2

转换现有模型

通道在最后支持不受现有模型的限制,因为只要输入格式正确,任何模型都可以转换为通道在最后,并通过图传播格式。

  1. # Need to be done once, after model initialization (or load)
  2. model = model.to(memory_format=torch.channels_last) # Replace with your model
  3. # Need to be done for every input
  4. input = input.to(memory_format=torch.channels_last) # Replace with your input
  5. output = model(input)

但是,并非所有运算符都完全转换为支持通道在最后(通常返回连续输出)。 这意味着您需要根据支持的运算符列表来验证已使用运算符的列表,或将内存格式检查引入急切的执行模式并运行模型。

运行以下代码后,如果运算符的输出与输入的存储格式不匹配,运算符将引发异常。

  1. def contains_cl(args):
  2. for t in args:
  3. if isinstance(t, torch.Tensor):
  4. if t.is_contiguous(memory_format=torch.channels_last) and not t.is_contiguous():
  5. return True
  6. elif isinstance(t, list) or isinstance(t, tuple):
  7. if contains_cl(list(t)):
  8. return True
  9. return False
  10. def print_inputs(args, indent=''):
  11. for t in args:
  12. if isinstance(t, torch.Tensor):
  13. print(indent, t.stride(), t.shape, t.device, t.dtype)
  14. elif isinstance(t, list) or isinstance(t, tuple):
  15. print(indent, type(t))
  16. print_inputs(list(t), indent=indent + ' ')
  17. else:
  18. print(indent, t)
  19. def check_wrapper(fn):
  20. name = fn.__name__
  21. def check_cl(*args, **kwargs):
  22. was_cl = contains_cl(args)
  23. try:
  24. result = fn(*args, **kwargs)
  25. except Exception as e:
  26. print("`{}` inputs are:".format(name))
  27. print_inputs(args)
  28. print('-------------------')
  29. raise e
  30. failed = False
  31. if was_cl:
  32. if isinstance(result, torch.Tensor):
  33. if result.dim() == 4 and not result.is_contiguous(memory_format=torch.channels_last):
  34. print("`{}` got channels_last input, but output is not channels_last:".format(name),
  35. result.shape, result.stride(), result.device, result.dtype)
  36. failed = True
  37. if failed and True:
  38. print("`{}` inputs are:".format(name))
  39. print_inputs(args)
  40. raise Exception(
  41. 'Operator `{}` lost channels_last property'.format(name))
  42. return result
  43. return check_cl
  44. old_attrs = dict()
  45. def attribute(m):
  46. old_attrs[m] = dict()
  47. for i in dir(m):
  48. e = getattr(m, i)
  49. exclude_functions = ['is_cuda', 'has_names', 'numel',
  50. 'stride', 'Tensor', 'is_contiguous', '__class__']
  51. if i not in exclude_functions and not i.startswith('_') and '__call__' in dir(e):
  52. try:
  53. old_attrs[m][i] = e
  54. setattr(m, i, check_wrapper(e))
  55. except Exception as e:
  56. print(i)
  57. print(e)
  58. attribute(torch.Tensor)
  59. attribute(torch.nn.functional)
  60. attribute(torch)

出:

  1. Optional
  2. '_Optional' object has no attribute '__name__'

如果您发现不支持通道在最后的张量的运算符并且想要贡献力量,请随时使用以下开发人员指南

下面的代码是恢复火炬的属性。

  1. for (m, attrs) in old_attrs.items():
  2. for (k,v) in attrs.items():
  3. setattr(m, k, v)

要做的工作

仍有许多事情要做,例如:

  • 解决 N1HW 和 NC11 张量的歧义;
  • 测试分布式训练支持;
  • 提高运算符覆盖率。

如果您有反馈和/或改进建议,请通过创建 ISSUE 来通知我们。

脚本的总运行时间:(0 分钟 2.300 秒)

下载 Python 源码:memory_format_tutorial.py

下载 Jupyter 笔记本:memory_format_tutorial.ipynb

由 Sphinx 画廊生成的画廊