MLP Model

B16692_02_2.jpg

输入层 中间层 输出层
XOR输入 XOR输出

构建模型

XOR模型

  1. 导入库
  2. 准备数据
  3. 配置模型
  4. 训练模型
  5. 加载模型
  6. 做出预测

导入库

  1. import pytorch_lightning as pl
  2. import torch
  3. from torch import nn, optim
  4. from torch.autograd import Variable
  5. import pytorch_lightning as pl
  6. from pytorch_lightning.callbacks import ModelCheckpoint
  7. from torch.utils.data import DataLoader
  8. print("torch version:",torch.__version__)
  9. print("pytorch ligthening version:",pl.__version__)

准备数据

  1. # 对应异或4种输入 A B为两个输入(特征)
  2. xor_input = [Variable(torch.Tensor([0, 0])),
  3. Variable(torch.Tensor([0, 1])),
  4. Variable(torch.Tensor([1, 0])),
  5. Variable(torch.Tensor([1, 1]))]
  6. # 对应目标变量
  7. xor_target = [Variable(torch.Tensor([0])),
  8. Variable(torch.Tensor([1])),
  9. Variable(torch.Tensor([1])),
  10. Variable(torch.Tensor([0]))]
  11. # 创造数据加载器 我们可以通过多种方式创建数据集并将其作为数据加载器传递给 PyTorch Lightning
  12. xor_data = list(zip(xor_input, xor_target))
  13. train_loader = DataLoader(xor_data, batch_size=1)
  14. # [(tensor([0., 0.]), tensor([0.])),
  15. # (tensor([0., 1.]), tensor([1.])),
  16. # (tensor([1., 0.]), tensor([1.])),
  17. # (tensor([1., 1.]), tensor([0.]))]

配置模型

  1. 初始化模型
  2. 将输入映射到模型
  3. 配置优化器
  4. 设置训练参数

初始化模型

  1. class XORModel(pl.LightningModule):
  2. def __init__(self):
  3. # super().__init__()
  4. super(XORModel,self).__init__()
  5. self.input_layer = nn.Linear(2, 4)
  6. self.output_layer = nn.Linear(4,1)
  7. self.sigmoid = nn.Sigmoid()
  8. # mean squared error
  9. self.loss = nn.MSELoss()

将输入映射到模型

  1. # 将输入映射到模型
  2. # 将接收的特征传递到输入层
  3. # 输入层生成结果传递到激活函数中
  4. # 激活函数生成的结果传递到输出层中
  5. def forward(self, input):
  6. #print("INPUT:", input.shape)
  7. x = self.input_layer(input)
  8. #print("FIRST:", x.shape)
  9. x = self.sigmoid(x)
  10. #print("SECOND:", x.shape)
  11. output = self.output_layer(x)
  12. #print("THIRD:", output.shape)
  13. return output

配置优化器

  1. def configure_optimizers(self):
  2. # 模型所有参数
  3. params = self.parameters()
  4. # 创建优化器
  5. optimizer = optim.Adam(params=params, lr = 0.01)
  6. return optimizer

设置训练参数

  1. # batch_idx为当前批次序号
  2. def training_step(self, batch, batch_idx):
  3. # 批量数据batch 里面包括输入/特征以及target目标
  4. xor_input, xor_target = batch
  5. # print("XOR INPUT:", xor_input.shape)
  6. # print("XOR TARGET:", xor_target.shape)
  7. # self会间接调用forward方法
  8. outputs = self(xor_input)
  9. # print("XOR OUTPUT:", outputs.shape)
  10. loss = self.loss(outputs, xor_target)
  11. return loss

训练模型

  1. # 训练模型
  2. from pytorch_lightning.utilities.types import TRAIN_DATALOADERS
  3. # 创建模型检查点回调函数 保存模型并不限于训练模型后 在训练模型中也要保存
  4. checkpoint_callback = ModelCheckpoint()
  5. model = XORModel()
  6. # Trainer是一些关键事物的抽象,例如循环数据集、反向传播、清除梯度和优化器步骤
  7. # Trainer类支持许多帮助构建模型的功能
  8. # 其中一些功能是各种回调、模型检查点、提前停止、开发运行单元测试、对GPU和TPU、记录器、日志、时期等的支持
  9. # 创建训练器
  10. trainer = pl.Trainer(max_epochs=100, callbacks=[checkpoint_callback])
  11. # 传递模型和训练数据 开始训练
  12. trainer.fit(model, train_dataloaders=train_loader)

image.png
image.png
image.png

加载模型

  1. # 加载模型
  2. # 获取最新版本模型路径
  3. print(checkpoint_callback.best_model_path)
  4. # 从检测点中加载模型
  5. train_model = model.load_from_checkpoint(checkpoint_callback.best_model_path)

做出预测

  1. # 获取最新版本模型路径
  2. print(checkpoint_callback.best_model_path)
  3. # 从检测点中加载模型
  4. train_model = model.load_from_checkpoint(checkpoint_callback.best_model_path)
  5. test = torch.utils.data.DataLoader(xor_input, batch_size=1)
  6. for val in xor_input:
  7. _ = train_model(val)
  8. print([int(val[0]),int(val[1])], int(_.round()))
  9. from torchmetrics.functional import accuracy
  10. print(checkpoint_callback.best_model_path)
  11. train_model = model.load_from_checkpoint(checkpoint_callback.best_model_path)
  12. total_accuracy = []
  13. for xor_input, xor_target in train_loader:
  14. for i in range(100):
  15. output_tensor = train_model(xor_input)
  16. test_accuracy = accuracy(output_tensor, xor_target.int())
  17. total_accuracy.append(test_accuracy)
  18. total_accuracy = torch.mean(torch.stack(total_accuracy))
  19. print("TOTAL ACCURACY FOR 100 ITERATIONS: ", total_accuracy.item())