Base

所有的代码以官方的MNIST进行测试,扩展

  1. from torch.nn import functional as F
  2. from torch import nn
  3. from torch.utils.data import DataLoader
  4. from torchvision.datasets import MNIST
  5. import os
  6. from torchvision import transforms
  7. from torch.optim import Adam
  8. from pytorch_lightning import Trainer
  9. from pytorch_lightning import LightningModule
  10. class LitMNIST(LightningModule):
  11. def __init__(self, hparams):
  12. super().__init__()
  13. self.layer_1 = nn.Linear(28 * 28, 10)
  14. self.params = hparams
  15. def forward(self, x):
  16. batch_size, channels, height, width = x.size()
  17. x = x.view(batch_size, -1)
  18. x = self.layer_1(x)
  19. x = F.log_softmax(x, dim=1)
  20. return x
  21. def training_step(self, batch, batch_idx):
  22. x, y = batch
  23. logits = self(x)
  24. loss = F.nll_loss(logits, y)
  25. return loss
  26. def configure_optimizers(self):
  27. optimizer = Adam(self.parameters(), lr=self.params['lr'], weight_decay= self.params['weight_decay'])
  28. return optimizer
  29. hparams={
  30. 'lr':0.001,
  31. 'weight_decay':0.0001
  32. }
  33. transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
  34. mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
  35. mnist_train = DataLoader(mnist_train, batch_size=256,num_workers=8)
  36. model = LitMNIST(hparams)
  37. trainer = Trainer()
  38. trainer.fit(model, mnist_train)