Code
Network Architecture
### Train of DCGAN network on MNIST dataset with Discriminator
### and Generator imported from models.py
import torch
import torch.nn as nn
# they did not use batchNorm in the early in the first layer of the discriminator
class Discriminator(nn.Module):
def __init__(self,channels_img, features_d):
super(Discriminator,self).__init__()
self.disc = nn.Sequential(
nn.Conv2d(channels_img, features_d, kernel_size=4,stride=2,padding=1),
nn.LeakyReLU(0.2),
self._block(features_d, features_d*2, 4, 2, 1),
self._block(features_d*2, features_d*4, 4, 2, 1),
self._block(features_d*4, features_d*8, 4, 2, 1),
nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0),
nn.Sigmoid(),
)
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2),
)
def forward(self,x):
return self.disc(x)
class Generator(nn.Module):
def __init__(self, z_dim, channels_img, features_g):
super(Generator,self).__init__()
self.gen = nn.Sequential(
# Input: N x z_dim x 1 x 1
self._block(z_dim, features_g*16, 4, 1, 0), # N x f_g*16 x 4 x 4
self._block(features_g*16, features_g*8, 4, 2 ,1),
self._block(features_g*8, features_g*4, 4, 2, 1),
self._block(features_g*4, features_g*2, 4, 2, 1),
nn.ConvTranspose2d(features_g*2, channels_img, kernel_size=4, stride=2, padding=1),
nn.Tanh(), #[-1,1]
)
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
def forward(self,x):
return self.gen(x)
def initialize_weights(model):
for m in model.modules():
if isinstance(m,(nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
nn.init.normal_(m.weight.data,0.0,0.02)
def test():
N, in_channels, H, W = 8, 3, 64, 64
z_dim = 100
x = torch.randn((N,in_channels, H, W))
disc = Discriminator(in_channels, 8)
initialize_weights(disc)
assert disc(x).shape == (N, 1, 1, 1)
gen = Generator(z_dim, in_channels, 8)
initialize_weights(gen)
z = torch.randn((N, z_dim, 1, 1))
assert gen(z).shape == (N, in_channels, H, W)
print("Success!")
Training
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator, Generator, initialize_weights
### Hyper parameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.cuda.is_available())
Learning_rate = 2e-4
Batch_size = 128
Image_size = 64
CHANNELS_IMG = 1
Noise_dim = 100
Num_epochs = 5
Features_disc = 64
Features_gen = 64
### Image pre-processing
transformx = transforms.Compose(
[
transforms.Resize(Image_size),
transforms.ToTensor(),
transforms.Normalize(
[0.5 for _ in range(CHANNELS_IMG)],[0.5 for _ in range(CHANNELS_IMG)]
),
]
)
### Load dataset
dataset = datasets.MNIST(root="./data",train=True,transform=transformx,download=False)
loader = DataLoader(dataset, batch_size=Batch_size, shuffle=True)
### Build the model
gen = Generator(Noise_dim, CHANNELS_IMG, Features_gen).to(device)
disc = Discriminator(CHANNELS_IMG,Features_disc).to(device)
# initialize model parameter
initialize_weights(gen)
initialize_weights(disc)
opt_gen = optim.Adam(gen.parameters(),lr=Learning_rate, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(),lr=Learning_rate,betas=(0.5, 0.999))
# loss function to calculate loss = y_n*log(x_n) + (1-y_n)*log(1-x_n)
criterion = nn.BCELoss()
# the fixed_noise is used to validate the model
fixed_noise = torch.randn(32,Noise_dim,1,1).to(device)
step = 0 # need a step for printing to tensorboard
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
gen.train()
disc.train()
## start training
for epoch in range(Num_epochs):
for batch_idx, (real, _) in enumerate(loader):
# get real image from training set
real = real.to(device)
# generate fake image
noise = torch.randn((Batch_size,Noise_dim,1,1)).to(device)
fake = gen(noise)
# Train discriminator max log(D(x)) + log(1-D(G(z)))
disc_real = disc(real).reshape(-1)
loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
disc_fake = disc(fake).reshape(-1)
loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
# generate discriminator loss
loss_disc = (loss_disc_real + loss_disc_fake)/2
disc.zero_grad()
loss_disc.backward(retain_graph=True)
opt_disc.step()
### Train the Generator min log(1-D(G(z)) <--> max log(D(G(z)))
output = disc(fake).reshape(-1)
loss_gen = criterion(output,torch.ones_like(output))
gen.zero_grad()
loss_gen.backward()
opt_gen.step()
# Print losses occasionally and print to tensorboard
if batch_idx % 100 == 0:
print(
f"Epoch [{epoch}/{Num_epochs}] Batch {batch_idx}/{len(loader)} \
Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
)
with torch.no_grad():
fake = gen(fixed_noise)
# take out (up to) 32 examples
img_grid_real = torchvision.utils.make_grid(
real[:32], normalize=True
)
img_grid_fake = torchvision.utils.make_grid(
fake[:32], normalize=True
)
writer_real.add_image("Real", img_grid_real, global_step=step)
writer_fake.add_image("Fake", img_grid_fake, global_step=step)
step += 1
import matplotlib.pyplot as plt
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(loss_gen.data.cpu().numpy(),label="G")
plt.plot(loss_disc.data.cpu().numpy(),label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()