参考来源:
CSDN:Pytorch矩阵乘法之torch.mul() 、 torch.mm() 及torch.matmul()的区别
CSDN:torch.bmm() 与 torch.matmul()
CSDN:关于 torch.bmm() 函数计算过程
CSDN:torch.bmm() 函数解读
SDN:torch.matmul() 用法介绍

torch.bmm() 强制规定维度和大小相同。
torch.matmul() 没有强制规定维度和大小,可以用利用广播机制进行不同维度的相乘操作。
当进行操作的两个 tensor 都是 3D 时,两者等同。

  • **torch.mul(a, b)** 是矩阵 a 和 b 对应位相乘,a 和 b 的维度必须相等,比如a的维度是 (1, 2),b 的维度是 (1, 2),返回的仍是 (1, 2) 的矩阵。
  • **torch.mm(a, b)** 是矩阵 a 和 b 矩阵相乘,比如a的维度是 (1, 2),b 的维度是 (2, 3),返回的就是 (1, 3) 的矩阵。
  • **torch.bmm()** 强制规定维度和大小相同。
  • **torch.matmul()** 没有强制规定维度和大小,可以用利用广播机制进行不同维度的相乘操作。

1、torch.mul(a, b) 和 torch.mm(a, b)

示例:

  1. import torch
  2. a = torch.rand(3, 4)
  3. b = torch.rand(3, 4)
  4. c = torch.rand(4, 5)
  5. print(torch.mul(a, b).size()) # 返回 1*2 的tensor
  6. print(torch.mm(a, c).size()) # 返回 1*3 的tensor
  7. print(torch.mul(a, c).size()) # 由于a、b维度不同,报错

结果:

  1. torch.Size([3, 4])
  2. torch.Size([3, 5])
  3. ---------------------------------------------------------------------------
  4. RuntimeError Traceback (most recent call last)
  5. <ipython-input-27-aea68cb5481f> in <module>
  6. 7 print(torch.mul(a, b).size()) # 返回 1*2 的tensor
  7. 8 print(torch.mm(a, c).size()) # 返回 1*3 的tensor
  8. ----> 9 print(torch.mul(a, c).size()) # 由于a、b维度不同,报错
  9. RuntimeError: The size of tensor a (4) must match the size of tensor b (5) at non-singleton dimension 1

2. torch.bmm()

官网:https://pytorch.org/docs/stable/torch.html#torch.bmm

  1. torch.bmm(input, mat2, out=None) Tensor

函数作用
torch.bmm() 是 tensor 中的一个相乘操作,类似于矩阵中的 A*B 。
参数:

  • input,mat2:两个要进行相乘的 tensor 结构,两者必须是 3D 维度的,每个维度中的大小是相同的。
  • output:输出结果

并且相乘的两个矩阵,要满足一定的维度要求:input(p,m,n) * mat2(p,n,a) ->output(p,m,a)。这个要求,可以类比于矩阵相乘。前一个矩阵的列等于后面矩阵的行才可以相乘

例子:

  1. import torch
  2. x = torch.rand(2, 3, 6)
  3. y = torch.rand(2, 6, 7)
  4. print(torch.bmm(x, y).size())
  5. ###############################
  6. y = torch.rand(2, 5, 7) ##维度不匹配,报错
  7. print(torch.bmm(x, y).size())
  8. """output:
  9. torch.Size([2, 3, 7])
  10. RuntimeError: Expected batch2_sizes[0] == bs && batch2_sizes[1] == contraction_size to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
  11. """

当 tensor 维度为 2 时会报错!

  1. import torch
  2. c = torch.randn((2, 5))
  3. print(c.shape)
  4. d = torch.reshape(c, (5, 2))
  5. print(d.shape)
  6. e = torch.bmm(c, d)
  7. """output:
  8. torch.Size([2, 5])
  9. torch.Size([5, 2])
  10. RuntimeError: Expected 3-dimensional tensor, but got 2-dimensional tensor for argument #1 'batch1' (while checking arguments for bmm)
  11. """

维度为 4 时也会报错!

  1. import torch
  2. ccc = torch.randn((1, 2, 2, 5))
  3. ddd = torch.randn((1, 2, 5, 2))
  4. e = torch.bmm(ccc, ddd)
  5. """output:
  6. RuntimeError: Expected 3-dimensional tensor, but got 4-dimensional tensor for argument #1 'batch1' (while checking arguments for bmm)
  7. """

3. torch.matmul()

  1. torch.matmul(input, other, out=None) Tensor

torch.matmul() 也是一种类似于矩阵相乘操作的 tensor 联乘操作。但是它可以利用 python 中的广播机制,处理一些维度不同的 tensor 结构进行相乘操作。这也是该函数与 torch.bmm() 区别所在。
参数:
input,other:两个要进行操作的 tensor 结构
output:结果

