参考来源:
知乎:一文学会 Pytorch 中的 einsum
实例代码:https://github.com/Ldpe2G/CodingForFun/tree/master/einsum_ex

爱因斯坦求和约定

爱因斯坦求和约定(einsum)提供了一套既简洁又优雅的规则,可实现包括但不限于:向量内积、向量外积、矩阵乘法、转置和张量收缩(tensor contraction)等张量操作,熟练运用 einsum 可以很方便的实现复杂的张量操作,而且不容易出错。

三条基本规则

首先看下 einsum 实现矩阵乘法的例子:

  1. a = torch.rand(2,3)
  2. b = torch.rand(3,4)
  3. c = torch.einsum("ik,kj->ij", [a, b])
  4. # 等价操作 torch.mm(a, b)

其中需要重点关注的是 einsum 的第一个参数 "ik,kj->ij",该字符串(下文以 equation 表示)表示了输入和输出张量的维度。equation 中的箭头左边表示输入张量,以逗号分割每个输入张量,箭头右边则表示输出张量。表示维度的字符只能是 26 个英文字母 ‘a‘ - ‘z‘。
einsum 的第二个参数表示实际的输入张量列表,其数量要与 equation 中的输入数量对应。同时对应每个张量的子 equation 的字符个数要与张量的真实维度对应,比如 "ik,kj->ij" 表示输入和输出张量都是两维的。
equation 中的字符也可以理解为索引,就是输出张量的某个位置的值,是怎么从输入张量中得到的,比如上面矩阵乘法的输出 c 的某个点 c[i, j] 的值是通过 a[i, k]b[i, k] 沿着 k 这个维度做内积得到的。

接着介绍两个基本概念,自由索引**Free indices**)和求和索引**Summation indices**):

  • 自由索引:出现在箭头右边的索引,比如上面的例子就是 ij
  • 求和索引:只出现在箭头左边的索引,表示中间计算结果需要这个维度上求和之后才能得到输出,比如上面的例子就是 k

接着是介绍三条基本规则:

  1. 规则一,equation 箭头左边,在不同输入之间重复出现的索引表示,把输入张量沿着该维度做乘法操作,比如还是以上面矩阵乘法为例, "ik,kj->ij"k 在输入中重复出现,所以就是把 ab 沿着 k 这个维度作相乘操作;
  2. 规则二,只出现在 **equation** 箭头左边的索引,表示中间计算结果需要在这个维度上求和,也就是上面提到的求和索引;
  3. 规则三,equation 箭头右边的索引顺序可以是任意的,比如上面的 "ik,kj->ij" 如果写成 "ik,kj->ji",那么就是返回输出结果的转置,用户只需要定义好索引的顺序,转置操作会在 einsum 内部完成。

特殊规则

特殊规则有两条:

  1. equation 可以不写包括箭头在内的右边部分,那么在这种情况下,输出张量的维度会根据默认规则推导。就是把输入中只出现一次的索引取出来,然后按字母表顺序排列,比如上面的矩阵乘法 "ik,kj->ij" 也可以简化为 "ik,kj",根据默认规则,输出就是 "ij" 与原来一样;
  2. equation 中支持 "..." 省略号,用于表示用户并不关心的索引,比如只对一个高维张量的最后两维做转置可以这么写:
    1. a = torch.randn(2,3,5,7,9)
    2. # i = 7, j = 9
    3. b = torch.einsum('...ij->...ji', [a])

实际例子解读

接下来将展示 13 个具体的例子,在这些例子中会将 Pytorch einsum 与对应的 Pytorch 张量接口和 python 简单的循环展开实现做对比,希望读者看完这些例子之后能轻松掌握 einsum 的基本用法。
实验代码 github 链接:https://github.com/Ldpe2G/CodingForFun/tree/master/einsum_ex

