12.Linear

https://pytorch.org/docs/stable/generated/torch.nn.Linear.html#torch.nn.Linear

torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)

以vgg16网络举例,图中这一步骤就用到了线性层,in_fearure = 4096 , out_feature = 10000

如下图是一个线性神经网络

in_feature就是输入的x的个数

out_feature就是输出的个数

线性层会根据神经网络计算公式:Ax+bias得到output,在pytorch框架下使用时只需要设置好in_feature,out_feature 和 bias

torch.reshape()函数的使用

https://pytorch.org/docs/1.2.0/torch.html#torch.reshape

  1. #返回一个张量,内容和原来的张量相同,但是具有不同形状.
  2. #并且尽可能返回视图,否则才会返回拷贝,
  3. #因此,需要注意内存共享问题.
  4. #传入的参数可以有一个-1,
  5. #表示其具体值由其他维度信息和元素总个数推断出来.

torch.flatten()

https://pytorch.org/docs/stable/generated/torch.flatten.html?highlight=flatten#torch.flatten
用于将数据展平

  1. torch.flatten(input, start_dim=0, end_dim=-1)
  1. import torch
  2. import torchvision
  3. from torch import nn
  4. from torch.nn import Linear
  5. from torch.utils.data import DataLoader
  6. dataset = torchvision.datasets.CIFAR10('./dataset',train=False,transform=torchvision.transforms.ToTensor(),
  7. download=True)
  8. #这里设置drop_last = True的原因是最后一个Batch_size不足64,展平后不是196608不符合设置的线性层输入 因此设置一下
  9. dataloader = DataLoader(dataset,batch_size=64,drop_last=True)
  10. class DEMO(nn.Module):
  11. def __init__(self):
  12. super(DEMO, self).__init__()
  13. self.linear1 = Linear(in_features=196608, out_features=10, bias=False)
  14. def forward(self,x):
  15. x = self.linear1(x)
  16. return x
  17. demo = DEMO()
  18. #torch.reshape()函数官方文档
  19. # 返回一个张量,内容和原来的张量相同,但是具有不同形状.
  20. # 并且尽可能返回视图,否则才会返回拷贝,
  21. # 因此,需要注意内存共享问题.
  22. # 传入的参数可以有一个-1,
  23. # 表示其具体值由其他维度信息和元素总个数推断出来.
  24. for data in dataloader:
  25. imgs,targets = data
  26. print(imgs.shape)
  27. # output = torch.reshape(imgs,(1,1,1,-1))
  28. output = torch.flatten(imgs)
  29. print(output.shape)
  30. output_linear = demo(output)
  31. print(output_linear.shape)

在实际使用时,torchvision中有很多已经搭建好的模型可以直接使用