10分钟快速入门 PyTorch (9) – LSTM 词性判断

PyTorch入门实战教程

在上一节中,我们介绍了一下自然语言处理里面最基本的单边和双边的 n gram 模型,用 word embedding和n gram 模型对一句话中的某个词做预测,下面我们将使用LSTM来做判别每个词的词性,因为同一个单词有着不同的词性,比如book可以表示名词,也可以表示动词,所以我们需要训练一下网络来得到词性的判断。

LSTM 词性判断

LSTM的网络结构在之前已经介绍过了,如果忘了的同学可以去前面看看。我们首先介绍一下如何做每个词词性的判断。

首先,我们定义好一个LSTM网络,然后给出一个句子,每个句子都有很多个词构成,每个词可以用一个词向量表示,这样一句话就可以形成一个序列,我们将这个序列依次传入LSTM,然后就可以得到与序列等长的输出,每个输出都表示的是一种词性,比如名词,动词之类的,还是一种分类问题,每个单词都属于几种词性中的一种。

我们可以思考一下为什么LSTM在这个问题里面起着重要的作用。如果我们完全孤立的对一个词做词性的判断这样我们需要特别高维的词向量,但是对于LSTM,它有着一个记忆的特性,这样我们就能够通过这个单词前面记忆的一些词语来对其做一个判断,比如前面如果是my,那么他紧跟的词有很大可能就是一个名词,这样就能够充分的利用上文来做这个问题。

同时我们还可以通过引入字符来增强表达,什么意思呢?也就是说一个单词有一些前缀和后缀,比如-ly这种后缀很大可能是一个副词,这样我们就能够在字符水平得到一个词性判断的更好结果。

具体怎么做呢?还是用LSTM。每个单词有不同的字母组成,比如 apple 由a p p l e构成,我们同样给这些字符词向量,这样形成了一个长度为5的序列,然后传入另外一个LSTM网络,只取最后输出的状态层作为它的一种字符表达,我们并不需要关心到底提取出来的字符表达是什么样的,在learning的过程中这些都是会被更新的参数,使得最终我们能够正确预测。

原理看着挺让人烦的,这个时候看代码反而更快,所以如果前面的原理你没有理解清楚,那么看看代码,说不行你就恍然大悟了。

Code

准备数据

这是一个简单的训练数据,两句话,每句话的每个单词的词性由后面给出。

接着我们需要给这些单词和词性一个编码

这样每个单词就用一个数字表示,每种词性也用一个数字表示,这些之前都接触过。

同时我们需要将从a到z的字符也编码。

字符LSTM

接着我们定义字符水平的LSTM

看看上面的代码,首先定义好embedding和lstm,接着传入n个字符,然后通过nn.Embedding得到词向量,接着传入LSTM网络,得到状态输出h,然后通过h[1]得到我们想要的hidden state。

这样我们对于每个单词,通过CharLSTM就能够得到相应的字符表示。

词性LSTM

接着我们来完成我们的目标,分析每个单词的词性,首先定义好LSTM网络


看着有点复杂,我们慢慢来解释。首先n_word 和 n_dim来定义单词的词向量维度,n_char和char_dim来定义字符的词向量维度,char_hidden表示CharLSTM输出的维度,n_hidden表示每个单词作为序列输入的LSTM输出维度,最后n_tag表示输出的词性的种类。

接着开始前向传播,不仅要传入一个编码之后的句子,同时还需要传入原本的单词,因为需要对字符做一个LSTM,所以传入的参数多了一个word_data表示一个句子的所有单词。

然后就是将每个单词传入CharLSTM,得到的结果和单词的词向量拼在一起形成一个新的输入,将输入传入LSTM里面,得到输出,最后接一个全连接层,将输出维数定义为label的数目。

这就是基本的思路,我就不具体解释每句话的含义了,留给大家自己看看,特别要注意里面有一些unsqueeze和squeeze是因为LSTM的输入要求要带上batch_size,torch.cat里面0和1分别表示沿着行和列来拼接。

结果

经过300个epoch,loss降到了0.2左右

最后我们来预测一下 Everybody ate the apple 这句话每个词的词性,一共有3种词性,DET,NN,V。最后得到的结果为

一共有4行,每行里面取最大的,那么第一个词的词性就是NN,第二个词是V,第三个词是DET,第四个词是NN。这个是相符的。

以上我们介绍了RNN在图像处理以及自然语言处理上的应用,RNN还有更多的应用,比如做image captioning,机器翻译等等,感兴趣的同学可以自己在github上找一找。

下一章将是本次教程的倒数第二个部分,Generative Adversarial Networks,生成对抗网络。

本文代码已经上传到了github上。

文章来源:知乎专栏

本站微信群、QQ群(三群号 726282629):

PyTorch入门实战教程

发表回复

您的电子邮箱地址不会被公开。

返回顶部