参考来源:
CSDN:pytorch中torch.transpose() 与 torch.tensor.permute() 的区别
CSDN:pytorch —- tensor.permute() 和 torch.transpose()
博客园:PyTorch 两大转置函数 transpose() 和 permute()

1. tensor.permute(dim1, dim2, dim3, …)

permute 可以对任意高维矩阵进行转置。但只有 tensor.permute() 这个调用方式:

  1. x = torch.rand(2,3,4)
  2. print("x.shape:", x.shape)
  3. x = x.permute(2,1,0)
  4. print("x.shape:", x.shape)
  5. """
  6. 输出:
  7. x.shape: torch.Size([2, 3, 4])
  8. x.shape: torch.Size([4, 3, 2])
  9. [Finished in 1.0s]
  10. """

例 2:

  1. t.rand(2,3,4,5).permute(3,2,0,1).shape
  2. Out[669]: torch.Size([5, 4, 2, 3])

总结
传入 permute 方法的参数是维度, 未进行变换前的 dim[0, 1, 2] 的方式, 转换后表示将第 0 维度和第 2 维度调换。

2. torch.transpose(input, dim0, dim1, out=None)

函数返回输入矩阵 input 的转置。交换维度 dim0dim1
参数:

  • input (Tensor):输入张量,必填。
  • dim0 (int):转置的第一维,默认 0,可选。
  • dim1 (int):转置的第二维,默认 1,可选。

transpose 只能操作 2D 矩阵的转置(就是每次 transpose 只能在两个维度之间转换,其他维度保持不变)。
有两种调用方式:tensor.transpose()torch.transpose(tensor, dim1, dim2)
连续使用 transpose 也可实现 permute 的效果

  1. torch.transpose(Tensor, 1, 0)
  2. t.rand(2,3,4,5).transpose(3,0).transpose(2,1).transpose(3,2).shape
  3. Out[672]: torch.Size([5, 4, 2, 3])
  4. t.rand(2,3,4,5).transpose(1,0).transpose(2,1).transpose(3,1).shape
  5. Out[670]: torch.Size([3, 5, 2, 4])

3. tensor.permute() 和 torch.transpose()

相同点:交换张量的维度
不同点:主要区别是 transpose 只能在两个维度之间转换, permute 可以一下转换好几个维度。

  • 参数列表:**torch.transpose(input, dim0, dim1, out=None)** 只能传入两个维度参数,tensor 在这两个维度之间交换。
  • 参数列表:**torch.tensor.permute(dims)** 要求传入所有维度,tensor 按维度排列顺序进行交换。
  • 内存:**torch.transpose(dim1,dim2)** 得到的张量与原张量共享内存,而 **torch.tensor.permute(dims)** 不具备这个性质

总结permute 相比 transpose 更加灵活,transpose 具有共享内存机制。

  1. a=torch.tensor([[[1,2,3],[4,5,6]]])
  2. b=torch.tensor([[[1,2,3],[4,5,6]]])
  3. c=a.transpose(2,1)
  4. c=c.transpose(2,1)
  5. d=b.permute(0,2,1)
  6. d=d.permute(0,1,2)
  7. print(c)
  8. print(a)
  9. print(d)
  10. print(b)
  11. """
  12. #输出结果为:
  13. tensor([[[1, 2, 3],
  14. [4, 5, 6]]])
  15. tensor([[[1, 2, 3],
  16. [4, 5, 6]]])
  17. tensor([[[1, 4],
  18. [2, 5],
  19. [3, 6]]])
  20. tensor([[[1, 2, 3],
  21. [4, 5, 6]]])
  22. """