1. 提取矩阵对角线元素

  1. import torch
  2. import numpy as np
  3. a = torch.arange(9).reshape(3, 3)
  4. # i = 3
  5. torch_ein_out = torch.einsum('ii->i', [a]).numpy()
  6. torch_org_out = torch.diagonal(a, 0).numpy()
  7. np_a = a.numpy()
  8. # 循环展开实现
  9. np_out = np.empty((3,), dtype=np.int32)
  10. # 自由索引外循环
  11. for i in range(0, 3):
  12. # 求和索引内循环
  13. # 这个例子并没有求和索引,
  14. # 所以相当于是1
  15. sum_result = 0
  16. for inner in range(0, 1):
  17. sum_result += np_a[i, i]
  18. np_out[i] = sum_result
  19. print("input:\n", np_a)
  20. print("torch ein out: \n", torch_ein_out)
  21. print("torch org out: \n", torch_org_out)
  22. print("numpy out: \n", np_out)
  23. print("is np_out == torch_ein_out ?", np.allclose(torch_ein_out, np_out))
  24. print("is torch_org_out == torch_ein_out ?", np.allclose(torch_ein_out, torch_org_out))
  25. # 终端打印结果
  26. # input:
  27. # [[0 1 2]
  28. # [3 4 5]
  29. # [6 7 8]]
  30. # torch ein out:
  31. # [0 4 8]
  32. # torch org out:
  33. # [0 4 8]
  34. # numpy out:
  35. # [0 4 8]
  36. # is np_out == torch_ein_out ? True
  37. # is torch_org_out == torch_ein_out ? True

2. 矩阵转置

  1. import torch
  2. import numpy as np
  3. a = torch.arange(6).reshape(2, 3)
  4. # i = 2, j = 3
  5. torch_ein_out = torch.einsum('ij->ji', [a]).numpy()
  6. torch_org_out = torch.transpose(a, 0, 1).numpy()
  7. np_a = a.numpy()
  8. # 循环展开实现
  9. np_out = np.empty((3, 2), dtype=np.int32)
  10. # 自由索引外循环
  11. for j in range(0, 3):
  12. for i in range(0, 2):
  13. # 求和索引内循环
  14. # 这个例子并没有求和索引
  15. # 所以相当于是1
  16. sum_result = 0
  17. for inner in range(0, 1):
  18. sum_result += np_a[i, j]
  19. np_out[j, i] = sum_result
  20. print("input:\n", np_a)
  21. print("torch ein out: \n", torch_ein_out)
  22. print("torch org out: \n", torch_org_out)
  23. print("numpy out: \n", np_out)
  24. print("is np_out == torch_org_out ?", np.allclose(torch_ein_out, np_out))
  25. print("is torch_ein_out == torch_org_out ?", np.allclose(torch_ein_out, torch_org_out))
  26. # 终端打印结果
  27. # input:
  28. # [[0 1 2]
  29. # [3 4 5]]
  30. # torch ein out:
  31. # [[0 3]
  32. # [1 4]
  33. # [2 5]]
  34. # torch org out:
  35. # [[0 3]
  36. # [1 4]
  37. # [2 5]]
  38. # numpy out:
  39. # [[0 3]
  40. # [1 4]
  41. # [2 5]]
  42. # is np_out == torch_org_out ? True
  43. # is torch_ein_out == torch_org_out ? True

3. permute 高维张量转置

  1. import torch
  2. import numpy as np
  3. a = torch.randn(2,3,5,7,9)
  4. # i = 7, j = 9
  5. torch_ein_out = torch.einsum('...ij->...ji', [a]).numpy()
  6. torch_org_out = a.permute(0, 1, 2, 4, 3).numpy()
  7. np_a = a.numpy()
  8. # 循环展开实现
  9. np_out = np.empty((2,3,5,9,7), dtype=np.float32)
  10. # 自由索引外循环
  11. for j in range(0, 9):
  12. for i in range(0, 7):
  13. # 求和索引内循环
  14. # 这个例子没有求和索引
  15. sum_result = 0
  16. for inner in range(0, 1):
  17. sum_result += np_a[..., i, j]
  18. np_out[..., j, i] = sum_result
  19. print("is np_out == torch_org_out ?", np.allclose(torch_ein_out, np_out))
  20. print("is torch_ein_out == torch_org_out ?", np.allclose(torch_ein_out, torch_org_out))
  21. # 终端打印结果
  22. # is np_out == torch_org_out ? True
  23. # is torch_ein_out == torch_org_out ? True

