您的位置 首页 PyTorch 教程

PyTorch 实现 LR-GAN

PyTorch入门实战教程

本文主要是介绍2017年国际学习表征会议(ICLR 2017)的论文《LR-GAN:分层递归生成对抗网络进行图像生成》在pytorch上的实现。该论文由弗吉尼亚理工大学、FAIR(脸书AI实验室)、佐治亚理工学院合著。该文的第一作者Jianwei Yang近日在GitHub上发文,说明该论文如何用Pytorch实现。

在论文中,我们提出,鉴于图像本身带有结构和内容,可采用LR-GAN(分层递归生成对抗网络)以递归的方式逐层生成图像。如下图所示,LR-GAN首先生成背景图像,然后生成具有外观、姿态和形状的前景。之后,LR-GAN将以一种与情境相关的方式将前景置于相应的背景之上,从而生成一个完整的自然图像。整个模型都是无监督的,并且采用梯度下降的方法,以端到端的方式进行训练。

通过这种方式,LR-GAN便可以明显减少背景和前景之间的混合。而无论是定性或定量的比较,都表明,相较于基准DCGAN模型,LR-GAN可以产生更好、更清晰的图像。

免责声明

这是基于Pytorch的LR-GAN代码。它是以Pytorch DCGAN为基础进行开发的。我们的原始代码是在第一作者实习期间基于Torch实现的。本文所呈现的所有结果都是基于Torch代码获得的,由于版权限制不能将其发布。此项目旨在重现论文中所得到的结果。

引文

如果你发现此代码对你有所帮助,想要了解更多,请引用以下文章:

实验基础

1.PyTorch,请使用正确的命令安装PyTorch,同时确保你也安装了torchvision。

2. STNM(Spatial transformer network with mask,具有掩码的空间变换神经网络)。我们已在此项目中提供了这个模块。但是如果你想对此做些改变,请参考此项目。

训练LR-GAN

准备

将此项目置于你自己的机器上,并确保Pytorch已成功安装。然后,你需要创建一个保存训练集的数据集文件夹;一个保存生成结果的图像文件夹,以及一个保存模型(生成器和鉴别器)的模型文件夹:

接下来,你就可以尝试在数据集上训练LR-GAN模型:1)MNIST-ONE; 2)MNIST-TWO; 3)CUB-200; 4)CIFAR-10。

样本图像如下所示:

在数据集文件夹中,分别为所有这些数据集创建子文件夹:

训练

我们的模型是分别在MNIST-ONE、MNIST-TWO、CIFAR-10和CUB-200四个数据集上进行训练的。接下来我们将分别介绍在不同数据集上进行训练的操作和结果。

1.MNIST-ONE

我们首先在MNIST-ONE上进行实验,可以从这里下载。将其解压缩到datasets / mnist-one文件夹中,然后运行以下命令:

其中,ntimestep指定递归层的数量,例如,2表示一个背景层和一个前景层;imageSize是指训练图像的比例尺寸;maxobjscale是指最大目标(前景)比例,值越大,目标尺寸越小;session指定了训练会话;niter指定训练时期的数量。以下是在训练周期为50,使用训练模型的随机生成结果:

生成背景图像

前景图像

前景掩码

最终图像

2.CUB200,我们在64×64的CUB200上运行。这是处理过的数据集。同样,将其下载并解压缩到datasets / cub200。然后,运行以下命令:

基于上述指令,我们获得了与论文中所提及的相同的模型。以下是随机生成的图像:

由此我们可以看到,此中布局类似于MNIST-ONE。正如我们所看到的那样,生成器产生鸟形掩码,从而使最终的图像更为清晰。

3.CIFAR-10,CIFAR-10可以使用pytorch dataloader自动下载。我们生成只需要使用两个时间步长。为了训练模型,需要运行:

以下是一些随机采样的生成结果:

生成背景图像

前景图像

前景掩码

最终的图像

从上面几幅图片中,我们可以清楚地发现一些生成的马形,鸟形和船形掩码图像,同时,最终生成的图像也更为清晰。

4.MNIST-TWO,图像为64×64,且包含两位数字。我们使用以下命令训练模型:

可以看到,整体布局与我们论文中的布局是一样的。

5.LFW,我们训练的是64×64的图像,图像资源可以从这里下载。同样,我们需要将其解压缩到文件夹datasets / lfw。我们使用以下命令训练模型:

以下是生成结果:

真实的图像

生成背景图像

生成前景图像

掩码图像

最终图像

测试LR-GAN

训练结束后,检查点将保存到models 中。你可以再附加两个选项(netGevaluate)到用于训练模型的命令中。以cifar10为例,其代码将如下所示:

然后,你将在images文件夹中获得模型的会话为1、周期为100的生成结果。

文章来源:Github

PyTorch入门实战教程

发表评论

电子邮件地址不会被公开。 必填项已用*标注

返回顶部