一、torch.nn
torch.nn.Module
提供一个可调用类,包含状态属性(如:nn 层参数),能够知道内含的Parameter类型的属性,用于做权重更新。
torch.nn.functional
提供了一些函数的实现,如:卷积、线性层。
nn.Linear(784,10) # w*x +b
torch.nn.Parameter
tensor的封装,用于Module类中。
二、torch.optim
提供梯度优化算子,如,SGD。
三、数据
from torch.utils.data import TensorDataset,DataLoader
Dataset:包含__len__ 和 __getitem__ 方法,常用实现:TensorDataset
DataLoader:加载Dataset对象,创建可迭代的 batch 数据