4. reduce sum

  1. import torch
  2. import numpy as np
  3. a = torch.arange(6).reshape(2, 3)
  4. # i = 2, j = 3
  5. torch_ein_out = torch.einsum('ij->', [a]).numpy()
  6. torch_org_out = torch.sum(a).numpy()
  7. np_a = a.numpy()
  8. # 循环展开实现
  9. np_out = np.empty((1, ), dtype=np.int32)
  10. # 自由索引外循环
  11. # 这个例子中没有自由索引
  12. # 相当于所有维度都加一起
  13. for o in range(0 ,1):
  14. # 求和索引内循环
  15. # 这个例子中,i 和 j
  16. # 都是求和索引
  17. sum_result = 0
  18. for i in range(0, 2):
  19. for j in range(0, 3):
  20. sum_result += np_a[i, j]
  21. np_out[o] = sum_result
  22. print("input:\n", np_a)
  23. print("torch ein out: \n", torch_ein_out)
  24. print("torch org out: \n", torch_org_out)
  25. print("numpy out: \n", np_out)
  26. print("is np_out == torch_ein_out ?", np.allclose(torch_ein_out, np_out))
  27. print("is torch_org_out == torch_ein_out ?", np.allclose(torch_ein_out, torch_org_out))
  28. # 终端打印结果
  29. # input:
  30. # [[0 1 2]
  31. # [3 4 5]]
  32. # torch ein out:
  33. # 15
  34. # torch org out:
  35. # 15
  36. # numpy out:
  37. # [15]
  38. # is np_out == torch_ein_out ? True
  39. # is torch_org_out == torch_ein_out ? True

5. 矩阵按列求和

  1. import torch
  2. import numpy as np
  3. a = torch.arange(6).reshape(2, 3)
  4. # i = 2, j = 3
  5. torch_ein_out = torch.einsum('ij->j', [a]).numpy()
  6. torch_org_out = torch.sum(a, dim=0).numpy()
  7. np_a = a.numpy()
  8. # 循环展开实现
  9. np_out = np.empty((3, ), dtype=np.int32)
  10. # 自由索引外循环
  11. # 这个例子中是 j
  12. for j in range(0, 3):
  13. # 求和索引内循环
  14. # 这个例子中是 i
  15. sum_result = 0
  16. for i in range(0, 2):
  17. sum_result += np_a[i, j]
  18. np_out[j] = sum_result
  19. print("input:\n", np_a)
  20. print("torch ein out: \n", torch_ein_out)
  21. print("torch org out: \n", torch_org_out)
  22. print("numpy out: \n", np_out)
  23. print("is np_out == torch_ein_out ?", np.allclose(torch_org_out, np_out))
  24. print("is torch_org_out == torch_ein_out ?", np.allclose(torch_org_out, torch_ein_out))
  25. # 终端打印输出
  26. # input:
  27. # [[0 1 2]
  28. # [3 4 5]]
  29. # torch ein out:
  30. # [3 5 7]
  31. # torch org out:
  32. # [3 5 7]
  33. # numpy out:
  34. # [3 5 7]
  35. # is np_out == torch_ein_out ? True
  36. # is torch_org_out == torch_ein_out ? True

6. 矩阵向量乘法

  1. import torch
  2. import numpy as np
  3. a = torch.arange(6).reshape(2, 3)
  4. b = torch.arange(3)
  5. # i = 2, k = 3
  6. torch_ein_out = torch.einsum('ik,k->i', [a, b]).numpy()
  7. # 等价形式,可以省略箭头和输出
  8. torch_ein_out2 = torch.einsum('ik,k', [a, b]).numpy()
  9. torch_org_out = torch.mv(a, b).numpy()
  10. np_a = a.numpy()
  11. np_b = b.numpy()
  12. # 循环展开实现
  13. np_out = np.empty((2, ), dtype=np.int32)
  14. # 自由索引外循环
  15. # 这个例子是 i
  16. for i in range(0, 2):
  17. # 求和索引内循环
  18. # 这个例子中是 k
  19. sum_result = 0
  20. for k in range(0, 3):
  21. sum_result += np_a[i, k] * np_b[k]
  22. np_out[i] = sum_result
  23. print("matrix a:\n", np_a)
  24. print("vector b:\n", np_b)
  25. print("torch ein out: \n", torch_ein_out)
  26. print("torch ein out2: \n", torch_ein_out2)
  27. print("torch org out: \n", torch_org_out)
  28. print("numpy out: \n", np_out)
  29. print("is np_out == torch_ein_out ?", np.allclose(torch_ein_out, np_out))
  30. print("is torch_ein_out2 == torch_ein_out ?", np.allclose(torch_ein_out2, torch_ein_out))
  31. print("is torch_org_out == torch_ein_out ?", np.allclose(torch_org_out, torch_ein_out))
  32. # 终端打印输出
  33. # matrix a:
  34. # [[0 1 2]
  35. # [3 4 5]]
  36. # vector b:
  37. # [0 1 2]
  38. # torch ein out:
  39. # [ 5 14]
  40. # torch ein out2:
  41. # [ 5 14]
  42. # torch org out:
  43. # [ 5 14]
  44. # numpy out:
  45. # [ 5 14]
  46. # is np_out == torch_ein_out ? True
  47. # is torch_ein_out2 == torch_ein_out ? True
  48. # is torch_org_out == torch_ein_out ? True

