此函数用于张量计算,爱因斯坦指标求和比较重要的性质:

  • 指标位置可以交换,交换之后不影响最终的结果(利用permute函数实现张量指标变换,下标即维度的意义)
  • 归根结底是矩阵的运算(典型的矩阵乘法—torch.mm,矩阵element-wise乘法— * or torch.mul),利用矩阵分析的方法来进行等价计算 ```python import torch

a = torch.randn((3, 2, 4, 6, 3)) # nkctv n, k, c, t, v = a.size() b = torch.randn((2, 3, 7)) # kvw , , w = b.size()

einsum1 = torch.einsum(“nkctv,kvw->nctw”, (a, b))

pa = a.permute(0, 2, 3, 1, 4).reshape(nct, kv) # (nct)(kv) pb = b.reshape(kv, w) # (kv)w einsum2 = pa.mm(pb).reshape(n, c, t, w)

einsum1 == einsum2

``` 上面的einsum1与einsum2恒等。