参考来源:
博客园:PyTorch笔记之 squeeze() 和 unsqueeze()

1. squeeze() 函数

squeeze() 用来去掉向量的一个维度,只有维度为 **1** 的那一维才能去掉。
Example
初始化 1 个向量 shape(1,2,3) 的向量

  1. import torch
  2. a = torch.rand((1,2,3))

squeeze() 和 unsqueeze() - 图1
去掉第 0 维,第 0 维的大小是 1 ,所以可以去掉第 0 维,去掉后向量的 shape(2,3)
squeeze() 和 unsqueeze() - 图2
去掉最后一维,最后一维的大小是 3 ,所以不会操作成功,向量的 shape 仍然是 (1,2,3)
squeeze() 和 unsqueeze() - 图3

2. unsqueeze() 函数

从函数名字就可以看出,unsqueeze()squeeze() 的功能是相反的,squeeze 是去掉 1 维,那 unsqueeze() 就是增加 1 维。
Example
增加第 4 维,此时向量的 shape(1,2,3,1)
squeeze() 和 unsqueeze() - 图4
插入第 0 维,我们初始化一个 shape(2,3)的向量,然后在插入第 0 维,插入后向量的 shape(1,2,3)
squeeze() 和 unsqueeze() - 图5