1. torch.cat()

pytorch 中,常见的拼接函数主要是两个,分别是:

  1. **torch.stack()**
  2. **torch.cat()**

一般 torch.cat() 是为了把函数 torch.stack() 得到 tensor 进行拼接而存在的。

torch.cat()

函数目的: 在给定维度上对输入的张量序列 **seq** 进行连接操作。

  1. outputs = torch.cat(inputs, dim=0) Tensor

参数:

  • **inputs** : 待连接的张量序列,可以是任意相同 Tensor 类型的 python 序列
  • **dim** : 选择的扩维, 必须在 0len(inputs[0]) 之间,沿着此维连接张量序列。

重点

输入数据必须是序列,序列中数据是任意相同的 **shape** 的同类型 **tensor**
维度不可以超过输入数据的任一个张量的维度

例子

  1. 准备数据,每个的 shape 都是 [2,3]
  1. import torch
  2. x1 = torch.tensor([[11, 12, 13],
  3. [21, 22, 23]], dtype=torch.int)
  4. print("'x1.shape':", x1.shape)
  5. x2 = torch.tensor([[11, 12, 13],
  6. [21, 22, 23]], dtype=torch.int)
  7. print("'x2.shape':", x2.shape)
  8. """
  9. 'x1.shape': torch.Size([2, 3])
  10. 'x2.shape': torch.Size([2, 3])
  11. """
  1. 合成 inputs
  1. # inputs为2个形状为[2 , 3]的矩阵
  2. inputs = [x1, x2]
  3. print("inputs:\n",inputs)
  4. """
  5. inputs:
  6. [tensor([[11, 12, 13],
  7. [21, 22, 23]], dtype=torch.int32),
  8. tensor([[11, 12, 13],
  9. [21, 22, 23]], dtype=torch.int32)]
  10. """
  1. 查看结果,测试不同的 dim 拼接结果。
  1. y1 = torch.cat(inputs, dim=0)
  2. print("'torch.cat(inputs, dim=0).shape':",y1.shape)
  3. print("'torch.cat(inputs, dim=0)':\n",y1)
  4. y2 = torch.cat(inputs, dim=1)
  5. print("'torch.cat(inputs, dim=1).shape':",y2.shape)
  6. print("'torch.cat(inputs, dim=1)':\n",y2)
  7. y3 = torch.cat(inputs, dim=2).shape
  8. """
  9. 'torch.cat(inputs, dim=0).shape': torch.Size([4, 3])
  10. 'torch.cat(inputs, dim=0)':
  11. tensor([[11, 12, 13],
  12. [21, 22, 23],
  13. [11, 12, 13],
  14. [21, 22, 23]], dtype=torch.int32)
  15. 'torch.cat(inputs, dim=1).shape': torch.Size([2, 6])
  16. 'torch.cat(inputs, dim=1)':
  17. tensor([[11, 12, 13, 11, 12, 13],
  18. [21, 22, 23, 21, 22, 23]], dtype=torch.int32)
  19. Traceback (most recent call last):
  20. File "test.py", line 30, in <module>
  21. y3 = torch.cat(inputs, dim=2).shape
  22. IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
  23. """

完整代码如下:

  1. import torch
  2. x1 = torch.tensor([[11, 12, 13], [21, 22, 23]], dtype=torch.int)
  3. print("'x1.shape':", x1.shape)
  4. x2 = torch.tensor([[11, 12, 13], [21, 22, 23]], dtype=torch.int)
  5. print("'x2.shape':", x2.shape)
  6. inputs = [x1, x2]
  7. print("inputs:\n", inputs)
  8. y1 = torch.cat(inputs, dim=0)
  9. print("'torch.cat(inputs, dim=0).shape':",y1.shape)
  10. print("'torch.cat(inputs, dim=0)':\n",y1)
  11. y2 = torch.cat(inputs, dim=1)
  12. print("'torch.cat(inputs, dim=1).shape':",y2.shape)
  13. print("'torch.cat(inputs, dim=1)':\n",y2)
  14. y3 = torch.cat(inputs, dim=2).shape

总结

通常用来,把 torch.stack 得到 tensor 进行拼接而存在的。

2. torch.stack()

实际使用中,**torch.stack()****torch.cat()** 互相辅助。
函数的意义:使用 torch.stack() 可以保留两个信息:[1. 序列][2. 张量矩阵] 信息,属于【扩张再拼接】的函数。
形象的理解:假如数据都是二维矩阵(平面),它可以把这些一个个平面按第三维(例如:时间序列)压成一个三维的立方体,而立方体的长度就是时间序列长度。
该函数常出现在自然语言处理(**NLP**)和图像卷积神经网络(**CV**)中。

torch.stack()

