网络效果

image.png

网络结构

这部分主要参考:https://github.com/aladdinpersson/Machine-Learning-Collection/tree/master/ML/Pytorch/GANs/CycleGAN

Generator

  1. import torch
  2. import torch.nn as nn
  3. class ConvBlock(nn.Module):
  4. def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
  5. super().__init__()
  6. self.conv = nn.Sequential(
  7. nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
  8. if down
  9. else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
  10. nn.InstanceNorm2d(out_channels),
  11. nn.ReLU(inplace=True) if use_act else nn.Identity()
  12. )
  13. def forward(self,x):
  14. return self.conv(x)
  15. class ResidualBlock(nn.Module):
  16. def __init__(self, channels):
  17. super().__init__()
  18. self.block = nn.Sequential(
  19. ConvBlock(channels, channels, kernel_size=3, padding=1),
  20. ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
  21. )
  22. def forward(self,x):
  23. return x + self.block(x)
  24. class Generator(nn.Module):
  25. def __init__(self, img_channels, num_features=64, num_residuals=6):
  26. super().__init__()
  27. self.initial = nn.Sequential(
  28. nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
  29. nn.InstanceNorm2d(num_features),
  30. nn.ReLU(inplace=True)
  31. )
  32. self.down_blocks = nn.ModuleList(
  33. [
  34. ConvBlock(num_features,num_features*2, kernel_size=3, stride=2, padding=1),
  35. ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1)
  36. ]
  37. )
  38. self.residual_blocks = nn.Sequential(
  39. *[ResidualBlock(num_features*4) for _ in range(num_residuals)]
  40. )
  41. self.up_blocks = nn.ModuleList(
  42. [
  43. ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
  44. ConvBlock(num_features*2, num_features, down=False, kernel_size=3, stride=2, padding=1, output_padding=1)
  45. ]
  46. )
  47. self.last = nn.Conv2d(num_features, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")
  48. def forward(self,x):
  49. x = self.initial(x)
  50. for layer in self.down_blocks:
  51. x = layer(x)
  52. x = self.residual_blocks(x)
  53. for layer in self.up_blocks:
  54. x = layer(x)
  55. x = self.last(x)
  56. return torch.tanh(x)
  57. def test():
  58. img_channels = 3
  59. img_size = 256
  60. x = torch.randn((2,img_channels,img_size,256))
  61. model = Generator(img_channels,num_residuals=6)
  62. preds = model(x)
  63. print(preds.shape)
  64. if __name__ == "__main__":
  65. test()

Discriminator

  1. import torch
  2. import torch.nn as nn
  3. class Block(nn.Module):
  4. def __init__(self, in_channels, out_channels, stride):
  5. super().__init__()
  6. self.conv = nn.Sequential(
  7. nn.Conv2d(in_channels,out_channels,4,stride,1,bias=True,padding_mode='reflect'),
  8. nn.InstanceNorm2d(out_channels),
  9. nn.LeakyReLU(0.2)
  10. )
  11. def forward(self,x):
  12. return self.conv(x)
  13. ### Input image size: 3x256x256
  14. class Discriminator(nn.Module):
  15. def __init__(self, in_channels=3, features=[64,128,256,512]):
  16. super().__init__()
  17. self.initial = nn.Sequential(
  18. nn.Conv2d(
  19. in_channels,
  20. features[0],
  21. kernel_size=4,
  22. stride=2,
  23. padding=1,
  24. padding_mode="reflect"
  25. ),
  26. nn.LeakyReLU(0.2),
  27. )
  28. layers = []
  29. in_channels = features[0]
  30. for feature in features[1:]:
  31. layers.append(Block(in_channels,feature,stride=1 if feature==features[-1] else 2))
  32. in_channels = feature
  33. layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode='reflect'))
  34. self.model = nn.Sequential(*layers)
  35. def forward(self,x):
  36. x = self.initial(x)
  37. x = self.model(x)
  38. return torch.sigmoid(x)
  39. def test():
  40. x = torch.randn((5,3,256,256))
  41. model = Discriminator()
  42. preds = model(x)
  43. print(preds.shape)
  44. if __name__ == "__main__":
  45. test()

