Base
所有的代码以官方的MNIST进行测试,扩展
from torch.nn import functional as Ffrom torch import nnfrom torch.utils.data import DataLoaderfrom torchvision.datasets import MNISTimport osfrom torchvision import transformsfrom torch.optim import Adamfrom pytorch_lightning import Trainerfrom pytorch_lightning import LightningModuleclass LitMNIST(LightningModule):def __init__(self, hparams):super().__init__()self.layer_1 = nn.Linear(28 * 28, 10)self.params = hparamsdef forward(self, x):batch_size, channels, height, width = x.size()x = x.view(batch_size, -1)x = self.layer_1(x)x = F.log_softmax(x, dim=1)return xdef training_step(self, batch, batch_idx):x, y = batchlogits = self(x)loss = F.nll_loss(logits, y)return lossdef configure_optimizers(self):optimizer = Adam(self.parameters(), lr=self.params['lr'], weight_decay= self.params['weight_decay'])return optimizerhparams={'lr':0.001,'weight_decay':0.0001}transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)mnist_train = DataLoader(mnist_train, batch_size=256,num_workers=8)model = LitMNIST(hparams)trainer = Trainer()trainer.fit(model, mnist_train)
