梯度消失
一、一句话定义
梯度消失:反向传播时,梯度随着网络层数加深而指数级衰减,逐渐趋近于 0,导致浅层(靠近输入层)的参数几乎不更新,模型学不到底层的特征。
二、具体是怎样发生的?
2.1 核心原因:也是连乘效应
和梯度爆炸一样,梯度消失也是链式法则连乘的结果。区别在于:
- 梯度爆炸:乘数都 > 1 → 越乘越大
- 梯度消失:乘数都 < 1 → 越乘越小
python
1 | # 一个3层网络的梯度传播(和前一讲相同的形式) |
如果每个乘数都小于 1,梯度就会指数级衰减到 0。
2.2 触发条件
| 条件 | 说明 | 示例 |
|---|---|---|
| 激活函数导数 < 1 | sigmoid 导数 ≤ 0.25,tanh 导数 ≤ 1 | sigmoid 饱和区导数接近 0 |
| 权重值 < 1 | 初始化太小或训练中权重变小 | W = [[0.1, 0.2], [0.1, -0.2]] |
| 网络很深 | 层数越多,连乘次数越多 | 10 层:0.25^10 ≈ 9.5e-7 |
| 激活函数进入饱和区 | 输入太大或太小,sigmoid/tanh 导数 ≈ 0 | sigmoid(10) ≈ 1,导数 ≈ 0 |
2.3 数字演算(以 sigmoid 为例)
python
1 | sigmoid 函数特性: |
注意:sigmoid 的实际导数往往小于 0.25(因为大多数输入不在 0 附近),所以实际衰减比这个例子更快。
三、梯度消失会导致什么后果?
3.1 浅层参数不更新
python
1 | # 以 10 层网络为例 |
3.2 模型“学不动”了
- 深层(靠近输出):勉强能学到一些东西,但因为没有浅层特征支撑,效果很差
- 浅层(靠近输入):几乎没被训练,仍然是随机初始化的状态
- 整体效果:模型的 loss 下降极慢,或者卡在一个很高的值下不去
3.3 深层网络反而比浅层网络差
这是最反直觉的现象:
python
1 | # 现象 |
四、哪些场景特别容易发生梯度消失?
4.1 使用 sigmoid 或 tanh 激活函数
| 激活函数 | 导数范围 | 最大导数 | 风险 |
|---|---|---|---|
| sigmoid | (0, 0.25] | 0.25 | 🔴 极高 |
| tanh | (0, 1] | 1.0 | 🟡 中等 |
| ReLU | {0, 1} | 1.0 | 🟢 低 |
| LeakyReLU | {0.01, 1} | 1.0 | 🟢 低 |
sigmoid 的导数最大值只有 0.25,连乘几层就消失了。这也是为什么现代深度学习几乎不用 sigmoid 作为隐藏层激活函数。
4.2 权重初始化太小
python
1 | # 错误的初始化方式 |
4.3 网络层数过深
即使是 ReLU,没有残差连接的话,深层网络梯度消失风险也很大。因为:
- 如果某层激活函数输出为 0(ReLU 死亡),梯度就断了
- 权重的随机性可能导致部分层有效梯度 < 1
4.4 RNN 处理长序列
RNN 的时间步展开和深度网络类似。
- 梯度消失:RNN 记不住很久以前的信息(LSTM 部分解决)
- 梯度爆炸:RNN 在某个时间步突然崩溃(梯度裁剪解决)
五、如何解决?(重点)
| 方法 | 原理 | 一句话评价 |
|---|---|---|
| ReLU / 变体激活函数 | 正区间导数为 1,不衰减梯度 | 🥇 最关键、最有效的改进 |
| 残差连接 | 梯度有一条“高速公路”直通浅层 | 🥇 Transformer 等架构的基础 |
| LayerNorm / BatchNorm | 稳定输入分布,防止进入饱和区 | 🥈 间接但很重要 |
| 合理初始化 | 让每层输出的方差保持为 1 | 🥈 治本,从源头防止 |
| LSTM / GRU | 门控机制让梯度可以长时间保持 | 🥇 RNN 场景的最佳方案 |
5.1 使用 ReLU 及其变体(最关键)
python
1 | # ReLU:最简单、最常用 |
ReLU 带来的问题:
- ReLU 死亡:如果输入始终为负,神经元输出恒为 0,梯度为 0,永远无法复活
- 解决方案:LeakyReLU、ELU、或更小的学习率 + 合理的初始化
5.2 残差连接(Add)
原理:让梯度有一条不经过任何权重乘法的直通路径。
python
1 | # 残差块 |
梯度传播时:
text
1 | ∂loss/∂x = ∂loss/∂output × ∂output/∂x |
效果:即使 ∂layer/∂x 的梯度消失为 0,∂loss/∂x 仍然有 ∂loss/∂output,不会消失。
5.3 使用 LayerNorm / BatchNorm
原理:把激活函数的输入稳定在非饱和区(比如均值 0、方差 1 附近),让导数保持在较大值。
python
1 | # 以 sigmoid 为例 |
5.4 合理的权重初始化(Xavier / Kaiming)
原理:让每层的输出方差保持为 1,输入信号不衰减。
python
1 | # Xavier 初始化(适合 tanh / sigmoid) |
5.5 使用 LSTM / GRU(RNN 专用)
原理:引入门控机制(遗忘门、输入门、输出门),让梯度可以选择性地通过,不会在每个时间步都衰减。
python
1 | # LSTM 的核心思想 |
六、如何判断发生了梯度消失?
6.1 明显的信号
| 现象 | 说明 |
|---|---|
| Loss 下降极慢 | 训练了很久,loss 还是很高 |
| 浅层参数几乎不变 | 打印浅层权重的变化,几乎为 0 |
| 梯度值极小 | 浅层梯度在 1e-6 以下,甚至为 0 |
| 深层效果好,浅层效果差 | 靠近输出的层学到了一些东西,但整体模型很差 |
| 模型不收敛 | loss 卡在一个值,怎么训练都不降 |
6.2 监控代码
python
1 | # 训练循环中监控各层梯度的分布 |
七、梯度消失 vs 梯度爆炸:对比总结
| 维度 | 梯度消失 | 梯度爆炸 |
|---|---|---|
| 现象 | 梯度 → 0 | 梯度 → ∞ |
| 原因 | 乘数 < 1 连乘 | 乘数 > 1 连乘 |
| 后果 | 浅层学不动,模型欠拟合 | Loss 变 NaN,模型崩溃 |
| 发生概率 | 更常见 | 相对少见 |
| 主要元凶 | sigmoid/tanh 的导数 < 1 | 权重初始化太大 |
| 高发场景 | 深层网络、RNN | RNN、初始化不当 |
| 直接解法 | 用 ReLU | 梯度裁剪 |
| 治本解法 | 残差连接 + 归一化 | 合理初始化 + 归一化 |