1. torch.cat()
在 pytorch 中,常见的拼接函数主要是两个,分别是:
**torch.stack()****torch.cat()**
一般 torch.cat() 是为了把函数 torch.stack() 得到 tensor 进行拼接而存在的。
torch.cat()
函数目的: 在给定维度上对输入的张量序列 **seq** 进行连接操作。
outputs = torch.cat(inputs, dim=0) → Tensor
参数:
**inputs**: 待连接的张量序列,可以是任意相同Tensor类型的python序列**dim**: 选择的扩维, 必须在0到len(inputs[0])之间,沿着此维连接张量序列。
重点
输入数据必须是序列,序列中数据是任意相同的 **shape** 的同类型 **tensor** 。
维度不可以超过输入数据的任一个张量的维度
例子
- 准备数据,每个的
shape都是[2,3]。
import torchx1 = torch.tensor([[11, 12, 13],[21, 22, 23]], dtype=torch.int)print("'x1.shape':", x1.shape)x2 = torch.tensor([[11, 12, 13],[21, 22, 23]], dtype=torch.int)print("'x2.shape':", x2.shape)"""'x1.shape': torch.Size([2, 3])'x2.shape': torch.Size([2, 3])"""
- 合成
inputs。
# inputs为2个形状为[2 , 3]的矩阵inputs = [x1, x2]print("inputs:\n",inputs)"""inputs:[tensor([[11, 12, 13],[21, 22, 23]], dtype=torch.int32),tensor([[11, 12, 13],[21, 22, 23]], dtype=torch.int32)]"""
- 查看结果,测试不同的
dim拼接结果。
y1 = torch.cat(inputs, dim=0)print("'torch.cat(inputs, dim=0).shape':",y1.shape)print("'torch.cat(inputs, dim=0)':\n",y1)y2 = torch.cat(inputs, dim=1)print("'torch.cat(inputs, dim=1).shape':",y2.shape)print("'torch.cat(inputs, dim=1)':\n",y2)y3 = torch.cat(inputs, dim=2).shape"""'torch.cat(inputs, dim=0).shape': torch.Size([4, 3])'torch.cat(inputs, dim=0)':tensor([[11, 12, 13],[21, 22, 23],[11, 12, 13],[21, 22, 23]], dtype=torch.int32)'torch.cat(inputs, dim=1).shape': torch.Size([2, 6])'torch.cat(inputs, dim=1)':tensor([[11, 12, 13, 11, 12, 13],[21, 22, 23, 21, 22, 23]], dtype=torch.int32)Traceback (most recent call last):File "test.py", line 30, in <module>y3 = torch.cat(inputs, dim=2).shapeIndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)"""
完整代码如下:
import torchx1 = torch.tensor([[11, 12, 13], [21, 22, 23]], dtype=torch.int)print("'x1.shape':", x1.shape)x2 = torch.tensor([[11, 12, 13], [21, 22, 23]], dtype=torch.int)print("'x2.shape':", x2.shape)inputs = [x1, x2]print("inputs:\n", inputs)y1 = torch.cat(inputs, dim=0)print("'torch.cat(inputs, dim=0).shape':",y1.shape)print("'torch.cat(inputs, dim=0)':\n",y1)y2 = torch.cat(inputs, dim=1)print("'torch.cat(inputs, dim=1).shape':",y2.shape)print("'torch.cat(inputs, dim=1)':\n",y2)y3 = torch.cat(inputs, dim=2).shape
总结
通常用来,把 torch.stack 得到 tensor 进行拼接而存在的。
2. torch.stack()
实际使用中,**torch.stack()** 和 **torch.cat()** 互相辅助。
函数的意义:使用 torch.stack() 可以保留两个信息:[1. 序列] 和 [2. 张量矩阵] 信息,属于【扩张再拼接】的函数。
形象的理解:假如数据都是二维矩阵(平面),它可以把这些一个个平面按第三维(例如:时间序列)压成一个三维的立方体,而立方体的长度就是时间序列长度。
该函数常出现在自然语言处理(**NLP**)和图像卷积神经网络(**CV**)中。
torch.stack()
官方解释:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。
浅显说法:把多个 2 维的张量凑成一个 3 维的张量;多个 3 维的凑成一个4 维的张量 … 以此类推,也就是在增加新的维度进行堆叠。
outputs = torch.stack(inputs, dim=?) → Tensor
参数
**inputs**: 待连接的张量序列。
注:python 的序列数据只有 list 和 tuple 。
**dim**: 新的维度, 必须在0到len(outputs)之间。
注:len(outputs) 是生成数据的维度大小,也就是 outputs 的维度值。
重点
函数中的输入 inputs 只允许是序列;且序列内部的张量元素,必须 shape 相等
举例:[tensor_1, tensor_2,..] 或者 (tensor_1, tensor_2,..),且必须 tensor_1.shape == tensor_2.shape 。dim 是选择生成的维度,必须满足 0<=dim<len(outputs);len(outputs) 是输出后的 tensor 的维度大小。
例子
- 准备
2个tensor数据,每个的shape都是[3,3]。
import torch# 假设是时间步T1的输出T1 = torch.tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])print("'T1.shape':", T1.shape)# 假设是时间步T2的输出T2 = torch.tensor([[10, 20, 30],[40, 50, 60],[70, 80, 90]])print("'T2.shape':", T2.shape)"""'T1.shape': torch.Size([3, 3])'T2.shape': torch.Size([3, 3])"""
- 测试
torch.stack()函数。
Y0 = torch.stack((T1, T2), dim=0)print("'torch.stack((T1, T2), dim=0).shape':", Y0.shape)print("'torch.stack((T1, T2), dim=0)':\n", Y0)Y1 = torch.stack((T1, T2), dim=1)print("'torch.stack((T1, T2), dim=1).shape':", Y1.shape)print("'torch.stack((T1, T2), dim=1)':\n", Y1)Y2 = torch.stack((T1, T2), dim=2)print("'torch.stack((T1, T2), dim=2).shape':", Y2.shape)print("'torch.stack((T1, T2), dim=2)':\n", Y2)Y3 = torch.stack((T1, T2), dim=3) # 选择的 dim > len(outputs),所以报错"""'torch.stack((T1, T2), dim=0).shape': torch.Size([2, 3, 3])'torch.stack((T1, T2), dim=0)':tensor([[[ 1, 2, 3],[ 4, 5, 6],[ 7, 8, 9]],[[10, 20, 30],[40, 50, 60],[70, 80, 90]]])'torch.stack((T1, T2), dim=1).shape': torch.Size([3, 2, 3])'torch.stack((T1, T2), dim=1)':tensor([[[ 1, 2, 3],[10, 20, 30]],[[ 4, 5, 6],[40, 50, 60]],[[ 7, 8, 9],[70, 80, 90]]])'torch.stack((T1, T2), dim=2).shape': torch.Size([3, 3, 2])'torch.stack((T1, T2), dim=2)':tensor([[[ 1, 10],[ 2, 20],[ 3, 30]],[[ 4, 40],[ 5, 50],[ 6, 60]],[[ 7, 70],[ 8, 80],[ 9, 90]]])Traceback (most recent call last):File "test.py", line 34, in <module>Y3 = torch.stack((T1, T2), dim=3)IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)"""
可以复制代码运行试试:拼接后的 tensor 形状,会根据不同的 dim 发生变化。
| dim | shape |
|---|---|
0 |
[2, 3, 3] |
1 |
[3, 2, 3] |
2 |
[3, 3, 2] |
3 |
溢出报错 |
总结
函数作用:
函数 torch.stack() 对序列数据内部的张量进行扩维拼接,指定维度由程序员选择、大小是生成后数据的维度区间。
存在意义:
在自然语言处理和卷及神经网络中, 通常为了保留 [序列(先后)信息] 和 [张量的矩阵信息] 才会使用 torch.stack() 。
手写过 RNN 的同学,知道在循环神经网络中输出数据是:一个 **list** ,该列表插入了 **seq_len** 个形状是 **[batch_size, output_size]** 的 **tensor** ,不利于计算,需要使用 torch.stack 进行拼接,保留–[1.**seq_len** 这个时间步]和–[2.张量属性 **[batch_size, output_size]**] 。
