网络效果
网络结构
这部分主要参考:https://github.com/aladdinpersson/Machine-Learning-Collection/tree/master/ML/Pytorch/GANs/CycleGAN
Generator
import torchimport torch.nn as nnclass ConvBlock(nn.Module):def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):super().__init__()self.conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)if downelse nn.ConvTranspose2d(in_channels, out_channels, **kwargs),nn.InstanceNorm2d(out_channels),nn.ReLU(inplace=True) if use_act else nn.Identity())def forward(self,x):return self.conv(x)class ResidualBlock(nn.Module):def __init__(self, channels):super().__init__()self.block = nn.Sequential(ConvBlock(channels, channels, kernel_size=3, padding=1),ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),)def forward(self,x):return x + self.block(x)class Generator(nn.Module):def __init__(self, img_channels, num_features=64, num_residuals=6):super().__init__()self.initial = nn.Sequential(nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),nn.InstanceNorm2d(num_features),nn.ReLU(inplace=True))self.down_blocks = nn.ModuleList([ConvBlock(num_features,num_features*2, kernel_size=3, stride=2, padding=1),ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1)])self.residual_blocks = nn.Sequential(*[ResidualBlock(num_features*4) for _ in range(num_residuals)])self.up_blocks = nn.ModuleList([ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),ConvBlock(num_features*2, num_features, down=False, kernel_size=3, stride=2, padding=1, output_padding=1)])self.last = nn.Conv2d(num_features, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")def forward(self,x):x = self.initial(x)for layer in self.down_blocks:x = layer(x)x = self.residual_blocks(x)for layer in self.up_blocks:x = layer(x)x = self.last(x)return torch.tanh(x)def test():img_channels = 3img_size = 256x = torch.randn((2,img_channels,img_size,256))model = Generator(img_channels,num_residuals=6)preds = model(x)print(preds.shape)if __name__ == "__main__":test()
Discriminator
import torchimport torch.nn as nnclass Block(nn.Module):def __init__(self, in_channels, out_channels, stride):super().__init__()self.conv = nn.Sequential(nn.Conv2d(in_channels,out_channels,4,stride,1,bias=True,padding_mode='reflect'),nn.InstanceNorm2d(out_channels),nn.LeakyReLU(0.2))def forward(self,x):return self.conv(x)### Input image size: 3x256x256class Discriminator(nn.Module):def __init__(self, in_channels=3, features=[64,128,256,512]):super().__init__()self.initial = nn.Sequential(nn.Conv2d(in_channels,features[0],kernel_size=4,stride=2,padding=1,padding_mode="reflect"),nn.LeakyReLU(0.2),)layers = []in_channels = features[0]for feature in features[1:]:layers.append(Block(in_channels,feature,stride=1 if feature==features[-1] else 2))in_channels = featurelayers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode='reflect'))self.model = nn.Sequential(*layers)def forward(self,x):x = self.initial(x)x = self.model(x)return torch.sigmoid(x)def test():x = torch.randn((5,3,256,256))model = Discriminator()preds = model(x)print(preds.shape)if __name__ == "__main__":test()
Dataset
from PIL import Imageimport osfrom torch.utils.data import Datasetimport numpy as npclass HorseZebraDataset(Dataset):def __init__(self, root_zebra, root_horse, transform=None):self.root_zebra = root_zebraself.root_horse = root_horseself.transform = transformself.zebra_images = os.listdir(root_zebra)self.horse_images = os.listdir(root_horse)self.zebra_len = len(self.zebra_images)self.horse_len = len(self.horse_images)self.length_dataset = max(self.zebra_len, self.horse_len)def __len__(self):return self.length_datasetdef __getitem__(self, idx):zebra_img = self.zebra_images[idx % self.zebra_len]horse_img = self.horse_images[idx % self.horse_len]zebra_path = os.path.join(self.root_zebra, zebra_img)horse_path = os.path.join(self.root_horse, horse_img)zebra_img = np.array(Image.open(zebra_path).convert("RGB"))horse_img = np.array(Image.open(horse_path).convert("RGB"))if self.transform:augmentations = self.transform(image=zebra_img, image0=horse_img)horse_img = augmentations["image0"]zebra_img = augmentations["image"]return {'A':horse_img , 'B': zebra_img}
网络训练
这部分主要参考:https://github.com/aitorzip/PyTorch-CycleGAN
Training
import argparseimport itertoolsfrom torch.utils.data import DataLoaderfrom torch.autograd import Variableimport torchimport torch.nn as nnimport torch.optim as optimfrom generator_model import Generatorfrom discriminator_model import Discriminatorfrom dataset import HorseZebraDatasetimport albumentations as Afrom albumentations.pytorch import ToTensorV2from utils import ReplayBufferdef main(opt):device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")### Load the data# image pre-processingtransforms = A.Compose([A.Resize(width=256, height=256),A.HorizontalFlip(p=0.5),A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),ToTensorV2(),],additional_targets={"image0": "image"},)datasets = HorseZebraDataset(root_horse=opt.data_root + "trainA", root_zebra=opt.data_root + "trainB",transform=transforms)loader = DataLoader(datasets, batch_size=opt.batch_size, shuffle=True, num_workers=4)### Building the NetworknetG_A2B = Generator(opt.input_nc).to(device)netG_B2A = Generator(opt.input_nc).to(device)netD_A = Discriminator(opt.input_nc).to(device)netD_B = Discriminator(opt.input_nc).to(device)# Lossescriterion_GAN = nn.MSELoss()criterion_cycle = nn.L1Loss()criterion_identity = nn.L1Loss()# Optimizers & LR schedulersoptimizer_G = optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),lr=opt.lr, betas=(0.5, 0.999))optimizer_D_A = optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999))optimizer_D_B = optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999))### Inputs & targets memory allocationTensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensorinput_A = Tensor(opt.batch_size, opt.input_nc, opt.size, opt.size)input_B = Tensor(opt.batch_size, opt.input_nc, opt.size, opt.size)fake_A_buffer = ReplayBuffer()fake_B_buffer = ReplayBuffer()# Loss plotfor epoch in range(opt.n_epochs):for idx, batch in enumerate(loader):# set model inputreal_A = Variable(input_A.copy_(batch['A']))real_B = Variable(input_B.copy_(batch['B']))### generate A2B and B2A ###optimizer_G.zero_grad()# Identity loss# G_A2B(B) should equal B if real B is fedsame_B = netG_A2B(real_B)loss_identity_B = criterion_identity(same_B, real_B)*5.0# G_B2A(A) should equal A if real A is fedsame_A = netG_B2A(real_A)loss_identity_A = criterion_identity(same_A, real_A)*5.0# GAN lossfake_B = netG_A2B(real_A)pred_fake = netD_B(fake_B)loss_GAN_A2B = criterion_GAN(pred_fake, torch.ones_like(pred_fake))fake_A = netG_B2A(real_B)pred_fake = netD_A(fake_A)loss_GAN_B2A = criterion_GAN(pred_fake, torch.ones_like(pred_fake))# Cycle losscycle_A = netG_B2A(fake_B)cycle_B = netG_A2B(fake_A)loss_cycle = criterion_cycle(cycle_A, real_A) + criterion_cycle(cycle_B, real_B)loss_cycle *= 10.0loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycleloss_G.backward()optimizer_G.step()### Discriminator A ###optimizer_D_A.zero_grad()# Real losspred_real = netD_A(real_A)loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))# Fake lossfake_A = fake_A_buffer.push_and_pop(fake_A)pred_fake = netD_A(fake_A.detach())loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_real))# Total lossloss_D_A = (loss_D_real + loss_D_fake)*0.5loss_D_A.backward()optimizer_D_A.step()### Discriminator B ###optimizer_D_B.zero_grad()# Real losspred_real = netD_B(real_B)loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))# Fake lossfake_B = fake_B_buffer.push_and_pop(fake_B)pred_fake = netD_B(fake_B.detach())loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_real))# Total lossloss_D_B = (loss_D_real + loss_D_fake)*0.5loss_D_B.backward()optimizer_D_B.step()if idx % 50 == 0:print(f"Epoch [{epoch}/{opt.n_epochs}] Batch {idx}/{len(loader)} \Loss G: {loss_G:.4f}, loss_cycle: {loss_cycle:.4f}, loss_D_A: {loss_D_A:.4f},")torch.save(netG_A2B.state_dict(),'./output/netG_A2B.pth')torch.save(netG_B2A.state_dict(), './output/netG_B2A.pth')torch.save(netD_A.state_dict(), './output/netD_A.pth')torch.save(netD_B.state_dict(), './output/netD_B.pth')if __name__ == "__main__":parser = argparse.ArgumentParser()parser.add_argument('--n_epochs', type=int, default=200, help="number of epochs of training")parser.add_argument('--batch_size', type=int, default=2, help="size of the batches")parser.add_argument('--data_root', type=str, default='./data/horse2zebra/', help="root directory of the dataset")parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate')parser.add_argument('--size', type=int, default=256, help='size of data crop(squared assumed)')parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data')parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data')opt = parser.parse_args()print(opt)main(opt)
Predict
import argparse
import sys
import os
from PIL import Image
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch
import numpy as np
from generator_model import Generator
from dataset import HorseZebraDataset
parser = argparse.ArgumentParser()
parser.add_argument('--batchSize', type=int, default=1, help='size of the batches')
parser.add_argument('--dataroot', type=str, default='datasets/horse2zebra/', help='root directory of the dataset')
parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data')
parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data')
parser.add_argument('--size', type=int, default=256, help='size of the data (squared assumed)')
parser.add_argument('--n_cpu', type=int, default=1, help='number of cpu threads to use during batch generation')
parser.add_argument('--generator_A2B', type=str, default='./output/netG_A2B.pth', help='A2B generator checkpoint file')
parser.add_argument('--generator_B2A', type=str, default='./output/netG_B2A.pth', help='B2A generator checkpoint file')
opt = parser.parse_args()
print(opt)
###### Definition of variables ######
# Networks
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
netG_A2B = Generator(opt.input_nc).to(device)
netG_B2A = Generator(opt.output_nc).to(device)
# Load state dicts
netG_A2B.load_state_dict(torch.load(opt.generator_A2B))
netG_B2A.load_state_dict(torch.load(opt.generator_B2A))
# Set model's test mode
netG_A2B.eval()
netG_B2A.eval()
# Inputs & targets memory allocation
horse_path = "D:/d2l/CycleGAN/data/horse2zebra/trainA/n02381460_36.jpg"
horse_img = np.array(Image.open(horse_path).convert("RGB"))
zebra_path = "D:/d2l/CycleGAN/data/horse2zebra/trainB/n02391049_77.jpg"
zebra_img = np.array(Image.open(zebra_path).convert("RGB"))
# transforms.ToTensor()
transform = transforms.Compose([
transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]
)
real_A = transform(horse_img).unsqueeze(0).cuda()
real_B = transform(zebra_img).unsqueeze(0).cuda()
fake_A = 0.5*(netG_B2A(real_B).data + 1.0)
fake_B = 0.5*(netG_A2B(real_A).data + 1.0)
out = fake_B.squeeze().cpu().numpy()
img_1 = np.transpose(out, (1,2,0))
out = fake_A.squeeze().cpu().numpy()
img_2 = np.transpose(out, (1,2,0))
import matplotlib.pyplot as plt
plt.subplot(221),plt.imshow(horse_img),plt.title("input image"),plt.axis("off")
plt.subplot(222),plt.imshow(img_1),plt.title("output image"),plt.axis("off")
plt.subplot(223),plt.imshow(zebra_img),plt.title("input image"),plt.axis("off")
plt.subplot(224),plt.imshow(img_2),plt.title("output image"),plt.axis("off")