7. 矩阵乘法

  1. import torch
  2. import numpy as np
  3. a = torch.arange(6).reshape(2, 3)
  4. b = torch.arange(15).reshape(3, 5)
  5. # i = 2, k = 3, j = 5
  6. torch_ein_out = torch.einsum('ik,kj->ij', [a, b]).numpy()
  7. # 等价形式,可以省略箭头和输出
  8. torch_ein_out2 = torch.einsum('ik,kj', [a, b]).numpy()
  9. torch_org_out = torch.mm(a, b).numpy()
  10. np_a = a.numpy()
  11. np_b = b.numpy()
  12. # 循环展开实现
  13. np_out = np.empty((2, 5), dtype=np.int32)
  14. # 自由索引外循环
  15. # 这个例子是 i 和 j
  16. for i in range(0, 2):
  17. for j in range(0, 5):
  18. # 求和索引内循环
  19. # 这个例子是 k
  20. sum_result = 0
  21. for k in range(0, 3):
  22. sum_result += np_a[i, k] * np_b[k, j]
  23. np_out[i, j] = sum_result
  24. print("matrix a:\n", np_a)
  25. print("matrix b:\n", np_b)
  26. print("torch ein out: \n", torch_ein_out)
  27. print("torch ein out2: \n", torch_ein_out2)
  28. print("torch org out: \n", torch_org_out)
  29. print("numpy out: \n", np_out)
  30. print("is numpy == torch_ein_out ?", np.allclose(torch_ein_out, np_out))
  31. print("is torch_ein_out2 == torch_ein_out ?", np.allclose(torch_ein_out2, torch_ein_out))
  32. print("is torch_org_out == torch_ein_out ?", np.allclose(torch_org_out, torch_ein_out))
  33. # 终端打印输出
  34. # matrix a:
  35. # [[0 1 2]
  36. # [3 4 5]]
  37. # matrix b:
  38. # [[ 0 1 2 3 4]
  39. # [ 5 6 7 8 9]
  40. # [10 11 12 13 14]]
  41. # torch ein out:
  42. # [[ 25 28 31 34 37]
  43. # [ 70 82 94 106 118]]
  44. # torch ein out2:
  45. # [[ 25 28 31 34 37]
  46. # [ 70 82 94 106 118]]
  47. # torch org out:
  48. # [[ 25 28 31 34 37]
  49. # [ 70 82 94 106 118]]
  50. # numpy out:
  51. # [[ 25 28 31 34 37]
  52. # [ 70 82 94 106 118]]
  53. # is numpy == torch_ein_out ? True
  54. # is torch_ein_out2 == torch_ein_out ? True
  55. # is torch_org_out == torch_ein_out ? True

8. 向量内积

  1. import torch
  2. import numpy as np
  3. a = torch.arange(3)
  4. b = torch.arange(3, 6) # [3, 4, 5]
  5. # i = 3
  6. torch_ein_out = torch.einsum('i,i->', [a, b]).numpy()
  7. # 等价形式,可以省略箭头和输出
  8. torch_ein_out2 = torch.einsum('i,i', [a, b]).numpy()
  9. torch_org_out = torch.dot(a, b).numpy()
  10. np_a = a.numpy()
  11. np_b = b.numpy()
  12. # 循环展开实现
  13. np_out = np.empty((1, ), dtype=np.int32)
  14. # 自由索引外循环
  15. # 这个例子没有自由索引
  16. for o in range(0, 1):
  17. # 求和索引内循环
  18. # 这个例子是 i
  19. sum_result = 0
  20. for i in range(0, 3):
  21. sum_result += np_a[i] * np_b[i]
  22. np_out[o] = sum_result
  23. print("vector a:\n", np_a)
  24. print("vector b:\n", np_b)
  25. print("torch ein out: \n", torch_ein_out)
  26. print("torch ein out2: \n", torch_ein_out2)
  27. print("torch org out: \n", torch_org_out)
  28. print("numpy out: \n", np_out)
  29. print("is np_out == torch_ein_out ?", np.allclose(torch_ein_out, np_out))
  30. print("is torch_ein_out2 == torch_ein_out ?", np.allclose(torch_ein_out2, torch_ein_out))
  31. print("is torch_org_out == torch_ein_out ?", np.allclose(torch_org_out, torch_ein_out))
  32. # 终端打印输出
  33. # vector a:
  34. # [0 1 2]
  35. # vector b:
  36. # [3 4 5]
  37. # torch ein out:
  38. # 14
  39. # torch ein out2:
  40. # 14
  41. # torch org out:
  42. # 14
  43. # numpy out:
  44. # [14]
  45. # is np_out == torch_ein_out ? True
  46. # is torch_ein_out2 == torch_ein_out ? True
  47. # is torch_org_out == torch_ein_out ? True