Dataset

  1. from PIL import Image
  2. import os
  3. from torch.utils.data import Dataset
  4. import numpy as np
  5. class HorseZebraDataset(Dataset):
  6. def __init__(self, root_zebra, root_horse, transform=None):
  7. self.root_zebra = root_zebra
  8. self.root_horse = root_horse
  9. self.transform = transform
  10. self.zebra_images = os.listdir(root_zebra)
  11. self.horse_images = os.listdir(root_horse)
  12. self.zebra_len = len(self.zebra_images)
  13. self.horse_len = len(self.horse_images)
  14. self.length_dataset = max(self.zebra_len, self.horse_len)
  15. def __len__(self):
  16. return self.length_dataset
  17. def __getitem__(self, idx):
  18. zebra_img = self.zebra_images[idx % self.zebra_len]
  19. horse_img = self.horse_images[idx % self.horse_len]
  20. zebra_path = os.path.join(self.root_zebra, zebra_img)
  21. horse_path = os.path.join(self.root_horse, horse_img)
  22. zebra_img = np.array(Image.open(zebra_path).convert("RGB"))
  23. horse_img = np.array(Image.open(horse_path).convert("RGB"))
  24. if self.transform:
  25. augmentations = self.transform(image=zebra_img, image0=horse_img)
  26. horse_img = augmentations["image0"]
  27. zebra_img = augmentations["image"]
  28. return {'A':horse_img , 'B': zebra_img}

网络训练

这部分主要参考:https://github.com/aitorzip/PyTorch-CycleGAN

Training

  1. import argparse
  2. import itertools
  3. from torch.utils.data import DataLoader
  4. from torch.autograd import Variable
  5. import torch
  6. import torch.nn as nn
  7. import torch.optim as optim
  8. from generator_model import Generator
  9. from discriminator_model import Discriminator
  10. from dataset import HorseZebraDataset
  11. import albumentations as A
  12. from albumentations.pytorch import ToTensorV2
  13. from utils import ReplayBuffer
  14. def main(opt):
  15. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  16. ### Load the data
  17. # image pre-processing
  18. transforms = A.Compose(
  19. [
  20. A.Resize(width=256, height=256),
  21. A.HorizontalFlip(p=0.5),
  22. A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
  23. ToTensorV2(),
  24. ],
  25. additional_targets={"image0": "image"},
  26. )
  27. datasets = HorseZebraDataset(root_horse=opt.data_root + "trainA", root_zebra=opt.data_root + "trainB",
  28. transform=transforms)
  29. loader = DataLoader(datasets, batch_size=opt.batch_size, shuffle=True, num_workers=4)
  30. ### Building the Network
  31. netG_A2B = Generator(opt.input_nc).to(device)
  32. netG_B2A = Generator(opt.input_nc).to(device)
  33. netD_A = Discriminator(opt.input_nc).to(device)
  34. netD_B = Discriminator(opt.input_nc).to(device)
  35. # Losses
  36. criterion_GAN = nn.MSELoss()
  37. criterion_cycle = nn.L1Loss()
  38. criterion_identity = nn.L1Loss()
  39. # Optimizers & LR schedulers
  40. optimizer_G = optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
  41. lr=opt.lr, betas=(0.5, 0.999))
  42. optimizer_D_A = optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999))
  43. optimizer_D_B = optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999))
  44. ### Inputs & targets memory allocation
  45. Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
  46. input_A = Tensor(opt.batch_size, opt.input_nc, opt.size, opt.size)
  47. input_B = Tensor(opt.batch_size, opt.input_nc, opt.size, opt.size)
  48. fake_A_buffer = ReplayBuffer()
  49. fake_B_buffer = ReplayBuffer()
  50. # Loss plot
  51. for epoch in range(opt.n_epochs):
  52. for idx, batch in enumerate(loader):
  53. # set model input
  54. real_A = Variable(input_A.copy_(batch['A']))
  55. real_B = Variable(input_B.copy_(batch['B']))
  56. ### generate A2B and B2A ###
  57. optimizer_G.zero_grad()
  58. # Identity loss
  59. # G_A2B(B) should equal B if real B is fed
  60. same_B = netG_A2B(real_B)
  61. loss_identity_B = criterion_identity(same_B, real_B)*5.0
  62. # G_B2A(A) should equal A if real A is fed
  63. same_A = netG_B2A(real_A)
  64. loss_identity_A = criterion_identity(same_A, real_A)*5.0
  65. # GAN loss
  66. fake_B = netG_A2B(real_A)
  67. pred_fake = netD_B(fake_B)
  68. loss_GAN_A2B = criterion_GAN(pred_fake, torch.ones_like(pred_fake))
  69. fake_A = netG_B2A(real_B)
  70. pred_fake = netD_A(fake_A)
  71. loss_GAN_B2A = criterion_GAN(pred_fake, torch.ones_like(pred_fake))
  72. # Cycle loss
  73. cycle_A = netG_B2A(fake_B)
  74. cycle_B = netG_A2B(fake_A)
  75. loss_cycle = criterion_cycle(cycle_A, real_A) + criterion_cycle(cycle_B, real_B)
  76. loss_cycle *= 10.0
  77. loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle
  78. loss_G.backward()
  79. optimizer_G.step()
  80. ### Discriminator A ###
  81. optimizer_D_A.zero_grad()
  82. # Real loss
  83. pred_real = netD_A(real_A)
  84. loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))
  85. # Fake loss
  86. fake_A = fake_A_buffer.push_and_pop(fake_A)
  87. pred_fake = netD_A(fake_A.detach())
  88. loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_real))
  89. # Total loss
  90. loss_D_A = (loss_D_real + loss_D_fake)*0.5
  91. loss_D_A.backward()
  92. optimizer_D_A.step()
  93. ### Discriminator B ###
  94. optimizer_D_B.zero_grad()
  95. # Real loss
  96. pred_real = netD_B(real_B)
  97. loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))
  98. # Fake loss
  99. fake_B = fake_B_buffer.push_and_pop(fake_B)
  100. pred_fake = netD_B(fake_B.detach())
  101. loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_real))
  102. # Total loss
  103. loss_D_B = (loss_D_real + loss_D_fake)*0.5
  104. loss_D_B.backward()
  105. optimizer_D_B.step()
  106. if idx % 50 == 0:
  107. print(
  108. f"Epoch [{epoch}/{opt.n_epochs}] Batch {idx}/{len(loader)} \
  109. Loss G: {loss_G:.4f}, loss_cycle: {loss_cycle:.4f}, loss_D_A: {loss_D_A:.4f},"
  110. )
  111. torch.save(netG_A2B.state_dict(),'./output/netG_A2B.pth')
  112. torch.save(netG_B2A.state_dict(), './output/netG_B2A.pth')
  113. torch.save(netD_A.state_dict(), './output/netD_A.pth')
  114. torch.save(netD_B.state_dict(), './output/netD_B.pth')
  115. if __name__ == "__main__":
  116. parser = argparse.ArgumentParser()
  117. parser.add_argument('--n_epochs', type=int, default=200, help="number of epochs of training")
  118. parser.add_argument('--batch_size', type=int, default=2, help="size of the batches")
  119. parser.add_argument('--data_root', type=str, default='./data/horse2zebra/', help="root directory of the dataset")
  120. parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate')
  121. parser.add_argument('--size', type=int, default=256, help='size of data crop(squared assumed)')
  122. parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data')
  123. parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data')
  124. opt = parser.parse_args()
  125. print(opt)
  126. main(opt)