一些规则约定:
(1)若两个都是1D(向量)的,则返回两个向量的点积

  1. import torch
  2. x = torch.rand(2)
  3. y = torch.rand(2)
  4. print(torch.matmul(x,y),torch.matmul(x,y).size())
  5. output
  6. tensor(0.1353) torch.Size([])

(2)若两个都是2D(矩阵)的,则按照(矩阵相乘)规则返回2D

  1. x = torch.rand(2,4)
  2. y = torch.rand(4,3) ###维度也要对应才可以乘
  3. print(torch.matmul(x,y),'\n',torch.matmul(x,y).size())
  4. """output:
  5. tensor([[0.9128, 0.8425, 0.7269],
  6. [1.4441, 1.5334, 1.3273]])
  7. torch.Size([2, 3])
  8. """

(3)若 input 维度 1D,other 维度 2D,则先将 1D 的维度扩充到 2D(1D 的维数前面 +1),然后得到结果后再将此维度去掉,得到的与 input 的维度相同。即使作扩充(广播)处理,input 的维度也要和 other 维度做对应关系。

  1. import torch
  2. x = torch.rand(4) #1D
  3. y = torch.rand(4,3) #2D
  4. print(x.size())
  5. print(y.size())
  6. print(torch.matmul(x,y),'\n',torch.matmul(x,y).size())
  7. ### 扩充x =>(,4)
  8. ### 相乘x(,4) * y(4,3) =>(,3)
  9. ### 去掉1D =>(3)
  10. """output:
  11. torch.Size([4])
  12. torch.Size([4, 3])
  13. tensor([0.9600, 0.5736, 1.0430])
  14. torch.Size([3])
  15. """

(4)若 input 是2D,other 是 1D,则返回两者的点积结果。(个人觉得这块也可以理解成给 other 添加了维度,然后再去掉此维度,只不过维度是 (3, ) 而不是规则 (3) 中的 ( ,4) 了,但是可能就是因为内部机制不同,所以官方说的是点积而不是维度的升高和下降)

  1. import torch
  2. x = torch.rand(3) #1D
  3. y = torch.rand(4,3) #2D
  4. print(torch.matmul(y,x),'\n',torch.matmul(y,x).size()) #2D*1D
  5. """output:
  6. torch.Size([3])
  7. torch.Size([4, 3])
  8. tensor([0.8278, 0.5970, 1.0370, 0.2681])
  9. torch.Size([4])
  10. """

(5)如果一个维度至少是 1D,另外一个大于 2D,则返回的是一个批矩阵乘法( a batched matrix multiply)。
(a)若 input 是 1D,other 是大于 2D的,则类似于规则 (3)。

  1. import torch
  2. x = torch.randn(2, 3, 4)
  3. y = torch.randn(3)
  4. print(torch.matmul(y, x),'\n',torch.matmul(y, x).size()) #1D*3D
  5. """output:
  6. tensor([[-0.9747, -0.6660, -1.1704, -1.0522],
  7. [ 0.0901, -1.5353, 1.5601, -0.0252]])
  8. torch.Size([2, 4])
  9. """

(b)若 other 是1D,input 是大于 2D 的,则类似于规则 (4)。

  1. import torch
  2. x = torch.randn(2, 3, 4)
  3. y = torch.randn(4)
  4. print(torch.matmul(x, y),'\n',torch.matmul(x, y).size()) # 3D*1D
  5. """output:
  6. tensor([[ 0.6217, -0.1259, -0.2377],
  7. [ 0.6874, 0.0733, 0.1793]])
  8. torch.Size([2, 3])
  9. """

(c)若 input 和 other 都是 3D 的,则与 torch.bmm() 函数功能一样。

  1. import torch
  2. x = torch.randn(2,2,4)
  3. y = torch.randn(2,4,5)
  4. print(torch.matmul(x, y).size(),'\n',torch.bmm(x, y).size())
  5. print(torch.equal(torch.matmul(x,y),torch.bmm(x,y)))
  6. """output:
  7. torch.Size([2, 2, 5])
  8. torch.Size([2, 2, 5])
  9. True
  10. """

(d)如果 input 中某一维度满足可以广播(扩充),那么也是可以进行相乘操作的。例如 input(j,1,n,m)*other(k,m,p)=output(j,k,n,p)

  1. import torch
  2. x = torch.randn(10,1,2,4)
  3. y = torch.randn(2,4,5)
  4. print(torch.matmul(x, y).size())
  5. """output:
  6. torch.Size([10, 2, 2, 5])
  7. """

这个例子中,可以理解为 xdim=1 这个维度可以扩充(广播),y 中可以添加一个维度,然后在进行批乘操作。