9. 矩阵元素对应相乘并求 reduce sum

  1. import torch
  2. import numpy as np
  3. a = torch.arange(6).reshape(2, 3)
  4. b = torch.arange(6,12).reshape(2, 3)
  5. # i = 2, j = 3
  6. torch_ein_out = torch.einsum('ij,ij->', [a, b]).numpy()
  7. # 等价形式,可以省略箭头和输出
  8. torch_ein_out2 = torch.einsum('ij,ij', [a, b]).numpy()
  9. torch_org_out = (a * b).sum().numpy()
  10. np_a = a.numpy()
  11. np_b = b.numpy()
  12. # 循环展开实现
  13. np_out = np.empty((1, ), dtype=np.int32)
  14. # 自由索引外循环
  15. # 这个例子没有自由索引
  16. for o in range(0, 1):
  17. # 求和索引内循环
  18. # 这个例子是 i 和 j
  19. sum_result = 0
  20. for i in range(0, 2):
  21. for j in range(0, 3):
  22. sum_result += np_a[i,j] * np_b[i,j]
  23. np_out[o] = sum_result
  24. print("matrix a:\n", np_a)
  25. print("matrix b:\n", np_b)
  26. print("torch ein out: \n", torch_ein_out)
  27. print("torch ein out2: \n", torch_ein_out2)
  28. print("torch org out: \n", torch_org_out)
  29. print("numpy out: \n", np_out)
  30. print("is np_out == torch_ein_out ?", np.allclose(torch_ein_out, np_out))
  31. print("is torch_ein_out2 == torch_ein_out ?", np.allclose(torch_ein_out2, torch_ein_out))
  32. print("is torch_org_out == torch_ein_out ?", np.allclose(torch_org_out, torch_ein_out))
  33. # 终端打印输出
  34. # matrix a:
  35. # [[0 1 2]
  36. # [3 4 5]]
  37. # matrix b:
  38. # [[ 6 7 8]
  39. # [ 9 10 11]]
  40. # torch ein out:
  41. # 145
  42. # torch ein out2:
  43. # 145
  44. # torch org out:
  45. # 145
  46. # numpy out:
  47. # [145]
  48. # is np_out == torch_ein_out ? True
  49. # is torch_ein_out2 == torch_ein_out ? True
  50. # is torch_org_out == torch_ein_out ? True

10. 向量外积

  1. import torch
  2. import numpy as np
  3. a = torch.arange(3)
  4. b = torch.arange(3,7) # [3, 4, 5, 6]
  5. # i = 3, j = 4
  6. torch_ein_out = torch.einsum('i,j->ij', [a, b]).numpy()
  7. # 等价形式,可以省略箭头和输出
  8. torch_ein_out2 = torch.einsum('i,j', [a, b]).numpy()
  9. torch_org_out = torch.outer(a, b).numpy()
  10. np_a = a.numpy()
  11. np_b = b.numpy()
  12. # 循环展开实现
  13. np_out = np.empty((3, 4), dtype=np.int32)
  14. # 自由索引外循环
  15. # 这个例子是 i 和 j
  16. for i in range(0, 3):
  17. for j in range(0, 4):
  18. # 求和索引内循环
  19. # 这个例子没有求和索引
  20. sum_result = 0
  21. for inner in range(0, 1):
  22. sum_result += np_a[i] * np_b[j]
  23. np_out[i, j] = sum_result
  24. print("vector a:\n", np_a)
  25. print("vector b:\n", np_b)
  26. print("torch ein out: \n", torch_ein_out)
  27. print("torch ein out2: \n", torch_ein_out2)
  28. print("torch org out: \n", torch_org_out)
  29. print("numpy out: \n", np_out)
  30. print("is np_out == torch_ein_out ?", np.allclose(torch_ein_out, np_out))
  31. print("is torch_ein_out2 == torch_ein_out ?", np.allclose(torch_ein_out2, torch_ein_out))
  32. print("is torch_org_out == torch_ein_out ?", np.allclose(torch_org_out, torch_ein_out))
  33. # 终端打印输出
  34. # vector a:
  35. # [0 1 2]
  36. # vector b:
  37. # [3 4 5 6]
  38. # torch ein out:
  39. # [[ 0 0 0 0]
  40. # [ 3 4 5 6]
  41. # [ 6 8 10 12]]
  42. # torch ein out2:
  43. # [[ 0 0 0 0]
  44. # [ 3 4 5 6]
  45. # [ 6 8 10 12]]
  46. # torch org out:
  47. # [[ 0 0 0 0]
  48. # [ 3 4 5 6]
  49. # [ 6 8 10 12]]
  50. # numpy out:
  51. # [[ 0 0 0 0]
  52. # [ 3 4 5 6]
  53. # [ 6 8 10 12]]
  54. # is np_out == torch_ein_out ? True
  55. # is torch_ein_out2 == torch_ein_out ? True
  56. # is torch_org_out == torch_ein_out ? True

