您的位置 首页 PyTorch 教程

PyTorch 实现风格迁移 (style transfer)

PyTorch入门实战教程

背景介绍

不知道大家是否用过prisma,就算没有用过,也一定看见别人用过这个软件,下面是一张这个软件得到的一个效果图:

example

 

官方宣传的卖点是一秒钟让你的作品拥有名家风格,什么毕加索,梵高,都不在话下。通过这个效果再将你的照片发到朋友圈,是不是效果爆棚,简直是各种装逼界的一股清流,秒杀各种修图ps好吗。而且可以完美的掩饰掉一些瑕疵,又比ps更自然,更有逼格,是不是很棒。

这个软件将使用的方法发了一篇论文,并且这个软件在发布的时候就取得了上千万的融资,是不是瞬间感觉现在学习了知识也能成为千万富豪了。如今在这个高速发展的时代,知识付费的时代确实已经到来了,所以我们现在努力学习各种知识就是在赚钱啊,有木有。这样大家的学习的时候就能够有着更大的动力了。

这篇论文感兴趣的同学可以去查看一下,里面主要涉及的是卷积神经网络。今天这篇文章要做的是什么呢?我们希望自己能够简单的实现这个风格迁移算法,并且用自己的算法来得到新的风格图片。一想到我们放到朋友圈的照片是自己写的算法来实现的就感觉成就感爆棚,有没有。

环境配置

废话不多说,我们先来看看需要的基本配置。首先需要python环境,建议使用anaconda;然后我们使用的深度学习框架是pytorch,当然你也可以用tensorflow,具体框架的介绍可以去看看之前写的文章,需要安装pytorch和torchvision,这里查看安装帮助;同时需要一些其他的包,如果缺什么就pip安装就好。

这篇文章主要参考于pytorch的官方tutorial,感兴趣的同学可以直接移步至官方教程的地方,这篇文章我会说一些自己的理解,代码部分基本都是参考这个教程,但是我会做一些说明,力求更加清楚。

原理分析

其实要实现的东西很清晰,就是需要将两张图片融合在一起,这个时候就需要定义怎么才算融合在一起。首先需要的就是内容上是相近的,然后风格上是相似的。这样来我们就知道我们需要做的事情是什么了,我们需要计算融合图片和内容图片的相似度,或者说差异性,然后尽可能降低这个差异性;同时我们也需要计算融合图片和风格图片在风格上的差异性,然后也降低这个差异性就可以了。这样我们就能够量化我们的目标了。

对于内容的差异性我们该如何定义呢?其实我们能够很简答的想到就是两张图片每个像素点进行比较,也就是求一下差,因为简单的计算他们之间的差会有正负,所以我们可以加一个平方,使得差全部是正的,也可以加绝对值,但是数学上绝对值会破坏函数的可微性,所以大家都用平方,这个地方不理解也没关系,记住普遍都是使用平方就行了。

对于风格的差异性我们该如何定义呢?这才是一个难点。这也是这篇文章提出的创新点,引入了Gram矩阵计算风格的差异。我尽量不使用数学的语言来解释,而使用通俗的语言。 首先需要的预先知识是卷积网络的知识,这里不细讲了,不了解的同学可以看之前的卷积网络的文章。我们知道一张图片通过卷积网络之后可以的到一个特征图,Gram矩阵就是在这个特征图上面定义出来的。每个特征图的大小一般是 MxNxC 或者是 CxMxN 这种大小,这里C表示的时候厚度,放在前面和后面都可以,MxN 表示的是一个矩阵的大小,其实就是有 C 个 MxN 这样的矩阵叠在一起。

Gram矩阵是如何定义的呢?首先Gram矩阵的大小是有特征图的厚度决定的,等于 CxC,那么每一个Gram矩阵的元素,也就是 Gram(i, j) 等于多少呢?先把特征图中第 i 层和第 j 层取出来,这样就得到了两个 MxN的矩阵,然后将这两个矩阵对应元素相乘然后求和就得到了 Gram(i, j),同理 Gram 的所有元素都可以通过这个方式得到。这样 Gram 中每个元素都可以表示两层特征图的一种组合,就可以定义为它的风格。

然后风格的差异就是两幅图的 Gram 矩阵的差异,就像内容的差异的计算方法一样,计算一下这两个矩阵的差就可以量化风格的差异。

实现

以下的内容都是用pytorch实现的,如果对pytorch不熟悉的同学可以看一下我之前的pytorch介绍文章,看看官方教程,如果不想了解pytorch的同学可以用自己熟悉的框架实现这个算法,理论部分前面已经讲完了。

内容差异的loss定义

其中有个变量weight,这个是表示权重,内容和风格你可以选择一个权重,比如你想风格上更像,内容上多一点差别没关系,那么内容的权重你可以定义小一点,风格的权重可以定义大一点;反之你可以把风格的权重定义小一点,内容的权重定义大一点。

风格差异的loss定义

Gram 矩阵的定义

style loss定义

建立模型

使用19层的 vgg 作为提取特征的卷积网络,并且定义哪几层为需要的特征。

训练模型

需要特别注意的是这个模型里面参数不再是网络里面的参数,因为网络使用的是已经预训练好的 vgg 网络,这个算法里面的参数是合成图片里面的每个像素点,我们可以将内容图片直接 copy 成合成图片,然后训练使得他的风格和我们的风格图片相似,同时也可以随机化一张图片作为合成图片,然后训练他使得他与内容图片以及风格图片具有相似性。

实验结果

我们使用的风格图片为

style.png

 

内容图片为

content.png

得到的合成效果为

demo.png

结语

通过这篇文章,我们利用pytorch实现了基本的风格转移算法,得到的效果也是满意的,所以我们可以把自己的图片通过这个算法做一个风格转移,实现你想要的作品的风格,逼格满满,大家学习之后肯定会有特别大的成就感,在完成项目的同时也学习到了新的知识,同时也会对这个产生更浓厚的感兴趣,兴趣才是各种的动力,比任何鸡汤都有用,希望大家都能够找到自己的兴趣,热爱自己所做的事。


本文代码已经上传到了github上。

文章来源:知乎专栏

本站微信群、QQ群(三群号 726282629):

PyTorch入门实战教程

发表回复

您的电子邮箱地址不会被公开。

评论列表(1)

  1. TypeError: backward() got an uTypeError: backward() got an unexpected keyword argument ‘retain_variables’nexpected keyword argument ‘retain_variables’

返回顶部