您的位置 首页 PyTorch 教程

[莫烦 PyTorch 系列教程] 4.2 – RNN 循环神经网络 (分类 Classification)

PyTorch入门实战教程

循环神经网络让神经网络有了记忆, 对于序列话的数据,循环神经网络能达到更好的效果. 如果你对循环神经网络还没有特别了解, 请观看几分钟的短动画, RNN 动画简介(如下) 和 LSTM(如下) 动画简介 能让你生动理解 RNN. 接着我们就一步一步做一个分析手写数字的 RNN 吧.

RNN 简介

LSTM 简介

MNIST手写数据

黑色的地方的值都是0, 白色的地方值大于0.

同样, 我们除了训练数据, 还给一些测试数据, 测试看看它有没有训练好.

RNN模型

和以前一样, 我们用一个 class 来建立 RNN 模型. 这个 RNN 整体流程是

  1. (input0, state0) -> LSTM -> (output0, state1) ;
  2. (input1, state1) -> LSTM -> (output1, state2) ;
  3. (inputN, stateN)-> LSTM -> (outputN, stateN 1) ;
  4. outputN -> Linear -> prediction . 通过LSTM分析每一时刻的值, 并且将这一时刻和前面时刻的理解合并在一起, 生成当前时刻对前面数据的理解或记忆. 传递这种理解给下一时刻分析.

训练

我们将图片数据看成一个时间上的连续数据, 每一行的像素点都是这个时刻的输入, 读完整张图片就是从上而下的读完了每行的像素点. 然后我们就可以拿出 RNN 在最后一步的分析值判断图片是哪一类了. 下面的代码省略了计算 accuracy 的部分, 你可以在我的 github 中看到全部代码.

最后我们再来取10个数据, 看看预测的值到底对不对:

所以这也就是在我 github 代码 中的每一步的意义啦.

文章来源:莫烦

本站微信群、QQ群(三群号 726282629):

PyTorch入门实战教程

发表回复

您的电子邮箱地址不会被公开。

返回顶部