博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
转:RNN(Recurrent Neural Networks)
阅读量:6204 次
发布时间:2019-06-21

本文共 8665 字,大约阅读时间需要 28 分钟。

本文主要参考的博客所写,所有的代码都是python实现。没有使用任何深度学习的工具,公式推导虽然枯燥,但是推导一遍之后对RNN的理解会更加的深入。看本文之前建议对传统的神经网络的基本知识已经了解,如果不了解的可以看此文:『』。

所有可执行代码:

文章目录 []

语言模型

熟悉NLP的应该会比较熟悉,就是将自然语言的一句话『概率化』。具体的,如果一个句子有m个词,那么这个句子生成的概率就是:

P(w1,...,wm)=mi=1P(wiw1,...,wi1)P(w1,...,wm)=∏i=1mP(wi∣w1,...,wi−1)

其实就是假设下一次词生成的概率和只和句子前面的词有关,例如句子『He went to buy some chocolate』生成的概率可以表示为:  P(他喜欢吃巧克力) = P(他喜欢吃) * P(巧克力|他喜欢吃) 。

数据预处理

训练模型总需要语料,这里语料是来自的reddit的评论数据,语料预处理会去掉一些低频词从而控制词典大小,低频词使用一个统一标识替换(这里是UNKNOWN_TOKEN),预处理之后每一个词都会使用一个唯一的编号替换;为了学出来哪些词常常作为句子开始和句子结束,引入SENTENCE_START和SENTENCE_END两个特殊字符。具体就看代码吧:

 

网络结构

和传统的nn不同,但是也很好理解,rnn的网络结构如下图:

rnn

A recurrent neural network and the unfolding in time of the computation involved in its forward computation.

不同之处就在于rnn是一个『循环网络』,并且有『状态』的概念。

如上图,t表示的是状态, xtxt 表示的状态t的输入, stst 表示状态t时隐层的输出, otot 表示输出。特别的地方在于,隐层的输入有两个来源,一个是当前的 xtxt 输入、一个是上一个状态隐层的输出 st1st−1 , W,U,VW,U,V 为参数。使用公式可以将上面结构表示为:

sty^t=tanh(Uxt+Wst1)=softmax(Vst)st=tanh⁡(Uxt+Wst−1)y^t=softmax(Vst)

 如果隐层节点个数为100,字典大小C=8000,参数的维度信息为:

xtotstUVWR8000R8000R100R100×8000R8000×100R100×100xt∈R8000ot∈R8000st∈R100U∈R100×8000V∈R8000×100W∈R100×100

初始化

参数的初始化有很多种方法,都初始化为0将会导致『symmetric calculations 』(我也不懂),如何初始化其实是和具体的激活函数有关系,我们这里使用的是tanh,一种推荐的方式是初始化为 [1n√,1n√][−1n,1n] ,其中n是前一层接入的链接数。更多信息请点击。

 

前向传播

类似传统的nn的方法,计算几个矩阵乘法即可:

预测函数可以写为:

 

损失函数

类似nn方法,使用交叉熵作为损失函数,如果有N个样本,损失函数可以写为:

L(y,o)=1NnNynlogonL(y,o)=−1N∑n∈Nynlog⁡on

下面两个函数用来计算损失:

 

BPTT学习参数

BPTT( Backpropagation Through Time)是一种非常直观的方法,和传统的BP类似,只不过传播的路径是个『循环』,并且路径上的参数是共享的。

损失是交叉熵,损失可以表示为:

Et(yt,y^t)E(y,y^)=ytlogy^t=tEt(yt,y^t)=tytlogy^tEt(yt,y^t)=−ytlog⁡y^tE(y,y^)=∑tEt(yt,y^t)=−∑tytlog⁡y^t

其中 ytyt 是真实值, (^yt)(^yt) 是预估值,将误差展开可以用图表示为:

rnn-bptt1

所以对所有误差求W的偏导数为:

EW=tEtW∂E∂W=∑t∂Et∂W

进一步可以将 EtEt 表示为:

E3V=E3y^3y^3V=E3y^3y^3z3z3V=(y^3y3)s3∂E3∂V=∂E3∂y^3∂y^3∂V=∂E3∂y^3∂y^3∂z3∂z3∂V=(y^3−y3)⊗s3

根据链式法则和RNN中W权值共享,可以得到:

E3W=k=03E3y^3y^3s3s3skskW∂E3∂W=∑k=03∂E3∂y^3∂y^3∂s3∂s3∂sk∂sk∂W

下图将这个过程表示的比较形象

rnn-bptt-with-gradients

BPTT更新梯度的代码:

 

梯度弥散现象

tanh和sigmoid函数和导数的取值返回如下图,可以看到导数取值是[0-1],用几次链式法则就会将梯度指数级别缩小,所以传播不了几层就会出现梯度非常弱。克服这个问题的LSTM是一种最近比较流行的解决方案。

tanh

Gradient Checking

梯度检验是非常有用的,检查的原理是一个点的『梯度』等于这个点的『斜率』,估算一个点的斜率可以通过求极限的方式:

Lθlimh0J(θ+h)J(θh)2h∂L∂θ≈limh→0J(θ+h)−J(θ−h)2h

通过比较『斜率』和『梯度』的值,我们就可以判断梯度计算的是否有问题。需要注意的是这个检验成本还是很高的,因为我们的参数个数是百万量级的。

梯度检验的代码:

 

SGD实现

这个公式应该非常熟悉:

W=WλΔWW=W−λΔW

其中 ΔWΔW 就是梯度,具体代码:

 

生成文本

生成过程其实就是模型的应用过程,只需要反复执行预测函数即可:

 

参考文献

转载地址:http://rcqca.baihongyu.com/

你可能感兴趣的文章
fastlane use_legacy_build_api true
查看>>
Git 学习笔记之 merge
查看>>
ajax跨域
查看>>
【三角函数】已知直角三角形的斜边长度和一个锐角角度,求另外两条直角边的长度...
查看>>
iOS端(腾讯Bugly)闪退异常上报扑获日志集成与使用指南
查看>>
在Markdown中输入数学公式
查看>>
萌新一手包App前后端开发日记(一)
查看>>
经典问题之「分支预测」
查看>>
好程序员web前端分享MVVM框架Vue实现原理
查看>>
Github入门详情教程
查看>>
Java之Set集合的"怪"
查看>>
快速入门Pytorch(1)--安装、张量以及梯度
查看>>
Hook技术之Hook Activity
查看>>
【国际专场】laravel多用户平台(SaaS, 如淘宝多用户商城)的搭建策略
查看>>
restTemplate使用和踩坑总结
查看>>
Laravel 5 4 实现前后台登录
查看>>
Volley 源码解析之网络请求
查看>>
从零开始撸一个Kotlin Demo
查看>>
01->选中UITableViewCell后,Cell中的UILabel的背景颜色变成透明色
查看>>
3章 RxJava操作符
查看>>