《深度学习之Pytorch实战计算机视觉》阅读笔记 第8章:图像风格迁移实战
1. 图像风格迁移
首先选取一幅图像作为基准图像,又称为内容图像,然后选取另一幅或者图像作为希望获取相应风格的图像,称其为风格图像。图像风格迁移就是保证内容图像的内容完整性的前提下,将风格图像的风格融入内容图像中,使内容图像的原始风格最后发生转变,最后的输出的图像呈现的是内容图像和风格图像之间的理想融合。
2. 图像内容损失
1)图像的内容损失
# 使用均方误差作为损失函数import torchfrom torch import nnclass Content_loss(torch.nn.Module):def __init__(self,weight,target):super(Content_loss,self).__init__()self.weight=weightself.target=target.detach()*weight# detach()对提取到的内容进行锁定,不需要进行梯度self.loss_fn=torch.nn.MSELoss()def forward(self,input):self.loss=self.loss_fn(input*self.weight,self.target)return inputdef backward(self):self.loss.backward(retain_graph=True)return self.loss
2)图像的风格损失
# 使用均方误差作为损失函数class Style_loss(torch.nn.Module):def __init__(self,weight,target):super(Style_loss,self).__init__()self.weight=weightself.target=target.detach()*weightself.loss_fn=torch.nn.MSELoss()# 格拉姆矩阵,矩阵的内积运算,放大“风格”self.gram=Gram_matrix()def forward(self,input):self.Gram=self.gram(input.clone())self.Gram.mul_(self.weight)self.loss=self.loss_fn(self.Gram,self.target)return inputdef backward(self):self.loss.backward(retain_graph=True)return self.lossclass Gram_matrix(torch.nn.Module):def forward(self,input):a,b,c,d=input.size()feature=input.view(a*b,c*d)gram=torch.mm(feature,feature.t())return gram.div(a*b*c*d)
3. 模型搭建
1)特征提取
cnn = models.vgg16(pretrained=True).featurescnn = cnn.cuda()# 分别指定提取内容特征和风格特征的卷积层content_layer = ['Conv_3']style_layer = ['Conv_1', 'Conv_2', 'Conv_3', 'Conv_4']
2)模型迁移
index = 1for layer in list(model)[:8]:if isinstance(layer, torch.nn.Conv2d):name = 'Conv_' + str(index)new_model.add_module(name, layer)if name in content_layer:target = new_model(content_image).clone()content_loss = Content_loss(content_weight, target)new_model.add_module('content_loss_' + str(index), content_loss)content_losses.append(content_loss)if name in style_layer:target = new_model(style_image).clone()target = gram(target)style_loss = Style_loss(style_weight, target)new_model.add_module('style_loss_' + str(index), style_loss)style_losses.append(style_loss)if isinstance(layer, torch.nn.ReLU):name = 'Relu_' + str(index)new_model.add_module(name, layer)index += 1if isinstance(layer, torch.nn.MaxPool2d):name = 'MaxPool_' + str(index)new_model.add_module(name, layer)
此时,模型输出结果:
3)参数优化
input_img = content_image.clone()parameter = torch.nn.Parameter(input_img.data)optimizer = torch.optim.LBFGS([parameter])
4. 训练
epoch_n = 300epoch = [0]while epoch[0] <= epoch_n:def closure():optimizer.zero_grad()style_score = 0content_score = 0parameter.data.clamp_(0, 1)new_model(parameter)for sl in style_losses:style_score += sl.backward()for cl in content_losses:content_score += cl.backward()epoch[0] += 1if epoch[0] % 50 == 0:print('Epoch:{} Style Loss:{:4f} Content Loss:{:4f}'.format(epoch[0], style_score.item(), content_score.item()))return style_score + content_scoreoptimizer.step(closure)
