1.Tensor

  1. torch.Tensor([[1,2,3],[1,2,3]])
  2. tensor([[1.,2.,3.],
  3. [1.,2.,3.]])

在pytorch中的tensor有点类似于numpy里面的ndarray
都是多维数组的形式。

使用torch内部的api创建tensor

  1. torch.empty([3,4])
  2. torch.ones([3,4])
  3. torch.zeros([3,4])
  4. torch.rand([3,4])
  5. torch.randint(low=0,high=10,size=[3,4])

image.png

tensor的常用方法

1.获取tensor中的元素
当tensor中只有一个元素时,可以用item()来获取。

  1. a = torch.Tensor([[[1]]])
  2. a.item()

image.png

2.转化为Numpy数组

  1. t1 = torch.randint(low=0,high=10,size=[3,4])
  2. t1.numpy()

image.png
image.png
3.获取tensor形状
tensor.size()

image.png
4.形状的改变
tensor.view()——类似于numpy的reshape()

image.png
经常可以看到调用torch.view(-1,28*28)之类的调用,那么这里的-1是什么意思呢,经过查看文档view()得到了一下结果:

  • view()返回的数据和传入的tensor一样,只是形状不同
  • -1在这里的意思是让电脑帮我们计算,比如下面的例子,总长度是20,我们不想自己算20/5=420/5=4,就可以在不想算的位置放上-1,电脑就会自己计算对应的数字,这个在实际搭建网络的时候是很好用的
  • 还要注意view()返回的tensor和传入的tensor共享内存,意思就是修改其中一个,数据都会变

5.获取维数
tensor.dim()

6.最大值最小值
tensor.max()
tensor.min()

7.转置
tensor.t()

10.tensor切片