torch.bmm()
torch.matmul()


torch.bmm()强制规定维度和大小相同
torch.matmul()没有强制规定维度和大小,可以用利用广播机制进行不同维度的相乘操作
当进行操作的两个tensor都是3D时,两者等同。

torch.bmm()

官网:https://pytorch.org/docs/stable/torch.html#torch.bmm

torch.``bmm(input, mat2, out=None) → Tensor

torch.bmm()是tensor中的一个相乘操作,类似于矩阵中的AB。
参数:
input,mat2:两个要进行相乘的tensor结构,两者必须是3D维度的,每个维度中的大小是相同的。
output:输出结果
并且相乘的两个矩阵,要满足一定的维度要求:input(p,m,n)
mat2(p,n,a) ->output(p,m,a)。这个要求,可以类比于矩阵相乘。前一个矩阵的列等于后面矩阵的行才可以相乘。
例子:

  1. 1. import torch
  2. 2. x = torch.rand(2,3,6)
  3. 3. y = torch.rand(2,6,7)
  4. 4. print(torch.bmm(x,y).size())
  5. 5.
  6. 6. output:
  7. 7. torch.Size([2, 3, 7])
  8. 8.
  9. 9. ###############################
  10. 10. y = torch.rand(2,5,7) ##维度不匹配,报错
  11. 11. print(torch.bmm(x,y).size())
  12. 12.
  13. 13. output:
  14. 14. Expected tensor to have size 6 at dimension 1, but got size 5 for argument #2 'batch2' (while checking arguments for bmm)

torch.matmul()

torch.``matmul(input, other, out=None) → Tensor

torch.matmul()也是一种类似于矩阵相乘操作的tensor联乘操作。但是它可以利用python 中的广播机制,处理一些维度不同的tensor结构进行相乘操作。这也是该函数与torch.bmm()区别所在。
参数:
input,other:两个要进行操作的tensor结构
output:结果
一些规则约定:
(1)若两个都是1D(向量)的,则返回两个向量的点积

  1. 1. import torch
  2. 2. x = torch.rand(2)
  3. 3. y = torch.rand(2)
  4. 4. print(torch.matmul(x,y),torch.matmul(x,y).size())
  5. 5. output
  6. 6. tensor(0.1353) torch.Size([])

(2)若两个都是2D(矩阵)的,则按照(矩阵相乘)规则返回2D

  1. 1. x = torch.rand(2,4)
  2. 2. y = torch.rand(4,3) ###维度也要对应才可以乘
  3. 3. print(torch.matmul(x,y),'\n',torch.matmul(x,y).size())
  4. 4.
  5. 5. output:
  6. 6. tensor([[0.9128, 0.8425, 0.7269],
  7. 7. [1.4441, 1.5334, 1.3273]])
  8. 8. torch.Size([2, 3])

(3)若input维度1D,other维度2D,则先将1D的维度扩充到2D(1D的维数前面+1),然后得到结果后再将此维度去掉,得到的与input的维度相同。即使作扩充(广播)处理,input的维度也要和other维度做对应关系。

  1. 1. import torch
  2. 2. x = torch.rand(4) #1D
  3. 3. y = torch.rand(4,3) #2D
  4. 4. print(x.size())
  5. 5. print(y.size())
  6. 6. print(torch.matmul(x,y),'\n',torch.matmul(x,y).size())
  7. 7.
  8. 8. ### 扩充x =>(,4)
  9. 9. ### 相乘x(,4) * y(4,3) =>(,3)
  10. 10. ### 去掉1D =>(3)
  11. 11.
  12. 12. output:
  13. 13. torch.Size([4])
  14. 14. torch.Size([4, 3])
  15. 15. tensor([0.9600, 0.5736, 1.0430])
  16. 16. torch.Size([3])

(4)若input是2D,other是1D,则返回两者的点积结果。(个人觉得这块也可以理解成给other添加了维度,然后再去掉此维度,只不过维度是(3, )而不是规则(3)中的( ,4)了,但是可能就是因为内部机制不同,所以官方说的是点积而不是维度的升高和下降)

  1. 1. import torch
  2. 2. x = torch.rand(3) #1D
  3. 3. y = torch.rand(4,3) #2D
  4. 4. print(torch.matmul(y,x),'\n',torch.matmul(y,x).size()) #2D*1D
  5. 5.
  6. 6. output:
  7. 7. torch.Size([3])
  8. 8. torch.Size([4, 3])
  9. 9. tensor([0.8278, 0.5970, 1.0370, 0.2681])
  10. 10. torch.Size([4])

(5)如果一个维度至少是1D,另外一个大于2D,则返回的是一个批矩阵乘法( a batched matrix multiply)。
(a)若input是1D,other是大于2D的,则类似于规则(3)。

  1. 1. import torch
  2. 2. x = torch.randn(2, 3, 4)
  3. 3. y = torch.randn(3)
  4. 4. print(torch.matmul(y, x),'\n',torch.matmul(y, x).size()) #1D*3D
  5. 5.
  6. 6. output:
  7. 7. tensor([[-0.9747, -0.6660, -1.1704, -1.0522],
  8. 8. [ 0.0901, -1.5353, 1.5601, -0.0252]])
  9. 9. torch.Size([2, 4])
  1. b)若other1Dinput是大于2D的,则类似于规则(4)。
  1. 1. import torch
  2. 2. x = torch.randn(2, 3, 4)
  3. 3. y = torch.randn(4)
  4. 4.
  5. 5. print(torch.matmul(x, y),'\n',torch.matmul(x, y).size()) # 3D*1D
  6. 6.
  7. 7. output:
  8. 8. tensor([[ 0.6217, -0.1259, -0.2377],
  9. 9. [ 0.6874, 0.0733, 0.1793]])
  10. 10. torch.Size([2, 3])
  1. c)若inputother都是3D的,则与torch.bmm()函数功能一样。
  1. 1. import torch
  2. 2. x = torch.randn(2,2,4)
  3. 3. y = torch.randn(2,4,5)
  4. 4.
  5. 5. print(torch.matmul(x, y).size(),'\n',torch.bmm(x, y).size())
  6. 6. print(torch.equal(torch.matmul(x,y),torch.bmm(x,y)))
  7. 7.
  8. 8. output:
  9. 9. torch.Size([2, 2, 5])
  10. 10. torch.Size([2, 2, 5])
  11. 11. True
  1. d)如果input中某一维度满足可以广播(扩充),那么也是可以进行相乘操作的。例如 inputj,1,n,m)* other (k,m,p) = output(j,k,n,p)。
  1. 1. import torch
  2. 2. x = torch.randn(10,1,2,4)
  3. 3. y = torch.randn(2,4,5)
  4. 4.
  5. 5. print(torch.matmul(x, y).size())
  6. 6.
  7. 7. output
  8. 8. torch.Size([10, 2, 2, 5])

这个例子中,可以理解为x中dim=1这个维度可以扩充(广播),y中可以添加一个维度,然后在进行批乘操作。