使用成熟的Tensorflow、PyTorch框架去实现递归神经网络(RNN),已经极大降低了技术的使用门槛。
但是,对于初学者,这还是远远不够的。知其然,更需知其所以然。
要避免低级错误,打好理论基础,然后使用RNN去解决更多实际的问题的话。
那么,有一个有趣的问题可以思考一下:
不使用Tensorflow等框架,只有Numpy的话,你该如何构建RNN?
没有头绪也不用担心。这里便有一项教程:使用Numpy从头构建用于NLP领域的RNN。
可以带你行进一遍RNN的构建流程。
初始化参数
与传统的神经网络不同,RNN具有3个权重参数,即:
输入权重(input weights),内部状态权重(internal state weights)和输出权重(output weights)
首先用随机数值初始化上述三个参数。
之后,将词嵌入维度(word_embedding dimension)和输出维度(output dimension)分别初始化为100和80。
输出维度是词汇表中存在的唯一词向量的总数。
, (output_dim,hidden_dim))
变量prev_memory指的是internal_state(这些是先前序列的内存)。
其他参数也给予了初始化数值。
input_weight梯度,internal_state_weight梯度和output_weight梯度分别命名为dU,dW和dV。
变量bptt_truncate表示网络在反向传播时必须回溯的时间戳数,这样做是为了克服梯度消失的问题。
前向传播
输出和输入向量
例如有一句话为:I like to play.,则假设在词汇表中:
I被映射到索引2,like对应索引45,to对应索引10、**对应索引64而标点符号.** 对应索引1。
为了展示从输入到输出的情况,我们先随机初始化每个单词的词嵌入。
输入已经完成,接下来需要考虑输出。
在本项目中,RNN单元接受输入后,输出的是下一个最可能出现的单词。
用于训练RNN,在给定第t+1个词作为输出的时候将第t个词作为输入,例如:在RNN单元输出字为“like”的时候给定的输入字为“I”.
现在输入是嵌入向量的形式,而计算损失函数(Loss)所需的输出格式是独热编码(One-Hot)矢量。
这是对输入字符串中除第一个单词以外的每个单词进行的操作,因为该神经网络学习只学习的是一个示例句子,而初始输入是该句子的第一个单词。
RNN的黑箱计算
现在有了权重参数,也知道输入和输出,于是可以开始前向传播的计算。
训练神经网络需要以下计算:
其中:
U代表输入权重、W代表内部状态权重,V代表输出权重。
输入权重乘以input(x),内部状态权重乘以前一层的激活(prev_memory)。
层与层之间使用的激活函数用的是tanh。
计算损失函数
之后损失函数使用的是交叉熵损失函数,由下式给出:
最重要的是,我们需要在上面的代码中看到第5行。
正如所知,ground_truth output(y)的形式是[0,0,….,1,…0]和predicted_output(y^hat)是[0.34,0.03,……,0.45]的形式,我们需要损失是单个值来从它推断总损失。
为此,使用sum函数来获得特定时间戳下y和y^hat向量中每个值的误差之和。
total_loss是整个模型(包括所有时间戳)的损失。
反向传播
反向传播的链式法则:
如上图所示:
Cost代表误差,它表示的是y^hat到y的差值。
由于Cost是的函数输出,因此激活a所反映的变化由dCost/da表示。
实际上,这意味着从激活节点的角度来看这个变化(误差)值。
类似地,a相对于z的变化表示为da/dz,z相对于w的变化表示为dw/dz。
最终,我们关心的是权重的变化(误差)有多大。
而由于权重与Cost之间没有直接关系,因此期间各个相对的变化值可以直接相乘(如上式所示)。
RNN的反向传播
由于RNN中存在三个权重,因此我们需要三个梯度。input_weights(dLoss / dU),internal_state_weights(dLoss / dW)和output_weights(dLoss / dV)的梯度。
这三个梯度的链可以表示如下:
所述dLoss/dy_unactivated代码如下:
计算两个梯度函数,一个是multiplication_backward,另一个是additional_backward。
在multiplication_backward的情况下,返回2个参数,一个是相对于权重的梯度(dLoss / dV),另一个是链梯度(chain gradient),该链梯度将成为计算另一个权重梯度的链的一部分。
在addition_backward的情况下,在计算导数时,加法函数(ht_unactivated)中各个组件的导数为1。例如:dh_unactivated / dU_frd=1(h_unactivated = U_frd + W_frd),且dU_frd / dU_frd的导数为1。
所以,计算梯度只需要这两个函数。multiplication_backward函数用于包含向量点积的方程,addition_backward用于包含两个向量相加的方程。
至此,已经分析并理解了RNN的反向传播,目前它是在单个时间戳上实现它的功能,之后可以将其用于计算所有时间戳上的梯度。
如下面的代码所示,forward_params_t是一个列表,其中包含特定时间步长的网络的前向参数。
变量ds是至关重要的部分,因为此行代码考虑了先前时间戳的隐藏状态,这将有助于提取在反向传播时所需的信息。
对于RNN,由于存在梯度消失的问题,所以采用的是截断的反向传播,而不是使用原始的。
在此技术中,当前单元将只查看k个时间戳,而不是只看一次时间戳,其中k表示要回溯的先前单元的数量。
权重更新
一旦使用反向传播计算了梯度,则更新权重势在必行,而这些是通过批量梯度下降法
训练序列
完成了上述所有步骤,就可以开始训练神经网络了。
用于训练的学习率是静态的,还可以使用逐步衰减等更改学习率的动态方法。
)
恭喜你!你现在已经实现从头建立递归神经网络了!
那么,是时候了,继续向LSTM和GRU等的高级架构前进吧。
<section data-brushtype="text" style="padding-right: 0em;padding-left: 0em;white-space: normal;max-width: 100%;letter-spacing: 0.544px;color: rgb(62, 62, 62);font-family: "Helvetica Neue", Helvetica, "Hiragino Sans GB", "Microsoft YaHei", Arial, sans-serif;widows: 1;word-spacing: 2px;caret-color: rgb(255, 0, 0);text-align: center;box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;color: rgb(0, 0, 0);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;box-sizing: border-box !important;overflow-wrap: break-word !important;">—</span></strong>完<strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;box-sizing: border-box !important;overflow-wrap: break-word !important;">—</span></strong></span></strong></span></strong></section><pre style="padding-right: 0em;padding-left: 0em;max-width: 100%;letter-spacing: 0.544px;color: rgb(62, 62, 62);widows: 1;word-spacing: 2px;caret-color: rgb(255, 0, 0);text-align: center;box-sizing: border-box !important;overflow-wrap: break-word !important;"><pre style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;letter-spacing: 0.544px;white-space: normal;font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section powered-by="xiumi.us" style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="margin-top: 15px;margin-bottom: 25px;max-width: 100%;opacity: 0.8;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section powered-by="xiumi.us" style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="margin-top: 15px;margin-bottom: 25px;max-width: 100%;opacity: 0.8;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="margin-bottom: 15px;padding-right: 0em;padding-left: 0em;max-width: 100%;color: rgb(127, 127, 127);font-size: 12px;font-family: sans-serif;line-height: 25.5938px;letter-spacing: 3px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;color: rgb(0, 0, 0);box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;font-size: 16px;font-family: 微软雅黑;caret-color: red;box-sizing: border-box !important;overflow-wrap: break-word !important;">为您推荐</span></strong></span></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;">干货 | 算法工程师超实用技术路线图</section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="font-size: 14px;">13个算法工程师必须掌握的PyTorch Tricks</span><span style="letter-spacing: 0.544px;font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;"></span></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="font-size: 14px;">吴恩达上新:生成对抗网络(GAN)专项课程</span><br /></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;">拿到2021灰飞烟灭算法岗offer的大佬们是啥样的<span style="font-size: 14px;">?</span><br /></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;color: rgb(0, 0, 0);box-sizing: border-box !important;overflow-wrap: break-word !important;">你一定从未看过如此通俗易懂的YOLO系列解读 (下)</section></section></section></section></section></section></section></section></section>
本篇文章来源于: 深度学习这件小事
本文为原创文章,版权归知行编程网所有,欢迎分享本文,转载请保留出处!
内容反馈