1. import torch
    2. from torch import nn
    3. class AE(nn.Module):
    4. def __init__(self):
    5. super(AE, self).__init__()
    6. # [b, 784] => [b, 20]
    7. self.encoder = nn.Sequential(
    8. nn.Linear(784, 256),
    9. nn.ReLU(),
    10. nn.Linear(256, 64),
    11. nn.ReLU(),
    12. nn.Linear(64, 20),
    13. nn.ReLU()
    14. )
    15. # [b, 20] => [b, 784]
    16. self.decoder = nn.Sequential(
    17. nn.Linear(20, 64),
    18. nn.ReLU(),
    19. nn.Linear(64, 256),
    20. nn.ReLU(),
    21. nn.Linear(256, 784),
    22. nn.Sigmoid()
    23. )
    24. def forward(self, x):
    25. """
    26. :param x: [b, 1, 28, 28]
    27. :return:
    28. """
    29. batchsz = x.size(0)
    30. # flatten
    31. x = x.view(batchsz, 784)
    32. # encoder
    33. x = self.encoder(x)
    34. # decoder
    35. x = self.decoder(x)
    36. # reshape
    37. x = x.view(batchsz, 1, 28, 28)
    38. return x, None
    1. import torch
    2. from torch.utils.data import DataLoader
    3. from torch import nn, optim
    4. from torchvision import transforms, datasets
    5. from ae import AE
    6. from vae import VAE
    7. import visdom
    8. def main():
    9. mnist_train = datasets.MNIST('mnist', True, transform=transforms.Compose([
    10. transforms.ToTensor()
    11. ]), download=True)
    12. mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)
    13. mnist_test = datasets.MNIST('mnist', False, transform=transforms.Compose([
    14. transforms.ToTensor()
    15. ]), download=True)
    16. mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)
    17. x, _ = iter(mnist_train).next()
    18. print('x:', x.shape)
    19. device = torch.device('cuda')
    20. # model = AE().to(device)
    21. model = VAE().to(device)
    22. criteon = nn.MSELoss()
    23. optimizer = optim.Adam(model.parameters(), lr=1e-3)
    24. print(model)
    25. viz = visdom.Visdom()
    26. for epoch in range(1000):
    27. for batchidx, (x, _) in enumerate(mnist_train):
    28. # [b, 1, 28, 28]
    29. x = x.to(device)
    30. x_hat, kld = model(x)
    31. loss = criteon(x_hat, x)
    32. if kld is not None:
    33. elbo = - loss - 1.0 * kld
    34. loss = - elbo
    35. # backprop
    36. optimizer.zero_grad()
    37. loss.backward()
    38. optimizer.step()
    39. print(epoch, 'loss:', loss.item(), 'kld:', kld.item())
    40. x, _ = iter(mnist_test).next()
    41. x = x.to(device)
    42. with torch.no_grad():
    43. x_hat, kld = model(x)
    44. viz.images(x, nrow=8, win='x', opts=dict(title='x'))
    45. viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))
    46. if __name__ == '__main__':
    47. main()
    1. import torch
    2. from torch import nn
    3. class VAE(nn.Module):
    4. def __init__(self):
    5. super(VAE, self).__init__()
    6. # [b, 784] => [b, 20]
    7. # u: [b, 10]
    8. # sigma: [b, 10]
    9. self.encoder = nn.Sequential(
    10. nn.Linear(784, 256),
    11. nn.ReLU(),
    12. nn.Linear(256, 64),
    13. nn.ReLU(),
    14. nn.Linear(64, 20),
    15. nn.ReLU()
    16. )
    17. # [b, 20] => [b, 784]
    18. self.decoder = nn.Sequential(
    19. nn.Linear(10, 64),
    20. nn.ReLU(),
    21. nn.Linear(64, 256),
    22. nn.ReLU(),
    23. nn.Linear(256, 784),
    24. nn.Sigmoid()
    25. )
    26. self.criteon = nn.MSELoss()
    27. def forward(self, x):
    28. """
    29. :param x: [b, 1, 28, 28]
    30. :return:
    31. """
    32. batchsz = x.size(0)
    33. # flatten
    34. x = x.view(batchsz, 784)
    35. # encoder
    36. # [b, 20], including mean and sigma
    37. h_ = self.encoder(x)
    38. # [b, 20] => [b, 10] and [b, 10]
    39. mu, sigma = h_.chunk(2, dim=1)
    40. # reparametrize trick, epison~N(0, 1)
    41. h = mu + sigma * torch.randn_like(sigma)
    42. # decoder
    43. x_hat = self.decoder(h)
    44. # reshape
    45. x_hat = x_hat.view(batchsz, 1, 28, 28)
    46. kld = 0.5 * torch.sum(
    47. torch.pow(mu, 2) +
    48. torch.pow(sigma, 2) -
    49. torch.log(1e-8 + torch.pow(sigma, 2)) - 1
    50. ) / (batchsz*28*28)
    51. return x_hat, kld