此函数用于张量计算,爱因斯坦指标求和比较重要的性质:
- 指标位置可以交换,交换之后不影响最终的结果(利用
permute
函数实现张量指标变换,下标即维度的意义) - 归根结底是矩阵的运算(典型的矩阵乘法—
torch.mm
,矩阵element-wise乘法— * ortorch.mul
),利用矩阵分析的方法来进行等价计算 ```python import torch
a = torch.randn((3, 2, 4, 6, 3)) # nkctv n, k, c, t, v = a.size() b = torch.randn((2, 3, 7)) # kvw , , w = b.size()
einsum1 = torch.einsum(“nkctv,kvw->nctw”, (a, b))
pa = a.permute(0, 2, 3, 1, 4).reshape(nct, kv) # (nct)(kv) pb = b.reshape(kv, w) # (kv)w einsum2 = pa.mm(pb).reshape(n, c, t, w)
einsum1 == einsum2
``` 上面的einsum1与einsum2恒等。