一、torch.nn

torch.nn.Module

  1. 提供一个可调用类,包含状态属性(如:nn 层参数),能够知道内含的Parameter类型的属性,用于做权重更新。

torch.nn.functional

提供了一些函数的实现,如:卷积、线性层。

nn.Linear(784,10) # w*x +b

torch.nn.Parameter

tensor的封装,用于Module类中。

二、torch.optim

提供梯度优化算子,如,SGD。

三、数据

from torch.utils.data import TensorDataset,DataLoader
Dataset:包含__len__ 和 __getitem__ 方法,常用实现:TensorDataset
DataLoader:加载Dataset对象,创建可迭代的 batch 数据