这个东西的主要作用,就是增加一个维度。
现在我们假设有一个数组A,数组A是一个两行三列的矩阵。大小我们记成(2, 3)。
图1. 矩阵
如果我们设置axis = 0,即 np.expand_dims(A, 0)
,我们的矩阵就从(2, 3)->(1, 2, 3);如果,我们设置axis = 1,即np.expand_dims(A, 1)
,我们的矩阵就从(2, 3)->(2, 1, 3);如果,我们设置axis = 2,即np.expand_dims(A, 2)
,我们的矩阵就从(2, 3)->(2, 3, 1)。
从上面来看,我们的数据就发生了一个增维的过程,那么这个数据增维会产生什么变化?
通过上面一篇笔记,我们知道,Python中是通过[]套娃的个数来表示数据的维度的,比如我们的一维表示:1, 2然后二维表示:[1, 2, 3], [4, 5, 6],扩展到三维:[[[1, 2, 3], [4, 5, 6]], [7, 8, 9], [1, 1, 1]]。
所以,通俗点说,我们就相当于通过np.expand_dims(A, 0)
来给A的维度加[],然后在shape中加一个×。