梯度爆炸

AI

一、一句话定义

梯度爆炸:反向传播时,梯度随着网络层数加深而指数级增长,变成天文数字(如 10^8 甚至 inf),导致参数更新步长巨大,模型直接崩溃。


二、具体是怎样发生的?

2.1 核心原因:连乘效应

反向传播时,梯度从输出层往输入层传播,每经过一层都要乘以该层权重的某种组合。如果这些乘数都大于1,梯度就会指数级膨胀。

python

1
2
3
4
# 一个3层网络的梯度传播(简化示意)
∂loss/∂W₁ = ∂loss/∂h₃ · W₃ · ∂h₃/∂h₂ · W₂ · ∂h₂/∂h₁ · W₁ · ∂h₁/∂W₁
↑ ↑
从loss传来的初始梯度 每过一层就多乘一个W

2.2 触发条件

条件 说明 示例
权重值 > 1 初始化或训练中权重变得较大 W = [[2.0, 1.5], [1.5, 2.0]]
没有饱和激活函数 ReLU在正区间导数为1,不会压制梯度 ReLU(x) = max(0,x) 导数恒为1
网络很深 层数越多,连乘次数越多 10层:2^10=1024倍放大
重复使用权重 RNN在同一权重上循环乘多次 时间步100步 = 乘100次

2.3 数字演算

python

1
2
3
4
5
6
7
8
9
10
11
12
13
14
假设:
- 每层权重的“尺度”约为 2
- 激活函数用 ReLU(导数恒为 1)
- 初始梯度(从 loss 传下来)= 1

经过第 1 层后:梯度 = 1 × 2 = 2
经过第 2 层后:梯度 = 2 × 2 = 4
经过第 3 层后:梯度 = 4 × 2 = 8
...
经过第 10 层后:梯度 = 1 × 2^10 = 1024
经过第 20 层后:梯度 = 1 × 2^20 ≈ 1,048,576(百万级)
经过第 30 层后:梯度 = 1 × 2^30 ≈ 1,073,741,824(十亿级)

# 指数增长,非常可怕

三、梯度爆炸会导致什么后果?

3.1 参数变成 NaN

python

1
2
3
4
5
6
7
# 正常的参数更新公式
w_new = w_old - learning_rate × gradient

# 假设 gradient = 1e9,learning_rate = 0.001
# 更新步长 = 1e9 × 0.001 = 1e6
# w_old 原本可能是 0.5,一步之后变成 1000000.5
# 下一轮前向传播,w × x 直接爆掉

3.2 Loss 变成 NaN(Not a Number)

因为参数数值巨大,前向传播时计算的中间值也会巨大。一旦超过浮点数表示范围(float32 上限约 3.4e38),就会变成 inf(无穷大)。inf 参与任何运算,结果都是 NaN

3.3 模型无法恢复

一旦参数中出现 NaN,后续所有计算都是 NaN。模型永远回不到正常状态,训练必须中断并从头重启


四、哪些场景特别容易发生梯度爆炸?

4.1 RNN / LSTM 处理长序列

RNN 在每个时间步使用同一个权重矩阵。如果这个矩阵的特征值大于 1,梯度经过 100 个时间步的反向传播,就会被放大 100 次方倍。

python

1
2
3
4
# RNN 的循环结构
h_t = tanh(W_h @ h_{t-1} + W_x @ x_t)
# 梯度从第100步传回第1步,需要乘100次 W_h
# 如果 W_h 的特征值 > 1 → 爆炸

这也是为什么后来有了 LSTM 和 GRU——它们用门控机制来缓解这个问题。

4.2 初始化权重过大

python

1
2
3
# 错误的初始化方式
W = torch.randn(512, 512) * 2 # 均值为0,标准差为2
# 每层输出范围放大2倍,梯度也随层数指数放大

4.3 层数过深且没有保护机制

Transformer 有 12~24 层,如果不用残差连接和层归一化,梯度爆炸风险很高。


五、如何解决?(重点)

方法 原理 一句话评价
梯度裁剪 限制梯度的最大值 最直接、最常用、几乎零成本
合理初始化 让权重尺度 ≈ 1/√fan_in 治本,从源头防止
LayerNorm/BatchNorm 稳定每层数值范围 间接防止,顺便解决其他问题
残差连接 梯度有一条“高速公路” 让梯度可以不经过权重乘法
降低学习率 减小更新步长 治标不治本,但有帮助

5.1 梯度裁剪(最推荐)

原理:计算所有梯度的整体范数(可以理解为“梯度的总长度”),如果超过设定阈值,就按比例缩小,但保持方向不变

python

1
2
3
4
5
6
7
8
9
10
11
12
# PyTorch 中的用法(只需一行代码)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# 具体效果演示
原始梯度 = [0.5, 100, 0.3, 500]
计算整体范数 ≈ sqrt(0.5² + 100² + 0.3² + 500²) ≈ 506

max_norm = 1.0
缩放因子 = 1.0 / 506 ≈ 0.002

裁剪后 ≈ [0.001, 0.198, 0.0006, 0.988]
# 方向没变(比例关系保留),但整体大小被限制在1以内

优点

  • 直接有效,99% 的情况下能解决爆炸
  • 计算成本极低(只多一步范数计算)
  • 对训练速度影响微乎其微

缺点

  • 只是“压在合理范围”,没有解决爆炸的根源

5.2 合理的权重初始化

原理:让每层输出的方差保持为 1,避免信号和梯度指数放大。

python

1
2
3
4
5
# Xavier / Glorot 初始化(适合 tanh / sigmoid)
nn.init.xavier_uniform_(W)

# Kaiming / He 初始化(适合 ReLU)
nn.init.kaiming_uniform_(W, nonlinearity='relu')

效果:权重尺度约在 1 / √fan_in 左右,不会显著大于 1。

5.3 使用归一化层

LayerNorm(Transformer 用)和 BatchNorm(CNN 用)通过稳定每一层的数值范围,间接防止梯度爆炸。

5.4 使用残差连接

残差连接 output = x + F(x) 让梯度有一条直通路径∂loss/∂x 可以直接从输出传到输入,不经过任何权重矩阵乘法。这条路径上的梯度不会爆炸。


六、如何判断发生了梯度爆炸?

6.1 明显的信号

现象 说明
Loss 突然变成 NaN 最明显的信号,一眼就能看出来
Loss 在某一步突然飙升 比如从 2.3 跳到 1e8,然后变 NaN
模型参数出现 inf print(param) 能看到 inf
梯度值巨大 正常梯度范数在 0.1~10 之间,爆炸时可达 1e6 以上

6.2 监控代码

python

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 训练循环中加入梯度监控
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm().item()
if grad_norm > 100:
print(f"⚠️ 梯度爆炸警告!{name}: {grad_norm}")
elif grad_norm > 10:
print(f"🔔 梯度偏大:{name}: {grad_norm}")

# 或者监控整体梯度范数
total_norm = 0
for p in model.parameters():
if p.grad is not None:
total_norm += p.grad.norm().item() ** 2
total_norm = total_norm ** 0.5
print(f"全局梯度范数: {total_norm}") # 正常 <10,爆炸时 >1000

七、总结:一张表说清楚

问题 答案
什么是梯度爆炸? 梯度随层数指数增长,变成天文数字
怎么发生的? 每层权重 > 1,连乘效应导致指数爆炸
后果是什么? 参数飞上天 → Loss 变 NaN → 模型崩溃
最直接的解法? 梯度裁剪(一行代码,立刻生效)
治本的解法? 合理初始化 + 归一化 + 残差连接
高发场景? RNN、长序列、深层网络、初始化太大
站内搜索
常搜: