您的位置 首页 PyTorch 教程

用 PyTorch 实现一个基本 GAN 网络学习正态分布

PyTorch入门实战教程

这篇文章将用 PyTorch 实现一个基本的生成对抗网络(Generative Adversarial Network, GAN),来学习一个正态分布。代码提供 Jupiter Notebook,地址见文末。

首先我们导入一些必备的库:

定义我们要学习的正态分布

我们定一个正态分布,它的均值和标准差如下:

定义生成网络(Generator)

我们的生成网络接收一些随机输入,按照上面的定义生成正态分布。你可以在代码里改变这些变量的值来看它们对最终结果的影响:

定义对抗网络(Adversarial)

我们的对抗网络输出如下:

  1. True(1.0) 如果输入的数据符合定义的正态分布
  2. False(0.0) 如果输入的数据不符合定义的正态分布

定义数据输入方式

学习率

下面的学习率你也可以试着变一下做做实验,如果太小会影响收敛。

下面的两个函数一个可以得到真正的分布,一个可以得到噪声。真正的分布用来训练 Discriminator,噪声用来作为 Generator 的输入。

生成器(Generator)

生成器用来输出符合我们想要的正态分布的均值。很简单的一个 4 层网络。

鉴别器(Discriminator)

非常简单的 Linear 模型,返回 True 或者 False。

搭建网络

我们使用 BCE 损失函数,SGD 优化函数。

训练

结果

训练完成后我们展示一些结果:

代码地址

代码见 https://github.com/rcorbish/pytorch-notebooks

PyTorch入门实战教程

发表评论

电子邮件地址不会被公开。 必填项已用*标注

返回顶部