unsqueeze()
import torch
a = torch.arange(0,6)
print(a)
print(a.view(1,6))
print(a.view(6,1))
print(a.unsqueeze(0))
print(a.unsqueeze(1))
print(a.unsqueeze(0).shape)
print(a.unsqueeze(1).shape)
可以看到,原来就是一个一维的数组,我们可以shape=[6],然后unsqueeze(0),就是在第0位纬度加上1,shape=[1,6];unsqueeze(1),就是在第1位纬度加上1,shape=[6,1]
squeeze()
print(a.unsqueeze(0).shape)
print(a.unsqueeze(1).shape)
print(a.unsqueeze(0).squeeze(0))
print(a.unsqueeze(1).squeeze(0))
print(a.unsqueeze(0).squeeze(0).shape)
print(a.unsqueeze(1).squeeze(0).shape)
squeeze()是去掉一个维度,并且去掉的这个维度的数值必须是1,否则无效
torch.squeeze()
压缩矩阵,我理解为降维
a.squeeze(i) 压缩第i维,如果这一维维数是1,则这一维可有可无,便可以压缩
import torch
a=torch.randn(1,3,4)
print(a)
b=a.squeeze(0)
print(b)
c=a.squeeze(1)
print(c
输出:
tensor([[[ 0.4627, 1.6447, 0.1320, 2.0946],
[-0.0080, 0.1794, 1.1898, -1.2525],
[ 0.8281, -0.8166, 1.8846, 0.9008]]])
一页三行4列的矩阵
第0维为1,则可以通过squeeze(0)删掉,转化为三行4列的矩阵
tensor([[ 0.4627, 1.6447, 0.1320, 2.0946],
[-0.0080, 0.1794, 1.1898, -1.2525],
[ 0.8281, -0.8166, 1.8846, 0.9008]])
第1维不为1,则不可以压缩
tensor([[[ 0.4627, 1.6447, 0.1320, 2.0946],
[-0.0080, 0.1794, 1.1898, -1.2525],
[ 0.8281, -0.8166, 1.8846, 0.9008]]])
torch.unsqueeze()
unsqueeze(i) 表示将第i维设置为1
对压缩为3行4列后的矩阵b进行操作,将第0维设置为1
c=b.unsqueeze(0)
print(c)
输出一个一页三行四列的矩阵
tensor([[[ 0.0661, -0.2386, -0.6610, 1.5774],
[ 1.2210, -0.1084, -0.1166, -0.2379],
[-1.0012, -0.4363, 1.0057, -1.5180]]])
将第一维设置为1
c=b.unsqueeze(1)
print(c)
输出一个3页,一行,4列的矩阵
tensor([[[-1.0067, -1.1477, -0.3213, -1.0633]],
[[-2.3976, 0.9857, -0.3462, -0.3648]],
[[ 1.1012, -0.4659, -0.0858, 1.6631]]])
另外,squeeze、unsqueeze操作不改变原矩阵