11. batch 矩阵乘法

  1. import torch
  2. import numpy as np
  3. a = torch.randn(2,3,5)
  4. b = torch.randn(2,5,4)
  5. # i = 2, j = 3, k = 5, l = 4
  6. torch_ein_out = torch.einsum('ijk,ikl->ijl', [a, b]).numpy()
  7. torch_org_out = torch.bmm(a, b).numpy()
  8. np_a = a.numpy()
  9. np_b = b.numpy()
  10. # 循环展开实现
  11. np_out = np.empty((2, 3, 4), dtype=np.float32)
  12. # 自由索引外循环
  13. # 这个例子是 i,j和l
  14. for i in range(0, 2):
  15. for j in range(0, 3):
  16. for l in range(0, 4):
  17. # 求和索引内循环
  18. # 这个例子是 k
  19. sum_result = 0
  20. for k in range(0, 5):
  21. sum_result += np_a[i, j, k] * np_b[i, k, l]
  22. np_out[i, j, l] = sum_result
  23. print("is np_out == torch_ein_out ?", np.allclose(torch_ein_out, np_out))
  24. print("is torch_org_out == torch_ein_out ?", np.allclose(torch_ein_out, torch_org_out))
  25. # 终端打印输出
  26. # is np_out == torch_ein_out ? True
  27. # is torch_org_out == torch_ein_out ? True

12. 张量收缩(tensor contraction)

  1. import torch
  2. import numpy as np
  3. a = torch.randn(2,3,5,7)
  4. b = torch.randn(11,13,3,17,5)
  5. # p = 2, q = 3, r = 5, s = 7
  6. # t = 11, u = 13, v = 17, r = 5
  7. torch_ein_out = torch.einsum('pqrs,tuqvr->pstuv', [a, b]).numpy()
  8. torch_org_out = torch.tensordot(a, b, dims=([1, 2], [2, 4])).numpy()
  9. np_a = a.numpy()
  10. np_b = b.numpy()
  11. # 循环展开实现
  12. np_out = np.empty((2, 7, 11, 13, 17), dtype=np.float32)
  13. # 自由索引外循环
  14. # 这里就是 p,s,t,u和v
  15. for p in range(0, 2):
  16. for s in range(0, 7):
  17. for t in range(0, 11):
  18. for u in range(0, 13):
  19. for v in range(0, 17):
  20. # 求和索引内循环
  21. # 这里是 q和r
  22. sum_result = 0
  23. for q in range(0, 3):
  24. for r in range(0, 5):
  25. sum_result += np_a[p, q, r, s] * np_b[t, u, q, v, r]
  26. np_out[p, s, t, u, v] = sum_result
  27. print("is np_out == torch_ein_out ?", np.allclose(torch_ein_out, np_out, atol=1e-6))
  28. print("is torch_ein_out == torch_org_out ?", np.allclose(torch_ein_out, torch_org_out, atol=1e-6))
  29. # 终端打印输出
  30. # is np_out == torch_ein_out ? True
  31. # is torch_ein_out == torch_org_out ? True

