unsqueeze()

  1. import torch
  2. a = torch.arange(0,6)
  3. print(a)
  4. print(a.view(1,6))
  5. print(a.view(6,1))
  6. print(a.unsqueeze(0))
  7. print(a.unsqueeze(1))
  8. print(a.unsqueeze(0).shape)
  9. print(a.unsqueeze(1).shape)

pytorch中squeeze()和unsqueeze()函数 - 图1
可以看到,原来就是一个一维的数组,我们可以shape=[6],然后unsqueeze(0),就是在第0位纬度加上1,shape=[1,6];unsqueeze(1),就是在第1位纬度加上1,shape=[6,1]

squeeze()

  1. print(a.unsqueeze(0).shape)
  2. print(a.unsqueeze(1).shape)
  3. print(a.unsqueeze(0).squeeze(0))
  4. print(a.unsqueeze(1).squeeze(0))
  5. print(a.unsqueeze(0).squeeze(0).shape)
  6. print(a.unsqueeze(1).squeeze(0).shape)

pytorch中squeeze()和unsqueeze()函数 - 图2
squeeze()是去掉一个维度,并且去掉的这个维度的数值必须是1,否则无效

torch.squeeze()

压缩矩阵,我理解为降维
a.squeeze(i) 压缩第i维,如果这一维维数是1,则这一维可有可无,便可以压缩

  1. import torch
  2. a=torch.randn(1,3,4)
  3. print(a)
  4. b=a.squeeze(0)
  5. print(b)
  6. c=a.squeeze(1)
  7. print(c
  8. 输出:
  9. tensor([[[ 0.4627, 1.6447, 0.1320, 2.0946],
  10. [-0.0080, 0.1794, 1.1898, -1.2525],
  11. [ 0.8281, -0.8166, 1.8846, 0.9008]]])
  12. 一页三行4列的矩阵
  13. 0维为1,则可以通过squeeze(0)删掉,转化为三行4列的矩阵
  14. tensor([[ 0.4627, 1.6447, 0.1320, 2.0946],
  15. [-0.0080, 0.1794, 1.1898, -1.2525],
  16. [ 0.8281, -0.8166, 1.8846, 0.9008]])
  17. 1维不为1,则不可以压缩
  18. tensor([[[ 0.4627, 1.6447, 0.1320, 2.0946],
  19. [-0.0080, 0.1794, 1.1898, -1.2525],
  20. [ 0.8281, -0.8166, 1.8846, 0.9008]]])

torch.unsqueeze()

unsqueeze(i) 表示将第i维设置为1
对压缩为3行4列后的矩阵b进行操作,将第0维设置为1

  1. c=b.unsqueeze(0)
  2. print(c)
  3. 输出一个一页三行四列的矩阵
  4. tensor([[[ 0.0661, -0.2386, -0.6610, 1.5774],
  5. [ 1.2210, -0.1084, -0.1166, -0.2379],
  6. [-1.0012, -0.4363, 1.0057, -1.5180]]])
  7. 将第一维设置为1
  8. c=b.unsqueeze(1)
  9. print(c)
  10. 输出一个3页,一行,4列的矩阵
  11. tensor([[[-1.0067, -1.1477, -0.3213, -1.0633]],
  12. [[-2.3976, 0.9857, -0.3462, -0.3648]],
  13. [[ 1.1012, -0.4659, -0.0858, 1.6631]]])

另外,squeeze、unsqueeze操作不改变原矩阵