References:
【概述】
长短期记忆(Long Short-Term Memory,LSTM)单元是 RNN 存储单元的一种变体,使用 LSTM 单元的 RNN 可以解决长序列数据训练过程中的梯度爆炸与梯度消失问题
在使用 LSTM 单元的 RNN 网络中,常规的存储单元被 LSTM 单元替代,其在 RNN 的存储单元的基础上,添加了三个门控函数,用于控制每一时间步信息的记忆与遗忘
- 输入门(Input Gate):在每一时刻从输入层输入的信息,会先经过输入门,输入门的开关会决定该时刻是否会有信息输入到存储单元中
- 输出门(Output Gate):输出门的开关会决定每一时刻的信息是否会从存储单元中输出
- 遗忘门(Forget Gate):在每一时刻存储单元的值会经历一个遗忘过程,遗忘门的开关会决定是否遗忘存储单元中的信息
【LSTM 单元结构】
单个 LSTM 单元的内部结构如下图所示:
相比于 RNN 只有一个传递状态 $s_t$,LSTM 有两个传递状态,即单元状态 $c_t$ 和隐藏状态 $h_t$,其中,RNN 中的 $s_t$ 对应于 LSTM 中的 $c_t$
在 $t$ 时刻,符号假设如下:
- $h_t$:$t$ 时刻的 LSTM 单元隐藏输出,即隐藏状态(Hidden State)
- $c_t$:$t$ 时刻的 LSTM 单元存储单元的值,即单元状态(Cell State)
- $x_t$:$t$ 时刻的输入数据
【前向传播】
单元输入与门控状态
LSTM 单元内部使用当前输入 $x_t$ 和上一时间步传递下来的 $h_{t-1}$ 拼接训练来得到四种状态:$Z$、$Z_i$、$Z_f$、$Z_o$
$Z$ 为 LSTM 单元的输入数据,从上一时刻到下一时刻隐藏层的连接权重矩阵为 $W$,偏置项为 $\mathbf{b}$,有:
即 $Z$ 是 $t$ 时刻的输入数据 $x_t$ 与上一时刻 $t-1$ 的隐藏状态 $h_{t-1}$ 经过向量拼接后,与权重矩阵 $W$ 点积后加上偏置项的值,经过 tanh 激活函数后,得到一个 $-1$ 到 $1$ 的值,作为 LSTM 单元的单元输入
输入门、遗忘门、输出门的输出分别为 $Z_i,Z_f,Z_o$,连接权重矩阵分别为 $W_i,W_f,W_o$,偏置项分别为 $\mathbf{b}_i,\mathbf{b}_f,\mathbf{b}_o$,有:
即 $t$ 时刻的输入数据 $x_t$ 与上一时刻 $t-1$ 的隐藏状态 $h_{t-1}$ 经过向量拼接后,与权重矩阵点积后加上偏置项的值,经过 sigmoid 激活函数后,得到一个 $0$ 到 $1$ 间的值,作为一种门控信号
三个阶段
在每个时间步,得到四个状态 $Z$、$Z_i$、$Z_f$、$Z_o$ 后,即进行 LSTM 单元内部的三个阶段
1)遗忘阶段:该阶段会针对上一时间步传进来的输入进行选择性忘记,即通过遗忘门 $Z_f$ 来控制上一时间步的单元状态 $c_{t-1}$
其中,$\odot$ 为哈达玛积(Hadamard Product),即对相同形状的矩阵中的对应元素相乘
2)选择记忆阶段:该阶段会将当前时间步的 LSTM 单元的输入进行选择性记忆,即通过输入门 $Z_i$ 来当前时间步控制 LSTM 单元的输入 $Z$
将遗忘阶段和选择记忆阶段的结果相加,即得到当前时间步的单元状态 $c_t$,即:
3)输出阶段:该阶段将决定哪些会被当成当前时间步的输出,即通过遗忘门 $Z_f$ 来控制由 1)、2)阶段得到的当前时间步的单元状态 $c_t$,同时,在经过遗忘门 $Z_f$ 处理前,还使用了 tanh 激活函数对 $c_t$ 进行放缩
最后,与普通的 RNN 类似,输出 $y_t$ 往往是通过 $h_t$ 变化得到
【梯度消失的解决】
在 RNN 中,求偏导过程中包含 $\prod\limits_{j=k+1}^t \frac{\partial s_j}{\partial s_{j-1}} = \prod\limits_{j=k+1}^t f’ \cdot W$ 项,根据隐藏层激活函数的导数 $f’$ 的取值,造成梯度消失或梯度爆炸问题
而在 LSTM 中,求偏导过程中同样包含类似于 $\prod\limits_{j=k+1}^t \frac{\partial s_j}{\partial s_{j-1}}$ 的项,即:
对于当前时间步的单元状态 $c_t$,有
将其展开,可得:
故有:
设 $P = \tanh’(x) \sigma(y)$,那么 $P$ 的函数图像如下所示
可以看得到,该函数的值基本上非 $0$ 即 $1$,那么有:
这就解决了 RNN 中的梯度消失问题