1. torch.cat()
在 pytorch
中,常见的拼接函数主要是两个,分别是:
**torch.stack()**
**torch.cat()**
一般 torch.cat()
是为了把函数 torch.stack()
得到 tensor
进行拼接而存在的。
torch.cat()
函数目的: 在给定维度上对输入的张量序列 **seq**
进行连接操作。
outputs = torch.cat(inputs, dim=0) → Tensor
参数:
**inputs**
: 待连接的张量序列,可以是任意相同Tensor
类型的python
序列**dim**
: 选择的扩维, 必须在0
到len(inputs[0])
之间,沿着此维连接张量序列。
重点
输入数据必须是序列,序列中数据是任意相同的 **shape**
的同类型 **tensor**
。
维度不可以超过输入数据的任一个张量的维度
例子
- 准备数据,每个的
shape
都是[2,3]
。
import torch
x1 = torch.tensor([[11, 12, 13],
[21, 22, 23]], dtype=torch.int)
print("'x1.shape':", x1.shape)
x2 = torch.tensor([[11, 12, 13],
[21, 22, 23]], dtype=torch.int)
print("'x2.shape':", x2.shape)
"""
'x1.shape': torch.Size([2, 3])
'x2.shape': torch.Size([2, 3])
"""
- 合成
inputs
。
# inputs为2个形状为[2 , 3]的矩阵
inputs = [x1, x2]
print("inputs:\n",inputs)
"""
inputs:
[tensor([[11, 12, 13],
[21, 22, 23]], dtype=torch.int32),
tensor([[11, 12, 13],
[21, 22, 23]], dtype=torch.int32)]
"""
- 查看结果,测试不同的
dim
拼接结果。
y1 = torch.cat(inputs, dim=0)
print("'torch.cat(inputs, dim=0).shape':",y1.shape)
print("'torch.cat(inputs, dim=0)':\n",y1)
y2 = torch.cat(inputs, dim=1)
print("'torch.cat(inputs, dim=1).shape':",y2.shape)
print("'torch.cat(inputs, dim=1)':\n",y2)
y3 = torch.cat(inputs, dim=2).shape
"""
'torch.cat(inputs, dim=0).shape': torch.Size([4, 3])
'torch.cat(inputs, dim=0)':
tensor([[11, 12, 13],
[21, 22, 23],
[11, 12, 13],
[21, 22, 23]], dtype=torch.int32)
'torch.cat(inputs, dim=1).shape': torch.Size([2, 6])
'torch.cat(inputs, dim=1)':
tensor([[11, 12, 13, 11, 12, 13],
[21, 22, 23, 21, 22, 23]], dtype=torch.int32)
Traceback (most recent call last):
File "test.py", line 30, in <module>
y3 = torch.cat(inputs, dim=2).shape
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
"""
完整代码如下:
import torch
x1 = torch.tensor([[11, 12, 13], [21, 22, 23]], dtype=torch.int)
print("'x1.shape':", x1.shape)
x2 = torch.tensor([[11, 12, 13], [21, 22, 23]], dtype=torch.int)
print("'x2.shape':", x2.shape)
inputs = [x1, x2]
print("inputs:\n", inputs)
y1 = torch.cat(inputs, dim=0)
print("'torch.cat(inputs, dim=0).shape':",y1.shape)
print("'torch.cat(inputs, dim=0)':\n",y1)
y2 = torch.cat(inputs, dim=1)
print("'torch.cat(inputs, dim=1).shape':",y2.shape)
print("'torch.cat(inputs, dim=1)':\n",y2)
y3 = torch.cat(inputs, dim=2).shape
总结
通常用来,把 torch.stack
得到 tensor
进行拼接而存在的。
2. torch.stack()
实际使用中,**torch.stack()**
和 **torch.cat()**
互相辅助。
函数的意义:使用 torch.stack()
可以保留两个信息:[1. 序列] 和 [2. 张量矩阵] 信息,属于【扩张再拼接】的函数。
形象的理解:假如数据都是二维矩阵(平面),它可以把这些一个个平面按第三维(例如:时间序列)压成一个三维的立方体,而立方体的长度就是时间序列长度。
该函数常出现在自然语言处理(**NLP**
)和图像卷积神经网络(**CV**
)中。
torch.stack()
官方解释:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。
浅显说法:把多个 2 维的张量凑成一个 3 维的张量;多个 3 维的凑成一个4 维的张量 … 以此类推,也就是在增加新的维度进行堆叠。
outputs = torch.stack(inputs, dim=?) → Tensor
参数
**inputs**
: 待连接的张量序列。
注:python 的序列数据只有 list
和 tuple
。
**dim**
: 新的维度, 必须在0
到len(outputs)
之间。
注:len(outputs)
是生成数据的维度大小,也就是 outputs
的维度值。
重点
函数中的输入 inputs
只允许是序列;且序列内部的张量元素,必须 shape
相等
举例:[tensor_1, tensor_2,..]
或者 (tensor_1, tensor_2,..)
,且必须 tensor_1.shape == tensor_2.shape
。dim
是选择生成的维度,必须满足 0<=dim<len(outputs)
;len(outputs)
是输出后的 tensor
的维度大小。
例子
- 准备
2
个tensor
数据,每个的shape
都是[3,3]
。
import torch
# 假设是时间步T1的输出
T1 = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
print("'T1.shape':", T1.shape)
# 假设是时间步T2的输出
T2 = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
print("'T2.shape':", T2.shape)
"""
'T1.shape': torch.Size([3, 3])
'T2.shape': torch.Size([3, 3])
"""
- 测试
torch.stack()
函数。
Y0 = torch.stack((T1, T2), dim=0)
print("'torch.stack((T1, T2), dim=0).shape':", Y0.shape)
print("'torch.stack((T1, T2), dim=0)':\n", Y0)
Y1 = torch.stack((T1, T2), dim=1)
print("'torch.stack((T1, T2), dim=1).shape':", Y1.shape)
print("'torch.stack((T1, T2), dim=1)':\n", Y1)
Y2 = torch.stack((T1, T2), dim=2)
print("'torch.stack((T1, T2), dim=2).shape':", Y2.shape)
print("'torch.stack((T1, T2), dim=2)':\n", Y2)
Y3 = torch.stack((T1, T2), dim=3) # 选择的 dim > len(outputs),所以报错
"""
'torch.stack((T1, T2), dim=0).shape': torch.Size([2, 3, 3])
'torch.stack((T1, T2), dim=0)':
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],
[[10, 20, 30],
[40, 50, 60],
[70, 80, 90]]])
'torch.stack((T1, T2), dim=1).shape': torch.Size([3, 2, 3])
'torch.stack((T1, T2), dim=1)':
tensor([[[ 1, 2, 3],
[10, 20, 30]],
[[ 4, 5, 6],
[40, 50, 60]],
[[ 7, 8, 9],
[70, 80, 90]]])
'torch.stack((T1, T2), dim=2).shape': torch.Size([3, 3, 2])
'torch.stack((T1, T2), dim=2)':
tensor([[[ 1, 10],
[ 2, 20],
[ 3, 30]],
[[ 4, 40],
[ 5, 50],
[ 6, 60]],
[[ 7, 70],
[ 8, 80],
[ 9, 90]]])
Traceback (most recent call last):
File "test.py", line 34, in <module>
Y3 = torch.stack((T1, T2), dim=3)
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
"""
可以复制代码运行试试:拼接后的 tensor
形状,会根据不同的 dim
发生变化。
dim | shape |
---|---|
0 |
[2, 3, 3] |
1 |
[3, 2, 3] |
2 |
[3, 3, 2] |
3 |
溢出报错 |
总结
函数作用:
函数 torch.stack()
对序列数据内部的张量进行扩维拼接,指定维度由程序员选择、大小是生成后数据的维度区间。
存在意义:
在自然语言处理和卷及神经网络中, 通常为了保留 [序列(先后)信息] 和 [张量的矩阵信息] 才会使用 torch.stack()
。
手写过 RNN
的同学,知道在循环神经网络中输出数据是:一个 **list**
,该列表插入了 **seq_len**
个形状是 **[batch_size, output_size]**
的 **tensor**
,不利于计算,需要使用 torch.stack
进行拼接,保留–[1.**seq_len**
这个时间步]和–[2.张量属性 **[batch_size, output_size]**
] 。