Predict

import argparse
import sys
import os
from PIL import Image
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch
import numpy as np
from generator_model import Generator
from dataset import HorseZebraDataset

parser = argparse.ArgumentParser()
parser.add_argument('--batchSize', type=int, default=1, help='size of the batches')
parser.add_argument('--dataroot', type=str, default='datasets/horse2zebra/', help='root directory of the dataset')
parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data')
parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data')
parser.add_argument('--size', type=int, default=256, help='size of the data (squared assumed)')
parser.add_argument('--n_cpu', type=int, default=1, help='number of cpu threads to use during batch generation')
parser.add_argument('--generator_A2B', type=str, default='./output/netG_A2B.pth', help='A2B generator checkpoint file')
parser.add_argument('--generator_B2A', type=str, default='./output/netG_B2A.pth', help='B2A generator checkpoint file')
opt = parser.parse_args()
print(opt)


###### Definition of variables ######
# Networks
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

netG_A2B = Generator(opt.input_nc).to(device)
netG_B2A = Generator(opt.output_nc).to(device)


# Load state dicts
netG_A2B.load_state_dict(torch.load(opt.generator_A2B))
netG_B2A.load_state_dict(torch.load(opt.generator_B2A))

# Set model's test mode
netG_A2B.eval()
netG_B2A.eval()

# Inputs & targets memory allocation
horse_path = "D:/d2l/CycleGAN/data/horse2zebra/trainA/n02381460_36.jpg"
horse_img = np.array(Image.open(horse_path).convert("RGB"))

zebra_path = "D:/d2l/CycleGAN/data/horse2zebra/trainB/n02391049_77.jpg"
zebra_img = np.array(Image.open(zebra_path).convert("RGB"))

# transforms.ToTensor()
transform = transforms.Compose([
    transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ]
)


real_A = transform(horse_img).unsqueeze(0).cuda()
real_B = transform(zebra_img).unsqueeze(0).cuda()


fake_A = 0.5*(netG_B2A(real_B).data + 1.0)
fake_B = 0.5*(netG_A2B(real_A).data + 1.0)

out = fake_B.squeeze().cpu().numpy()
img_1 = np.transpose(out, (1,2,0))

out = fake_A.squeeze().cpu().numpy()
img_2 = np.transpose(out, (1,2,0))

import matplotlib.pyplot as plt
plt.subplot(221),plt.imshow(horse_img),plt.title("input image"),plt.axis("off")
plt.subplot(222),plt.imshow(img_1),plt.title("output image"),plt.axis("off")
plt.subplot(223),plt.imshow(zebra_img),plt.title("input image"),plt.axis("off")
plt.subplot(224),plt.imshow(img_2),plt.title("output image"),plt.axis("off")

训练好的网络

output.zip