《深度学习之Pytorch实战计算机视觉》阅读笔记 第8章:图像风格迁移实战

1. 图像风格迁移

首先选取一幅图像作为基准图像,又称为内容图像,然后选取另一幅或者图像作为希望获取相应风格的图像,称其为风格图像。图像风格迁移就是保证内容图像的内容完整性的前提下,将风格图像的风格融入内容图像中,使内容图像的原始风格最后发生转变,最后的输出的图像呈现的是内容图像和风格图像之间的理想融合。

2. 图像内容损失

1)图像的内容损失

  1. # 使用均方误差作为损失函数
  2. import torch
  3. from torch import nn
  4. class Content_loss(torch.nn.Module):
  5. def __init__(self,weight,target):
  6. super(Content_loss,self).__init__()
  7. self.weight=weight
  8. self.target=target.detach()*weight
  9. # detach()对提取到的内容进行锁定,不需要进行梯度
  10. self.loss_fn=torch.nn.MSELoss()
  11. def forward(self,input):
  12. self.loss=self.loss_fn(input*self.weight,self.target)
  13. return input
  14. def backward(self):
  15. self.loss.backward(retain_graph=True)
  16. return self.loss

2)图像的风格损失

  1. # 使用均方误差作为损失函数
  2. class Style_loss(torch.nn.Module):
  3. def __init__(self,weight,target):
  4. super(Style_loss,self).__init__()
  5. self.weight=weight
  6. self.target=target.detach()*weight
  7. self.loss_fn=torch.nn.MSELoss()
  8. # 格拉姆矩阵,矩阵的内积运算,放大“风格”
  9. self.gram=Gram_matrix()
  10. def forward(self,input):
  11. self.Gram=self.gram(input.clone())
  12. self.Gram.mul_(self.weight)
  13. self.loss=self.loss_fn(self.Gram,self.target)
  14. return input
  15. def backward(self):
  16. self.loss.backward(retain_graph=True)
  17. return self.loss
  18. class Gram_matrix(torch.nn.Module):
  19. def forward(self,input):
  20. a,b,c,d=input.size()
  21. feature=input.view(a*b,c*d)
  22. gram=torch.mm(feature,feature.t())
  23. return gram.div(a*b*c*d)

3. 模型搭建

1)特征提取

  1. cnn = models.vgg16(pretrained=True).features
  2. cnn = cnn.cuda()
  3. # 分别指定提取内容特征和风格特征的卷积层
  4. content_layer = ['Conv_3']
  5. style_layer = ['Conv_1', 'Conv_2', 'Conv_3', 'Conv_4']

2)模型迁移

  1. index = 1
  2. for layer in list(model)[:8]:
  3. if isinstance(layer, torch.nn.Conv2d):
  4. name = 'Conv_' + str(index)
  5. new_model.add_module(name, layer)
  6. if name in content_layer:
  7. target = new_model(content_image).clone()
  8. content_loss = Content_loss(content_weight, target)
  9. new_model.add_module('content_loss_' + str(index), content_loss)
  10. content_losses.append(content_loss)
  11. if name in style_layer:
  12. target = new_model(style_image).clone()
  13. target = gram(target)
  14. style_loss = Style_loss(style_weight, target)
  15. new_model.add_module('style_loss_' + str(index), style_loss)
  16. style_losses.append(style_loss)
  17. if isinstance(layer, torch.nn.ReLU):
  18. name = 'Relu_' + str(index)
  19. new_model.add_module(name, layer)
  20. index += 1
  21. if isinstance(layer, torch.nn.MaxPool2d):
  22. name = 'MaxPool_' + str(index)
  23. new_model.add_module(name, layer)

此时,模型输出结果:
20190828094519.png

3)参数优化

  1. input_img = content_image.clone()
  2. parameter = torch.nn.Parameter(input_img.data)
  3. optimizer = torch.optim.LBFGS([parameter])

4. 训练

  1. epoch_n = 300
  2. epoch = [0]
  3. while epoch[0] <= epoch_n:
  4. def closure():
  5. optimizer.zero_grad()
  6. style_score = 0
  7. content_score = 0
  8. parameter.data.clamp_(0, 1)
  9. new_model(parameter)
  10. for sl in style_losses:
  11. style_score += sl.backward()
  12. for cl in content_losses:
  13. content_score += cl.backward()
  14. epoch[0] += 1
  15. if epoch[0] % 50 == 0:
  16. print('Epoch:{} Style Loss:{:4f} Content Loss:{:4f}'.format(
  17. epoch[0], style_score.item(), content_score.item()))
  18. return style_score + content_score
  19. optimizer.step(closure)