13. 二次变换(bilinear transformation)

  1. import torch
  2. import numpy as np
  3. a = torch.randn(2,3)
  4. b = torch.randn(5,3,7)
  5. c = torch.randn(2,7)
  6. # i = 2, k = 3, j = 5, l = 7
  7. torch_ein_out = torch.einsum('ik,jkl,il->ij', [a, b, c]).numpy()
  8. m = torch.nn.Bilinear(3, 7, 5, bias=False)
  9. m.weight.data = b
  10. torch_org_out = m(a, c).detach().numpy()
  11. np_a = a.numpy()
  12. np_b = b.numpy()
  13. np_c = c.numpy()
  14. # 循环展开实现
  15. np_out = np.empty((2, 5), dtype=np.float32)
  16. # 自由索引外循环
  17. # 这里是 i 和 j
  18. for i in range(0, 2):
  19. for j in range(0, 5):
  20. # 求和索引内循环
  21. # 这里是 k 和 l
  22. sum_result = 0
  23. for k in range(0, 3):
  24. for l in range(0, 7):
  25. sum_result += np_a[i, k] * np_b[j, k, l] * np_c[i, l]
  26. np_out[i, j] = sum_result
  27. # print("matrix a:\n", np_a)
  28. # print("matrix b:\n", np_b)
  29. print("torch ein out: \n", torch_ein_out)
  30. print("torch org out: \n", torch_org_out)
  31. print("numpy out: \n", np_out)
  32. print("is np_out == torch_ein_out ?", np.allclose(torch_ein_out, np_out))
  33. print("is torch_org_out == torch_ein_out ?", np.allclose(torch_ein_out, torch_org_out))
  34. # 终端打印输出
  35. # torch ein out:
  36. # [[-2.9185116 0.17024004 -0.43915534 1.5860008 10.016678 ]
  37. # [-0.48688257 -3.5114982 -0.7543343 -0.46790922 1.4816089 ]]
  38. # torch org out:
  39. # [[-2.9185116 0.17024004 -0.43915534 1.5860008 10.016678 ]
  40. # [-0.48688257 -3.5114982 -0.7543343 -0.46790922 1.4816089 ]]
  41. # numpy out:
  42. # [[-2.9185114 0.17023998 -0.4391551 1.5860008 10.016678 ]
  43. # [-0.4868826 -3.5114982 -0.7543342 -0.4679092 1.4816089 ]]
  44. # is np_out == torch_ein_out ? True
  45. # is torch_org_out == torch_ein_out ? True

从上面的13个例子可以看出,只要确定了自由索引和求和索引,einsum 的输出计算都可以用一套比较通用的多层循来实现,外层的循环对应自由索引,内层循环对应求和索引。

Pytorch einsum 实现简要解读

C++ 代码解读:

github 代码链接:https://github.com/pytorch/pytorch/blob/53596cdb7359116e8c8ae18ffef06f2677ad1296/aten/src/ATen/native/Linear.cpp#L148

