您的位置 首页 PyTorch 教程

PyTorch DQN(深度Q网络)示例

PyTorch入门实战教程

深度Q网络介绍

深度Q网络是用深度学习来解决强化中Q学习的问题,可以先了解一下Q学习的过程是一个怎样的过程,实际上就是不断的试错,从试错的经验之中寻找最优解。

关于Q学习,我看到一个非常好的例子,另外知乎上面也有相关的讨论

其实早在13年的时候,deepmind出来了第一篇用深度学习来解决Q学习的问题的paper,那个时候deepmind还不够火,和一般的Q学习不同的是,由于12年Alex率先用CNN解决图像中的high level的语义的提取,deepmind也同时采用了CNN来直接对图像进行特征提取,而非传统的进行手工特征提取。

我想从代码的角度来看一下DQN是如何实现的。

代码理解

pytorcyh的代码在官网上是有的,我也贴出了自己添加了注释的代码,以及写一下自己的对于代码的理解。

作者调用了一个gym的库,这个库可以用作强化学习的训练样本,但是蛋疼的是,在用pycharm进行debug的时候,gym库总会报错,如果直接运行则不会,我想可能是因为gym库并不可以进行调试。

anyway,代码的总体流程是,调用gym,声明一个事件,在强化学习中被称为agent,这个agent会展示当前的状态,然后会接收一个action,输出下一个的状态以及这个action所得到的奖励,ok,至于这个agent采取了action之后所得到的奖励是如何计算的。

这个agent采取了这个action下一个状态是啥,gym已经给你们写好了。

在定义网络结构之前,作者实际上是把自己试错的状态存储了起来,存储的内容有,当前的state,采取action,以及nextstate,以及这个action相应的reward,而state并不是当前游戏的截屏,而是两帧之间的差值,reward是gym自己返回的。

至于为什么这样做?有点儿类似与用RNN解决slam的问题,为什么输入到网络中的是视频两帧之间的差值,而不是视频自己本身的内容,要给自己挖个坑。

网络训练

存储了这些状态之后就可以训练网络了,主体的网络结构如下:

网络输出的两个值,分别是对应不同的action,其实也不难理解,训练的网络最终能够产生的输出当然是决策是怎样的,不过这种自己不断的试错,并且把自己试错的数据保存下来,严格意义上来说真的是无监督学习?

anyway,作者用这些试错的数据进行训练。

网络的LOSS

不过,网络的loss怎么设计?

loss如上,实际上就是求取两个Q函数之间的差值,ok,前一个Q函数的自变量描述的是当前的状态s以及对应的行为a,后一个r Q描述的是当前的reward加上,在下一个state如何采取下一步行动能够让Q最大的项。

而这两项如何在代码中体现,实际上作者定义了两个网络,一个成为policy,另外一个为target网络。

优化的目标是policy net,target网络为定期对policy的copy,如下:

policy net输入state batch,并且将实际中的对应的action的那一列输出,action非0即1,所以policy_net输出的是batch_size的列向量。

在这段代码中,这个网络的输出就是Q函数的值,target_net网络输入的是next_state,并且因为不知道其实际的action是多少,所以取最大的,输出乘以一个gamma,并且加上当前状态的reward即可。

其实永远是policy_net更新在前,更新的方向是让两个网络的输出尽可能的接近,其实也不仅仅是这样,这中间还有一个reward变量,可是为什么target_net的更新要永远滞后,一种更加极端的情况是,如果把next_state输入到policy网络中呢?

文章来源:Yongjieshi

本站微信群、QQ群(三群号 726282629):

PyTorch入门实战教程

发表回复

您的电子邮箱地址不会被公开。

返回顶部