梯度爆炸
一、一句话定义
梯度爆炸:反向传播时,梯度随着网络层数加深而指数级增长,变成天文数字(如 10^8 甚至 inf),导致参数更新步长巨大,模型直接崩溃。
二、具体是怎样发生的?
2.1 核心原因:连乘效应
反向传播时,梯度从输出层往输入层传播,每经过一层都要乘以该层权重的某种组合。如果这些乘数都大于1,梯度就会指数级膨胀。
python
1 | # 一个3层网络的梯度传播(简化示意) |
2.2 触发条件
| 条件 | 说明 | 示例 |
|---|---|---|
| 权重值 > 1 | 初始化或训练中权重变得较大 | W = [[2.0, 1.5], [1.5, 2.0]] |
| 没有饱和激活函数 | ReLU在正区间导数为1,不会压制梯度 | ReLU(x) = max(0,x) 导数恒为1 |
| 网络很深 | 层数越多,连乘次数越多 | 10层:2^10=1024倍放大 |
| 重复使用权重 | RNN在同一权重上循环乘多次 | 时间步100步 = 乘100次 |
2.3 数字演算
python
1 | 假设: |
三、梯度爆炸会导致什么后果?
3.1 参数变成 NaN
python
1 | # 正常的参数更新公式 |
3.2 Loss 变成 NaN(Not a Number)
因为参数数值巨大,前向传播时计算的中间值也会巨大。一旦超过浮点数表示范围(float32 上限约 3.4e38),就会变成 inf(无穷大)。inf 参与任何运算,结果都是 NaN。
3.3 模型无法恢复
一旦参数中出现 NaN,后续所有计算都是 NaN。模型永远回不到正常状态,训练必须中断并从头重启。
四、哪些场景特别容易发生梯度爆炸?
4.1 RNN / LSTM 处理长序列
RNN 在每个时间步使用同一个权重矩阵。如果这个矩阵的特征值大于 1,梯度经过 100 个时间步的反向传播,就会被放大 100 次方倍。
python
1 | # RNN 的循环结构 |
这也是为什么后来有了 LSTM 和 GRU——它们用门控机制来缓解这个问题。
4.2 初始化权重过大
python
1 | # 错误的初始化方式 |
4.3 层数过深且没有保护机制
Transformer 有 12~24 层,如果不用残差连接和层归一化,梯度爆炸风险很高。
五、如何解决?(重点)
| 方法 | 原理 | 一句话评价 |
|---|---|---|
| 梯度裁剪 | 限制梯度的最大值 | 最直接、最常用、几乎零成本 |
| 合理初始化 | 让权重尺度 ≈ 1/√fan_in | 治本,从源头防止 |
| LayerNorm/BatchNorm | 稳定每层数值范围 | 间接防止,顺便解决其他问题 |
| 残差连接 | 梯度有一条“高速公路” | 让梯度可以不经过权重乘法 |
| 降低学习率 | 减小更新步长 | 治标不治本,但有帮助 |
5.1 梯度裁剪(最推荐)
原理:计算所有梯度的整体范数(可以理解为“梯度的总长度”),如果超过设定阈值,就按比例缩小,但保持方向不变。
python
1 | # PyTorch 中的用法(只需一行代码) |
优点:
- 直接有效,99% 的情况下能解决爆炸
- 计算成本极低(只多一步范数计算)
- 对训练速度影响微乎其微
缺点:
- 只是“压在合理范围”,没有解决爆炸的根源
5.2 合理的权重初始化
原理:让每层输出的方差保持为 1,避免信号和梯度指数放大。
python
1 | # Xavier / Glorot 初始化(适合 tanh / sigmoid) |
效果:权重尺度约在 1 / √fan_in 左右,不会显著大于 1。
5.3 使用归一化层
LayerNorm(Transformer 用)和 BatchNorm(CNN 用)通过稳定每一层的数值范围,间接防止梯度爆炸。
5.4 使用残差连接
残差连接 output = x + F(x) 让梯度有一条直通路径:∂loss/∂x 可以直接从输出传到输入,不经过任何权重矩阵乘法。这条路径上的梯度不会爆炸。
六、如何判断发生了梯度爆炸?
6.1 明显的信号
| 现象 | 说明 |
|---|---|
Loss 突然变成 NaN |
最明显的信号,一眼就能看出来 |
| Loss 在某一步突然飙升 | 比如从 2.3 跳到 1e8,然后变 NaN |
模型参数出现 inf |
print(param) 能看到 inf |
| 梯度值巨大 | 正常梯度范数在 0.1~10 之间,爆炸时可达 1e6 以上 |
6.2 监控代码
python
1 | # 训练循环中加入梯度监控 |
七、总结:一张表说清楚
| 问题 | 答案 |
|---|---|
| 什么是梯度爆炸? | 梯度随层数指数增长,变成天文数字 |
| 怎么发生的? | 每层权重 > 1,连乘效应导致指数爆炸 |
| 后果是什么? | 参数飞上天 → Loss 变 NaN → 模型崩溃 |
| 最直接的解法? | 梯度裁剪(一行代码,立刻生效) |
| 治本的解法? | 合理初始化 + 归一化 + 残差连接 |
| 高发场景? | RNN、长序列、深层网络、初始化太大 |