Code

Network Architecture

  1. ### Train of DCGAN network on MNIST dataset with Discriminator
  2. ### and Generator imported from models.py
  3. import torch
  4. import torch.nn as nn
  5. # they did not use batchNorm in the early in the first layer of the discriminator
  6. class Discriminator(nn.Module):
  7. def __init__(self,channels_img, features_d):
  8. super(Discriminator,self).__init__()
  9. self.disc = nn.Sequential(
  10. nn.Conv2d(channels_img, features_d, kernel_size=4,stride=2,padding=1),
  11. nn.LeakyReLU(0.2),
  12. self._block(features_d, features_d*2, 4, 2, 1),
  13. self._block(features_d*2, features_d*4, 4, 2, 1),
  14. self._block(features_d*4, features_d*8, 4, 2, 1),
  15. nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0),
  16. nn.Sigmoid(),
  17. )
  18. def _block(self, in_channels, out_channels, kernel_size, stride, padding):
  19. return nn.Sequential(
  20. nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
  21. nn.BatchNorm2d(out_channels),
  22. nn.LeakyReLU(0.2),
  23. )
  24. def forward(self,x):
  25. return self.disc(x)
  26. class Generator(nn.Module):
  27. def __init__(self, z_dim, channels_img, features_g):
  28. super(Generator,self).__init__()
  29. self.gen = nn.Sequential(
  30. # Input: N x z_dim x 1 x 1
  31. self._block(z_dim, features_g*16, 4, 1, 0), # N x f_g*16 x 4 x 4
  32. self._block(features_g*16, features_g*8, 4, 2 ,1),
  33. self._block(features_g*8, features_g*4, 4, 2, 1),
  34. self._block(features_g*4, features_g*2, 4, 2, 1),
  35. nn.ConvTranspose2d(features_g*2, channels_img, kernel_size=4, stride=2, padding=1),
  36. nn.Tanh(), #[-1,1]
  37. )
  38. def _block(self, in_channels, out_channels, kernel_size, stride, padding):
  39. return nn.Sequential(
  40. nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
  41. nn.BatchNorm2d(out_channels),
  42. nn.ReLU(),
  43. )
  44. def forward(self,x):
  45. return self.gen(x)
  46. def initialize_weights(model):
  47. for m in model.modules():
  48. if isinstance(m,(nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
  49. nn.init.normal_(m.weight.data,0.0,0.02)
  50. def test():
  51. N, in_channels, H, W = 8, 3, 64, 64
  52. z_dim = 100
  53. x = torch.randn((N,in_channels, H, W))
  54. disc = Discriminator(in_channels, 8)
  55. initialize_weights(disc)
  56. assert disc(x).shape == (N, 1, 1, 1)
  57. gen = Generator(z_dim, in_channels, 8)
  58. initialize_weights(gen)
  59. z = torch.randn((N, z_dim, 1, 1))
  60. assert gen(z).shape == (N, in_channels, H, W)
  61. print("Success!")

Training

  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. import torchvision
  5. import torchvision.datasets as datasets
  6. import torchvision.transforms as transforms
  7. from torch.utils.data import DataLoader
  8. from torch.utils.tensorboard import SummaryWriter
  9. from model import Discriminator, Generator, initialize_weights
  10. ### Hyper parameters etc.
  11. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  12. print(torch.cuda.is_available())
  13. Learning_rate = 2e-4
  14. Batch_size = 128
  15. Image_size = 64
  16. CHANNELS_IMG = 1
  17. Noise_dim = 100
  18. Num_epochs = 5
  19. Features_disc = 64
  20. Features_gen = 64
  21. ### Image pre-processing
  22. transformx = transforms.Compose(
  23. [
  24. transforms.Resize(Image_size),
  25. transforms.ToTensor(),
  26. transforms.Normalize(
  27. [0.5 for _ in range(CHANNELS_IMG)],[0.5 for _ in range(CHANNELS_IMG)]
  28. ),
  29. ]
  30. )
  31. ### Load dataset
  32. dataset = datasets.MNIST(root="./data",train=True,transform=transformx,download=False)
  33. loader = DataLoader(dataset, batch_size=Batch_size, shuffle=True)
  34. ### Build the model
  35. gen = Generator(Noise_dim, CHANNELS_IMG, Features_gen).to(device)
  36. disc = Discriminator(CHANNELS_IMG,Features_disc).to(device)
  37. # initialize model parameter
  38. initialize_weights(gen)
  39. initialize_weights(disc)
  40. opt_gen = optim.Adam(gen.parameters(),lr=Learning_rate, betas=(0.5, 0.999))
  41. opt_disc = optim.Adam(disc.parameters(),lr=Learning_rate,betas=(0.5, 0.999))
  42. # loss function to calculate loss = y_n*log(x_n) + (1-y_n)*log(1-x_n)
  43. criterion = nn.BCELoss()
  44. # the fixed_noise is used to validate the model
  45. fixed_noise = torch.randn(32,Noise_dim,1,1).to(device)
  46. step = 0 # need a step for printing to tensorboard
  47. writer_real = SummaryWriter(f"logs/real")
  48. writer_fake = SummaryWriter(f"logs/fake")
  49. gen.train()
  50. disc.train()
  51. ## start training
  52. for epoch in range(Num_epochs):
  53. for batch_idx, (real, _) in enumerate(loader):
  54. # get real image from training set
  55. real = real.to(device)
  56. # generate fake image
  57. noise = torch.randn((Batch_size,Noise_dim,1,1)).to(device)
  58. fake = gen(noise)
  59. # Train discriminator max log(D(x)) + log(1-D(G(z)))
  60. disc_real = disc(real).reshape(-1)
  61. loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
  62. disc_fake = disc(fake).reshape(-1)
  63. loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
  64. # generate discriminator loss
  65. loss_disc = (loss_disc_real + loss_disc_fake)/2
  66. disc.zero_grad()
  67. loss_disc.backward(retain_graph=True)
  68. opt_disc.step()
  69. ### Train the Generator min log(1-D(G(z)) <--> max log(D(G(z)))
  70. output = disc(fake).reshape(-1)
  71. loss_gen = criterion(output,torch.ones_like(output))
  72. gen.zero_grad()
  73. loss_gen.backward()
  74. opt_gen.step()
  75. # Print losses occasionally and print to tensorboard
  76. if batch_idx % 100 == 0:
  77. print(
  78. f"Epoch [{epoch}/{Num_epochs}] Batch {batch_idx}/{len(loader)} \
  79. Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
  80. )
  81. with torch.no_grad():
  82. fake = gen(fixed_noise)
  83. # take out (up to) 32 examples
  84. img_grid_real = torchvision.utils.make_grid(
  85. real[:32], normalize=True
  86. )
  87. img_grid_fake = torchvision.utils.make_grid(
  88. fake[:32], normalize=True
  89. )
  90. writer_real.add_image("Real", img_grid_real, global_step=step)
  91. writer_fake.add_image("Fake", img_grid_fake, global_step=step)
  92. step += 1
  93. import matplotlib.pyplot as plt
  94. plt.figure(figsize=(10,5))
  95. plt.title("Generator and Discriminator Loss During Training")
  96. plt.plot(loss_gen.data.cpu().numpy(),label="G")
  97. plt.plot(loss_disc.data.cpu().numpy(),label="D")
  98. plt.xlabel("iterations")
  99. plt.ylabel("Loss")
  100. plt.legend()
  101. plt.show()