1.Tensor
torch.Tensor([[1,2,3],[1,2,3]])
tensor([[1.,2.,3.],
[1.,2.,3.]])
在pytorch中的tensor有点类似于numpy里面的ndarray
都是多维数组的形式。
使用torch内部的api创建tensor
torch.empty([3,4])
torch.ones([3,4])
torch.zeros([3,4])
torch.rand([3,4])
torch.randint(low=0,high=10,size=[3,4])
tensor的常用方法
1.获取tensor中的元素
当tensor中只有一个元素时,可以用item()来获取。
a = torch.Tensor([[[1]]])
a.item()
2.转化为Numpy数组
t1 = torch.randint(low=0,high=10,size=[3,4])
t1.numpy()
3.获取tensor形状
tensor.size()
4.形状的改变
tensor.view()——类似于numpy的reshape()
经常可以看到调用torch.view(-1,28*28)之类的调用,那么这里的-1是什么意思呢,经过查看文档view()得到了一下结果:
- view()返回的数据和传入的tensor一样,只是形状不同
- -1在这里的意思是让电脑帮我们计算,比如下面的例子,总长度是20,我们不想自己算20/5=420/5=4,就可以在不想算的位置放上-1,电脑就会自己计算对应的数字,这个在实际搭建网络的时候是很好用的
- 还要注意view()返回的tensor和传入的tensor共享内存,意思就是修改其中一个,数据都会变
5.获取维数
tensor.dim()
6.最大值最小值
tensor.max()
tensor.min()
7.转置
tensor.t()
10.tensor切片