使用 PyTorch 实现 Kronecker Product

原始文档:https://www.yuque.com/lart/ugkv9f/gb2h93

计算过程

实现Kronecker Product - 图1

PyTorch 实现

起因是看到了这篇文章:https://zhuanlan.zhihu.com/p/79295551介绍了一种新颖的卷积方式,其中使用了 kronecker product 方法来实现。这种计算理解很容易,但是实现起来该如何编程,这是一个值得思考的问题。文章结尾作者推荐的代码中给出了一种实现https://github.com/d-li14/dgconv.pytorch/blob/master/dgconv.py#L26

  1. def kronecker_product(mat1, mat2):
  2. # 在pytorch1.7版本之后,torch.ger就被torch.outer所代替了
  3. out_mat = torch.ger(mat1.view(-1), mat2.view(-1))
  4. # 这里的(mat1.size() + mat2.size())表示的是将两个list拼接起来
  5. out_mat = out_mat.reshape(*(mat1.size() + mat2.size())).permute([0, 2, 1, 3])
  6. out_mat = out_mat.reshape(mat1.size(0) * mat2.size(0), mat1.size(1) * mat2.size(1))
  7. return out_mat

这里应该参考的是这里的方法https://discuss.pytorch.org/t/kronecker-product/3919/7

但是该帖子后面给出了一种更简单的方法https://discuss.pytorch.org/t/kronecker-product/3919/10

  1. def kronecker(A, B):
  2. AB = torch.einsum("ab,cd->acbd", A, B)
  3. AB = AB.view(A.size(0)*B.size(0), A.size(1)*B.size(1))
  4. return AB

二者实际上是一致的,也就是说这里的

  1. out_mat = torch.ger(mat1.view(-1), mat2.view(-1))
  2. out_mat = out_mat.reshape(*(mat1.size() + mat2.size())).permute([0, 2, 1, 3])

与基于 enisum 的实现

  1. AB = torch.einsum("ab,cd->acbd", A, B)

表示的是一样的行为。

在前者中,假设 A=mat1B=mat2 ,二者分别为 axb 和 cxd 大小的矩阵。计算过程可以表述如下:

  1. 对于二者先通过矢量化(.view(-1))。
  2. 利用外积操作(torch.ger)计算出各个元素之间的乘积构成的矩阵,大小为 abxcd。
  3. 将结果调整为 axbxcxd 大小的形状。
  4. 利用 permute 操作变成 axcxbxd 大小的形状。

这个过程实际上与 einsum 中的维度索引的调整 'ab,cd->acbd' 是一致的。

殊途同归。

从另一个角度看einsum的模式设定

我们最终的目标是得到这样一个结果:

  1. A={lab} a,b={1,2}
  2. B={mcd} c,d={1,2}
  3. b: ┌──────1──────┐ ┌──────2──────┐
  4. d: 1 2 1 2
  5. a,c ┌──┴─────────────┴──┴─────────────┴─►
  6. : :
  7. 1 1 l11xm11 l11xm12 l12xm11 l12xm12
  8. 2 l11xm21 l11xm22 l12xm21 l12xm22
  9. 2 1 l21xm11 l21xm12 l22xm11 l22xm12
  10. 2 l21xm21 l21xm22 l22xm21 l22xm22

这里我标注了 A 和 B 对应的下标的变化范围。

由于我们利用 PyTorch 实现这些处理,那我们必定是要使用对应的矩阵运算和形状变换的。

从这里的图可以看出来,若是对于这里最终的结果形式反着使用变形操作就可以得到这样的结果:

  A={lab} a,b={1,2}
  B={mcd} c,d={1,2}

       d:   1       2
   a,c ┌─────────────────► b
   : : │                 │ :
   │ 1 │ l11xm11 l11xm12 │ 1
   │   │ l12xm11 l12xm12 │ 2
   1───┤                 │
   │ 2 │ l11xm21 l11xm22 │ 1
   │   │ l12xm21 l12xm22 │ 2
   ├───┤                 │
   │ 1 │ l21xm11 l21xm12 │ 1
   │   │ l22xm11 l22xm12 │ 2
   2───┤                 │
   │ 2 │ l21xm21 l21xm22 │ 1
   │   │ l22xm21 l22xm22 │ 2
   ▼   ▼                 ▼

这实际上也就是索引为 (i,m,j,n) 的 2D 矩阵,也是对应于 torch.einsum("ab,cd->acbd", A, B) 结果的表示。

前面的两种形式的代码实际上都是为了得到后面这个结果。再通过一个reshape操作从而获得最终结果。。

参考资料