官方解释:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。
浅显说法:把多个 2 维的张量凑成一个 3 维的张量;多个 3 维的凑成一个4 维的张量 … 以此类推,也就是在增加新的维度进行堆叠。

  1. outputs = torch.stack(inputs, dim=?) Tensor

参数

  • **inputs** : 待连接的张量序列。

注:python 的序列数据只有 listtuple

  • **dim** : 新的维度, 必须在 0len(outputs) 之间。

注:len(outputs) 是生成数据的维度大小,也就是 outputs 的维度值。

重点

函数中的输入 inputs 只允许是序列;且序列内部的张量元素,必须 shape 相等
举例:[tensor_1, tensor_2,..] 或者 (tensor_1, tensor_2,..),且必须 tensor_1.shape == tensor_2.shape
dim 是选择生成的维度,必须满足 0<=dim<len(outputs)len(outputs) 是输出后的 tensor 的维度大小。

例子

  1. 准备 2tensor 数据,每个的 shape 都是 [3,3]
  1. import torch
  2. # 假设是时间步T1的输出
  3. T1 = torch.tensor([[1, 2, 3],
  4. [4, 5, 6],
  5. [7, 8, 9]])
  6. print("'T1.shape':", T1.shape)
  7. # 假设是时间步T2的输出
  8. T2 = torch.tensor([[10, 20, 30],
  9. [40, 50, 60],
  10. [70, 80, 90]])
  11. print("'T2.shape':", T2.shape)
  12. """
  13. 'T1.shape': torch.Size([3, 3])
  14. 'T2.shape': torch.Size([3, 3])
  15. """
  1. 测试 torch.stack() 函数。
  1. Y0 = torch.stack((T1, T2), dim=0)
  2. print("'torch.stack((T1, T2), dim=0).shape':", Y0.shape)
  3. print("'torch.stack((T1, T2), dim=0)':\n", Y0)
  4. Y1 = torch.stack((T1, T2), dim=1)
  5. print("'torch.stack((T1, T2), dim=1).shape':", Y1.shape)
  6. print("'torch.stack((T1, T2), dim=1)':\n", Y1)
  7. Y2 = torch.stack((T1, T2), dim=2)
  8. print("'torch.stack((T1, T2), dim=2).shape':", Y2.shape)
  9. print("'torch.stack((T1, T2), dim=2)':\n", Y2)
  10. Y3 = torch.stack((T1, T2), dim=3) # 选择的 dim > len(outputs),所以报错
  11. """
  12. 'torch.stack((T1, T2), dim=0).shape': torch.Size([2, 3, 3])
  13. 'torch.stack((T1, T2), dim=0)':
  14. tensor([[[ 1, 2, 3],
  15. [ 4, 5, 6],
  16. [ 7, 8, 9]],
  17. [[10, 20, 30],
  18. [40, 50, 60],
  19. [70, 80, 90]]])
  20. 'torch.stack((T1, T2), dim=1).shape': torch.Size([3, 2, 3])
  21. 'torch.stack((T1, T2), dim=1)':
  22. tensor([[[ 1, 2, 3],
  23. [10, 20, 30]],
  24. [[ 4, 5, 6],
  25. [40, 50, 60]],
  26. [[ 7, 8, 9],
  27. [70, 80, 90]]])
  28. 'torch.stack((T1, T2), dim=2).shape': torch.Size([3, 3, 2])
  29. 'torch.stack((T1, T2), dim=2)':
  30. tensor([[[ 1, 10],
  31. [ 2, 20],
  32. [ 3, 30]],
  33. [[ 4, 40],
  34. [ 5, 50],
  35. [ 6, 60]],
  36. [[ 7, 70],
  37. [ 8, 80],
  38. [ 9, 90]]])
  39. Traceback (most recent call last):
  40. File "test.py", line 34, in <module>
  41. Y3 = torch.stack((T1, T2), dim=3)
  42. IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
  43. """

可以复制代码运行试试:拼接后的 tensor 形状,会根据不同的 dim 发生变化。

dim shape
0 [2, 3, 3]
1 [3, 2, 3]
2 [3, 3, 2]
3 溢出报错

总结

函数作用:
函数 torch.stack() 对序列数据内部的张量进行扩维拼接,指定维度由程序员选择、大小是生成后数据的维度区间。
存在意义:
在自然语言处理和卷及神经网络中, 通常为了保留 [序列(先后)信息][张量的矩阵信息] 才会使用 torch.stack()
手写过 RNN 的同学,知道在循环神经网络中输出数据是:一个 **list** ,该列表插入了 **seq_len** 个形状是 **[batch_size, output_size]****tensor** ,不利于计算,需要使用 torch.stack 进行拼接,保留–[1.**seq_len** 这个时间步]和–[2.张量属性 **[batch_size, output_size]**]