我只读懂了大概的实现思路,然后按照我自己的理解添加了注释(仅供参考):

  1. // 为了方便理解,我简化了大部分代码,
  2. // 并把对于 "..." 省略号的处理去掉了
  3. /**
  4. * 代码实现主要分为3大步:
  5. * 1. 解析 equation,分别得到输入和输出对应的字符串
  6. * 2. 补全输出和输入张量的维度,通过 permute 操作对齐输入和输出的维度
  7. * 3. 将维度对齐之后的输入张量相乘,然后根据求和索引累加
  8. */
  9. Tensor einsum(std::string equation, TensorList operands) {
  10. // ......
  11. // 把 equation 按照箭头分割
  12. // 得到箭头左边输入的部分
  13. const auto arrow_pos = equation.find("->");
  14. const auto lhs = equation.substr(0, arrow_pos);
  15. // 获取输入操作数个数
  16. const auto num_ops = operands.size();
  17. // 下面循环主要作用是解析 equation 左边输入部分,
  18. // 按 ',' 号分割得到每个输入张量对应的字符串,
  19. // 并把并把每个 char 字符转成 int, 范围 [0, 25]
  20. // 新建 vector 保存每个输入张量对应的字符数组
  21. std::vector<std::vector<int>> op_labels(num_ops);
  22. std::size_t curr_op = 0;
  23. for (auto i = decltype(lhs.length()){0}; i < lhs.length(); ++i) {
  24. switch (lhs[i]) {
  25. // ......
  26. case ',':
  27. // 遇到逗号,接下来解析下一个输入张量的字符串
  28. ++curr_op;
  29. // ......
  30. break;
  31. default:
  32. // ......
  33. // 把 char 字符转成 int
  34. op_labels[curr_op].push_back(lhs[i] - 'a');
  35. }
  36. }
  37. // TOTAL_LABELS = 26
  38. constexpr int TOTAL_LABELS = 'z' - 'a' + 1;
  39. std::vector<int> label_count(TOTAL_LABELS, 0);
  40. // 遍历所有输入操作数
  41. // 统计 equation 中 'a' - 'z' 每个字符的出现次数
  42. for(const auto i : c10::irange(num_ops)) {
  43. const auto labels = op_labels[i];
  44. for (const auto& label : labels) {
  45. // ......
  46. ++label_count[label];
  47. }
  48. // ......
  49. }
  50. // 创建一个 vector 用于保存 equation
  51. // 箭头右边输出的字符到索引的映射
  52. std::vector<int64_t> label_perm_index(TOTAL_LABELS, -1);
  53. int64_t perm_index = 0;
  54. // ......
  55. // 接下来解析输出字符串
  56. if (arrow_pos == std::string::npos) {
  57. // 处理用户省略了箭头的情况,
  58. // ......
  59. } else {
  60. // 一般情况
  61. // 得到箭头右边的输出
  62. const auto rhs = equation.substr(arrow_pos + 2);
  63. // 遍历输出字符串并解析
  64. for (auto i = decltype(rhs.length()){0}; i < rhs.length(); ++i) {
  65. switch (rhs[i]) {
  66. // ......
  67. default:
  68. // ......
  69. const auto label = rhs[i] - 'a';
  70. // ......
  71. // 建立字符到索引的映射,perm_index从0开始
  72. label_perm_index[label] = perm_index++;
  73. }
  74. }
  75. }
  76. // 保存原始的输出维度大小
  77. const int64_t out_size = perm_index;
  78. // 对齐输出张量的维度,使得对齐之后的维度等于
  79. // 自由索引加上求和索引的个数
  80. // 对输出补全省略掉的求和索引
  81. // 也就是在输入等式中出现,但是没有在输出等式中出现的字符
  82. for (const auto label : c10::irange(TOTAL_LABELS)) {
  83. if (label_count[label] > 0 && label_perm_index[label] == -1) {
  84. label_perm_index[label] = perm_index++;
  85. }
  86. }
  87. // 对所有输入张量,同样补齐维度至与输出维度大小相同
  88. // 最后对输入做 permute 操作,使得输入张量的每一维
  89. // 与输出张量的每一维能对上
  90. std::vector<Tensor> permuted_operands;
  91. for (const auto i: c10::irange(num_ops)) {
  92. // 保存输入张量最终做 permute 时候的维度映射
  93. std::vector<int64_t> perm_shape(perm_index, -1);
  94. Tensor operand = operands[i];
  95. // 取输入张量对应的 equation
  96. const auto labels = op_labels[i];
  97. std::size_t j = 0;
  98. for (const auto& label : labels) {
  99. // ......
  100. // 建立当前遍历到的输入张量字符到
  101. // 输出张量的字符到的映射
  102. // label: 当前遍历到的字符
  103. // label_perm_index: 保存了输出字符对应的索引
  104. // 所以 perm_shape 就是建立了输入张量的每一维度
  105. // 与输出张量维度的对应关系
  106. perm_shape[label_perm_index[label]] = j++;
  107. }
  108. // 如果输入张量的维度小于补全后的输出
  109. // 那么 perm_shape 中一定存在值为 -1 的元素
  110. // 那么相当于需要扩充输入张量的维度
  111. // 扩充的维度添加在张量的尾部
  112. for (int64_t& index : perm_shape) {
  113. if (index == -1) {
  114. // 在张量尾部插入维度1
  115. operand = operand.unsqueeze(-1);
  116. // 修改了perm_shape中的index,
  117. // 因为是引用取值
  118. index = j++;
  119. }
  120. }
  121. // 把输入张量的维度按照输出张量的维度重排,采用 permute 操作
  122. permuted_operands.push_back(operand.permute(perm_shape));
  123. }
  124. // ......
  125. Tensor result = permuted_operands[0];
  126. // .....
  127. // 计算最终结果
  128. for (const auto i: c10::irange(1, num_ops)) {
  129. Tensor operand = permuted_operands[i];
  130. // 新建 vector 用于保存求和索引
  131. std::vector<int64_t> sum_dims;
  132. // ......
  133. // 详细的代码可以阅读 Pytorch 源码
  134. // 这里我还没有完全理解 sumproduct_pair 的实现,
  135. // 里面用的是 permute + bmm,
  136. // 不过我觉得可以简单理解为
  137. // 将张量做广播乘法,再根据求和索引做累加
  138. result = sumproduct_pair(result, operand, sum_dims, false);
  139. }
  140. return result;
  141. }

图解实现

下面还是用矩阵乘法来说明 C++ 的实现思路,下图展示的是矩阵乘法的通用实现:
image.png
接下来展示 C++ 的实现思路:
image.png

总结

通过上面的实际例子和代码解读,可以看到 einsum 非常灵活,可以方便的实现各种常用的张量操作。希望读者通过这篇文章也可以轻松掌握 einsum 的基本用法。文中对于 Pytorch C++实现代码的解析是基于我自己的理解,如果觉得有误或者不理解的地方欢迎讨论。

参考资料