10分钟快速入门 PyTorch (10) – GAN

PyTorch入门实战教程

前面我们已经讲完了一般的深层网络,适用于图像的卷积神经网络,适用于序列的循环神经网络。但是要知道Lecun提出第一代卷积网络Lenet的时间是1998年,而循环神经网络提出的时间更早,是在1986年。这些网络在当时并没有火起来,如今随着计算能力的加强,数据集的增多,深度学习逐渐火了起来,随着越来越多的人的研究,各种各样的神经网络都在不断进步,CNN里面出现了inception net,resnet等等,RNN演变了LSTM和GRU,虽然神经网络不断在发展,但是本质上仍然是在CNN和RNN的基础上。

直到2014年,深度学习三巨头之一 Ian Goodfellow 提出了生成对抗网络(Generative Adversarial Networks, GANs),刚开始的时候并没有引起轰动,直到16年,学界、业界对其的兴趣出现了“井喷”,多篇重磅文章陆续发表,Lecun也形容GANs“adversarial training is the coolest thing since sliced bread.” 16年12月NIPS大会上,Goodfellow做了GANs的专题报告,使得GANs成为了当今最炙手可热的研究领域,等你看完了这篇文章你就会知道为什么GANs能够成为当今人工智能领域的主要课题之一。

GANs

GANs的全称叫做生成对抗网络,根据这个名字,你就可以猜测这个网络是由两部分组成的,第一部分是生成,第二部分是对抗。那么你已经基本猜对了,这个网络第一部分是生成网络,第二部分对抗模型严格来讲是一个判别器,简单来说呢,就是让两个网络相互竞争,生成网络来生成假的数据,对抗网络通过判别器去判别真伪,最后希望生成器生成的数据能够以假乱真。

可以用这个图来简单的看一看这两个过程。

下面我们就来依次介绍。

Discriminator Network

首先我们来讲一下对抗过程,因为这个过程更加简单。

对抗过程简单来说就是一个判断真假的判别器,相当于一个二分类问题,我们输入一张真的图片希望判别器输出的结果是1,输入一张假的图片希望判别器输出的结果是0。这其实已经和原图片的label没有关系了,不管原图片到底是一个多少类别的图片,他们都统一称为真的图片,label是1表示真实的;而生成的假的图片的label是0表示假的。

我们训练的过程就是希望这个判别器能够正确的判出真的图片和假的图片,这其实就是一个简单的二分类问题,对于这个问题可以用我们前面讲过的很多方法去处理,比如logistic回归,深层网络,卷积神经网络,循环神经网络都可以。

Generative Network

接着我们要看看如何生成一张假的图片。首先给出一个简单的高维的正态分布的噪声向量,如上图所示的D-dimensional noise vector,这个时候我们可以通过仿射变换,也就是xw b将其映射到一个更高的维度,然后将他重新排列成一个矩形,这样看着更像一张图片,接着进行一些卷积、池化、激活函数处理,最后得到了一个与我们输入图片大小一模一样的噪音矩阵,这就是我们所说的假的图片,这个时候我们如何去训练这个生成器呢?就是通过判别器来得到结果,然后希望增大判别器判别这个结果为真的概率,在这一步我们不会更新判别器的参数,只会更新生成器的参数。

如下图所示

以上的过程已经简单的阐述了生成对抗网络的学习过程,如果仍然不太清楚这个过程,下面我们会通过代码来更清晰地展示整个过程。

Code

我们会使用mnist手写数字来做数据集,通过生成对抗网络我们希望生成一些“以假乱真”的手写字体。为了加快训练过程,我们不使用卷积网络来做判别器,我们使用简单的多层网络来进行判别。

Discriminator Network

以上这个网络是一个简单的多层神经网络,将图片28×28展开成784,然后通过多层感知器,中间经过斜率设置为0.2的LeakyReLU激活函数,最后接sigmoid激活函数得到一个0到1之间的概率进行二分类。之所以使用LeakyRelu而不是用ReLU激活函数是因为经过实验LeakyReLU的表现更好。

Generative Network

输入一个100维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维,然后通过LeakyReLU激活函数,接着进行一个线性变换,再经过一个LeakyReLU激活函数,然后经过线性变换将其变成784维,最后经过Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间。

Discriminator Train

判别器的训练由两部分组成,第一部分是真的图像判别为真,第二部分是假的图片判别为假,在这两个过程中,生成器的参数不参与更新。

首先我们需要定义loss的度量方式和优化函数,loss度量使用二分类的交叉熵,油画函数注意使用的学习率是0.0003

接着进入训练
我已经把每一步都注释在了代码上,这样更加便于大家阅读,这是一个判别器的训练过程,我们希望判别器能够正确辨别出真假图片。

Generative Train

在生成网络的训练中,我们希望生成一张假的图片,然后经过判别器之后希望他能够判断为真的图片,在这个过程中,我们将判别器固定,将假的图片传入判别器的结果与真实label对应,反向传播更新的参数是生成网络里面的参数,这样我们就可以通过跟新生成网络里面的参数来使得判别器判断生成的假的图片为真,这样就达到了生成对抗的作用。

这样我们就写好了一个简单的生成网络,通过不断地训练我们希望能够生成很真的图片。

Result

通过不断训练,我们可以得到下面的图片

这是真实图片

第1幅为第一次生成的噪声图片,之后分别是跑完15次生成的图片,跑完30次,跑完50次,跑完70次,最后一个是跑完100次生成的图片

怎么样,是不是特别神奇,我们居然可以生成一副看着很真的图片,这里我们只是用了简单的多层感知器来生成和判别模型,我们可以用更复杂的卷积神经网络来做同样的事情,代码将和本文的代码放在一起,有兴趣的同学可以自己去看看,然后放几张卷积网络生成的图片

可以发现产生的噪声更少了,训练也更加稳定,主要是里面引入了Batchnormalization,另外gan的训练过程是特别困难的,两个对偶网络相互学习,这个时候有一些训练技巧可以使得训练生成更加稳定,详细见一下github。

最后我们来说一下为何Gans能够成为最近20年来机器学习以及深度学习界革命性的发现。这是因为不管是深度学习还是机器学习仍然很大一部分是监督学习,但是创建这么多有label的数据集所需要的人力物力是极大的,同时遇到的新的任务时我们很容易得到原始的没有label的数据集,这是我们需要花大量的时间去给其标定label,所以很多人都认为无监督学习才是机器学习的未来,这个时候Gans的出现为无监督学习提供了有力的支持,这当然引起了学界的大量关注,同时基于Gans的应用也越来越多,业界对其也非常狂热。

最后引用Yan Lecun的话:”它(Gans)为创建无监督学习模型提供了强有力的算法框架,有望帮助我们为 AI 加入常识(common sense)。我们认为,沿着这条路走下去,有不小的成功机会能开发出更智慧的 AI 。”

以上我们简单的介绍了Gans,通过网络实现了手写字体的生成,当然还有更多的变形和应用,有兴趣的同学可以自己阅读相关论文深入了解。

下一章我们将进入pytorch教程的最后一个部分,也是和AI联系最为紧密的一个部分,reinforcement learning,增强学习。

本文代码已经上传到了github上。

文章来源:知乎专栏

PyTorch入门实战教程
除特别注明外,本站所有文章均为 PyTorch 中文网原创,转载请注明出处:https://www.pytorchtutorial.com/10-minute-pytorch-10/

Leave a Reply

Your email address will not be published. Required fields are marked *

返回顶部