Alex_McAvoy

想要成为渔夫的猎手

经典循环神经网络之 LSTM

References:

【概述】

长短期记忆(Long Short-Term Memory,LSTM)单元是 RNN 存储单元的一种变体,使用 LSTM 单元的 RNN 可以解决长序列数据训练过程中的梯度爆炸与梯度消失问题

在使用 LSTM 单元的 RNN 网络中,常规的存储单元被 LSTM 单元替代,其在 RNN 的存储单元的基础上,添加了三个门控函数,用于控制每一时间步信息的记忆与遗忘

  • 输入门(Input Gate):在每一时刻从输入层输入的信息,会先经过输入门,输入门的开关会决定该时刻是否会有信息输入到存储单元中
  • 输出门(Output Gate):输出门的开关会决定每一时刻的信息是否会从存储单元中输出
  • 遗忘门(Forget Gate):在每一时刻存储单元的值会经历一个遗忘过程,遗忘门的开关会决定是否遗忘存储单元中的信息

【LSTM 单元结构】

单个 LSTM 单元的内部结构如下图所示:

相比于 RNN 只有一个传递状态 $s_t$,LSTM 有两个传递状态,即单元状态 $c_t$ 和隐藏状态 $h_t$,其中,RNN 中的 $s_t$ 对应于 LSTM 中的 $c_t$

在 $t$ 时刻,符号假设如下:

  • $h_t$:$t$ 时刻的 LSTM 单元隐藏输出,即隐藏状态(Hidden State)
  • $c_t$:$t$ 时刻的 LSTM 单元存储单元的值,即单元状态(Cell State)
  • $x_t$:$t$ 时刻的输入数据

【前向传播】

单元输入与门控状态

LSTM 单元内部使用当前输入 $x_t$ 和上一时间步传递下来的 $h_{t-1}$ 拼接训练来得到四种状态:$Z$、$Z_i$、$Z_f$、$Z_o$

$Z$ 为 LSTM 单元的输入数据,从上一时刻到下一时刻隐藏层的连接权重矩阵为 $W$,偏置项为 $\mathbf{b}$,有:

即 $Z$ 是 $t$ 时刻的输入数据 $x_t$ 与上一时刻 $t-1$ 的隐藏状态 $h_{t-1}$ 经过向量拼接后,与权重矩阵 $W$ 点积后加上偏置项的值,经过 tanh 激活函数后,得到一个 $-1$ 到 $1$ 的值,作为 LSTM 单元的单元输入

输入门、遗忘门、输出门的输出分别为 $Z_i,Z_f,Z_o$,连接权重矩阵分别为 $W_i,W_f,W_o$,偏置项分别为 $\mathbf{b}_i,\mathbf{b}_f,\mathbf{b}_o$,有:

即 $t$ 时刻的输入数据 $x_t$ 与上一时刻 $t-1$ 的隐藏状态 $h_{t-1}$ 经过向量拼接后,与权重矩阵点积后加上偏置项的值,经过 sigmoid 激活函数后,得到一个 $0$ 到 $1$ 间的值,作为一种门控信号

三个阶段

在每个时间步,得到四个状态 $Z$、$Z_i$、$Z_f$、$Z_o$ 后,即进行 LSTM 单元内部的三个阶段

1)遗忘阶段:该阶段会针对上一时间步传进来的输入进行选择性忘记,即通过遗忘门 $Z_f$ 来控制上一时间步的单元状态 $c_{t-1}$

其中,$\odot$ 为哈达玛积(Hadamard Product),即对相同形状的矩阵中的对应元素相乘

2)选择记忆阶段:该阶段会将当前时间步的 LSTM 单元的输入进行选择性记忆,即通过输入门 $Z_i$ 来当前时间步控制 LSTM 单元的输入 $Z$

将遗忘阶段和选择记忆阶段的结果相加,即得到当前时间步的单元状态 $c_t$,即:

3)输出阶段:该阶段将决定哪些会被当成当前时间步的输出,即通过遗忘门 $Z_f$ 来控制由 1)、2)阶段得到的当前时间步的单元状态 $c_t$,同时,在经过遗忘门 $Z_f$ 处理前,还使用了 tanh 激活函数对 $c_t$ 进行放缩

最后,与普通的 RNN 类似,输出 $y_t$ 往往是通过 $h_t$ 变化得到

【梯度消失的解决】

在 RNN 中,求偏导过程中包含 $\prod\limits_{j=k+1}^t \frac{\partial s_j}{\partial s_{j-1}} = \prod\limits_{j=k+1}^t f’ \cdot W$ 项,根据隐藏层激活函数的导数 $f’$ 的取值,造成梯度消失或梯度爆炸问题

而在 LSTM 中,求偏导过程中同样包含类似于 $\prod\limits_{j=k+1}^t \frac{\partial s_j}{\partial s_{j-1}}$ 的项,即:

对于当前时间步的单元状态 $c_t$,有

将其展开,可得:

故有:

设 $P = \tanh’(x) \sigma(y)$,那么 $P$ 的函数图像如下所示

可以看得到,该函数的值基本上非 $0$ 即 $1$,那么有:

这就解决了 RNN 中的梯度消失问题

感谢您对我的支持,让我继续努力分享有用的技术与知识点!