《深度学习之Pytorch实战计算机视觉》阅读笔记 第8章:图像风格迁移实战
1. 图像风格迁移
首先选取一幅图像作为基准图像,又称为内容图像,然后选取另一幅或者图像作为希望获取相应风格的图像,称其为风格图像。图像风格迁移就是保证内容图像的内容完整性的前提下,将风格图像的风格融入内容图像中,使内容图像的原始风格最后发生转变,最后的输出的图像呈现的是内容图像和风格图像之间的理想融合。
2. 图像内容损失
1)图像的内容损失
# 使用均方误差作为损失函数
import torch
from torch import nn
class Content_loss(torch.nn.Module):
def __init__(self,weight,target):
super(Content_loss,self).__init__()
self.weight=weight
self.target=target.detach()*weight
# detach()对提取到的内容进行锁定,不需要进行梯度
self.loss_fn=torch.nn.MSELoss()
def forward(self,input):
self.loss=self.loss_fn(input*self.weight,self.target)
return input
def backward(self):
self.loss.backward(retain_graph=True)
return self.loss
2)图像的风格损失
# 使用均方误差作为损失函数
class Style_loss(torch.nn.Module):
def __init__(self,weight,target):
super(Style_loss,self).__init__()
self.weight=weight
self.target=target.detach()*weight
self.loss_fn=torch.nn.MSELoss()
# 格拉姆矩阵,矩阵的内积运算,放大“风格”
self.gram=Gram_matrix()
def forward(self,input):
self.Gram=self.gram(input.clone())
self.Gram.mul_(self.weight)
self.loss=self.loss_fn(self.Gram,self.target)
return input
def backward(self):
self.loss.backward(retain_graph=True)
return self.loss
class Gram_matrix(torch.nn.Module):
def forward(self,input):
a,b,c,d=input.size()
feature=input.view(a*b,c*d)
gram=torch.mm(feature,feature.t())
return gram.div(a*b*c*d)
3. 模型搭建
1)特征提取
cnn = models.vgg16(pretrained=True).features
cnn = cnn.cuda()
# 分别指定提取内容特征和风格特征的卷积层
content_layer = ['Conv_3']
style_layer = ['Conv_1', 'Conv_2', 'Conv_3', 'Conv_4']
2)模型迁移
index = 1
for layer in list(model)[:8]:
if isinstance(layer, torch.nn.Conv2d):
name = 'Conv_' + str(index)
new_model.add_module(name, layer)
if name in content_layer:
target = new_model(content_image).clone()
content_loss = Content_loss(content_weight, target)
new_model.add_module('content_loss_' + str(index), content_loss)
content_losses.append(content_loss)
if name in style_layer:
target = new_model(style_image).clone()
target = gram(target)
style_loss = Style_loss(style_weight, target)
new_model.add_module('style_loss_' + str(index), style_loss)
style_losses.append(style_loss)
if isinstance(layer, torch.nn.ReLU):
name = 'Relu_' + str(index)
new_model.add_module(name, layer)
index += 1
if isinstance(layer, torch.nn.MaxPool2d):
name = 'MaxPool_' + str(index)
new_model.add_module(name, layer)
此时,模型输出结果:
3)参数优化
input_img = content_image.clone()
parameter = torch.nn.Parameter(input_img.data)
optimizer = torch.optim.LBFGS([parameter])
4. 训练
epoch_n = 300
epoch = [0]
while epoch[0] <= epoch_n:
def closure():
optimizer.zero_grad()
style_score = 0
content_score = 0
parameter.data.clamp_(0, 1)
new_model(parameter)
for sl in style_losses:
style_score += sl.backward()
for cl in content_losses:
content_score += cl.backward()
epoch[0] += 1
if epoch[0] % 50 == 0:
print('Epoch:{} Style Loss:{:4f} Content Loss:{:4f}'.format(
epoch[0], style_score.item(), content_score.item()))
return style_score + content_score
optimizer.step(closure)