这篇文章将用 PyTorch 实现一个基本的生成对抗网络(Generative Adversarial Network, GAN),来学习一个正态分布。代码提供 Jupiter Notebook,地址见文末。
首先我们导入一些必备的库:
1 2 3 4 | import torch import torch.nn as nn import torch.optim as optim from torch.distributions.normal import Normal |
文章目录
定义我们要学习的正态分布
我们定一个正态分布,它的均值和标准差如下:
1 2 3 | data_mean = 3.0 data_stddev = 0.4 Series_Length = 30 |
定义生成网络(Generator)
我们的生成网络接收一些随机输入,按照上面的定义生成正态分布。你可以在代码里改变这些变量的值来看它们对最终结果的影响:
1 2 3 | g_input_size = 20 g_hidden_size = 150 g_output_size = Series_Length |
定义对抗网络(Adversarial)
我们的对抗网络输出如下:
- True(1.0) 如果输入的数据符合定义的正态分布
- False(0.0) 如果输入的数据不符合定义的正态分布
1 2 3 | d_input_size = Series_Length d_hidden_size = 75 d_output_size = 1 |
定义数据输入方式
1 2 3 4 | d_minibatch_size = 15 g_minibatch_size = 10 num_epochs = 5000 print_interval = 1000 |
学习率
下面的学习率你也可以试着变一下做做实验,如果太小会影响收敛。
1 2 | d_learning_rate = 3e-3 g_learning_rate = 8e-3 |
下面的两个函数一个可以得到真正的分布,一个可以得到噪声。真正的分布用来训练 Discriminator,噪声用来作为 Generator 的输入。
1 2 3 4 5 6 7 8 9 | def get_real_sampler(mu, sigma): dist = Normal( mu, sigma ) return lambda m, n: dist.sample( (m, n) ).requires_grad_() def get_noise_sampler(): return lambda m, n: torch.rand(m, n).requires_grad_() # Uniform-dist data into generator, _NOT_ Gaussian actual_data = get_real_sampler( data_mean, data_stddev ) noise_data = get_noise_sampler() |
生成器(Generator)
生成器用来输出符合我们想要的正态分布的均值。很简单的一个 4 层网络。
1 2 3 4 5 6 7 8 9 10 11 | class Generator(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(Generator, self).__init__() self.map1 = nn.Linear(input_size, hidden_size) self.map2 = nn.Linear(hidden_size, hidden_size) self.map3 = nn.Linear(hidden_size, output_size) self.xfer = torch.nn.SELU() def forward(self, x): x = self.xfer( self.map1(x) ) x = self.xfer( self.map2(x) ) return self.xfer( self.map3( x ) ) |
鉴别器(Discriminator)
非常简单的 Linear 模型,返回 True 或者 False。
1 2 3 4 5 6 7 8 9 10 11 12 | class Discriminator(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(Discriminator, self).__init__() self.map1 = nn.Linear(input_size, hidden_size) self.map2 = nn.Linear(hidden_size, hidden_size) self.map3 = nn.Linear(hidden_size, output_size) self.elu = torch.nn.ELU() def forward(self, x): x = self.elu(self.map1(x)) x = self.elu(self.map2(x)) return torch.sigmoid( self.map3(x) ) |
搭建网络
我们使用 BCE 损失函数,SGD 优化函数。
1 2 3 4 5 6 | G = Generator(input_size=g_input_size, hidden_size=g_hidden_size, output_size=g_output_size) D = Discriminator(input_size=d_input_size, hidden_size=d_hidden_size, output_size=d_output_size) criterion = nn.BCELoss() d_optimizer = optim.SGD(D.parameters(), lr=d_learning_rate ) g_optimizer = optim.SGD(G.parameters(), lr=g_learning_rate ) |
训练
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 | def train_D_on_actual() : real_data = actual_data( d_minibatch_size, d_input_size ) real_decision = D( real_data ) real_error = criterion( real_decision, torch.ones( d_minibatch_size, 1 )) # ones = true real_error.backward() def train_D_on_generated() : noise = noise_data( d_minibatch_size, g_input_size ) fake_data = G( noise ) fake_decision = D( fake_data ) fake_error = criterion( fake_decision, torch.zeros( d_minibatch_size, 1 )) # zeros = fake fake_error.backward() def train_G(): noise = noise_data( g_minibatch_size, g_input_size ) fake_data = G( noise ) fake_decision = D( fake_data ) error = criterion( fake_decision, torch.ones( g_minibatch_size, 1 ) ) error.backward() return error.item(), fake_data losses = [] for epoch in range(num_epochs): D.zero_grad() train_D_on_actual() train_D_on_generated() d_optimizer.step() G.zero_grad() loss,generated = train_G() g_optimizer.step() losses.append( loss ) if( epoch % print_interval) == (print_interval-1) : print( "Epoch %6d. Loss %5.3f" % ( epoch 1, loss ) ) print( "Training complete" ) |
结果
训练完成后我们展示一些结果:
1 2 3 4 5 6 7 8 9 10 11 | import matplotlib.pyplot as plt def draw( data ) : plt.figure() d = data.tolist() if isinstance(data, torch.Tensor ) else data plt.plot( d ) plt.show() d = torch.empty( generated.size(0), 53 ) for i in range( 0, d.size(0) ) : d[i] = torch.histc( generated[i], min=0, max=5, bins=53 ) draw( d.t() ) |
代码地址
代码见 https://github.com/rcorbish/pytorch-notebooks。
本站微信群、QQ群(三群号 726282629):