参考来源:
博客园:PyTorch笔记之 squeeze() 和 unsqueeze()
1. squeeze() 函数
squeeze() 用来去掉向量的一个维度,只有维度为 **1** 的那一维才能去掉。
Example:
初始化 1 个向量 shape 为 (1,2,3) 的向量
import torcha = torch.rand((1,2,3))

去掉第 0 维,第 0 维的大小是 1 ,所以可以去掉第 0 维,去掉后向量的 shape 是(2,3) 。

去掉最后一维,最后一维的大小是 3 ,所以不会操作成功,向量的 shape 仍然是 (1,2,3)
。
2. unsqueeze() 函数
从函数名字就可以看出,unsqueeze() 和 squeeze() 的功能是相反的,squeeze 是去掉 1 维,那 unsqueeze() 就是增加 1 维。
Example:
增加第 4 维,此时向量的 shape 是(1,2,3,1) 。
插入第 0 维,我们初始化一个 shape 为(2,3)的向量,然后在插入第 0 维,插入后向量的 shape 为(1,2,3) 。
