梯度消失

AI

一、一句话定义

梯度消失:反向传播时,梯度随着网络层数加深而指数级衰减,逐渐趋近于 0,导致浅层(靠近输入层)的参数几乎不更新,模型学不到底层的特征。


二、具体是怎样发生的?

2.1 核心原因:也是连乘效应

和梯度爆炸一样,梯度消失也是链式法则连乘的结果。区别在于:

  • 梯度爆炸:乘数都 > 1 → 越乘越大
  • 梯度消失:乘数都 < 1 → 越乘越小

python

1
2
3
4
# 一个3层网络的梯度传播(和前一讲相同的形式)
∂loss/∂W₁ = ∂loss/∂h₃ · W₃ · ∂h₃/∂h₂ · W₂ · ∂h₂/∂h₁ · W₁ · ∂h₁/∂W₁
↑ ↑
从loss传来的初始梯度 每过一层就多乘一个数

如果每个乘数都小于 1,梯度就会指数级衰减到 0。

2.2 触发条件

条件 说明 示例
激活函数导数 < 1 sigmoid 导数 ≤ 0.25,tanh 导数 ≤ 1 sigmoid 饱和区导数接近 0
权重值 < 1 初始化太小或训练中权重变小 W = [[0.1, 0.2], [0.1, -0.2]]
网络很深 层数越多,连乘次数越多 10 层:0.25^10 ≈ 9.5e-7
激活函数进入饱和区 输入太大或太小,sigmoid/tanh 导数 ≈ 0 sigmoid(10) ≈ 1,导数 ≈ 0

2.3 数字演算(以 sigmoid 为例)

python

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
sigmoid 函数特性:
- 输出范围:(0, 1)
- 最大值处导数:σ'(0) = 0.25
- 饱和区(|x|>5)导数:≈ 0.00...

假设:
- 每层激活函数的导数 ≈ 0.25(乐观估计)
- 初始梯度(从 loss 传下来)= 1

经过第 1 层后:梯度 = 1 × 0.25 = 0.25
经过第 2 层后:梯度 = 0.25 × 0.25 = 0.0625
经过第 3 层后:梯度 = 0.0625 × 0.25 = 0.015625
...
经过第 10 层后:梯度 = 1 × 0.25^10 ≈ 9.5 × 10⁻⁷(不到百万分之一)
经过第 20 层后:梯度 = 1 × 0.25^20 ≈ 9.1 × 10⁻¹³(几乎为 0)

# 指数衰减,同样可怕

注意:sigmoid 的实际导数往往小于 0.25(因为大多数输入不在 0 附近),所以实际衰减比这个例子更快。


三、梯度消失会导致什么后果?

3.1 浅层参数不更新

python

1
2
3
4
5
6
# 以 10 层网络为例
第 10 层(靠近输出):梯度 ≈ 1 → 参数正常更新 ✓
第 9 层:梯度 ≈ 0.25 → 参数更新变慢 ⚠️
第 8 层:梯度 ≈ 0.0625 → 几乎不动 ⚠️
第 7 层:梯度 ≈ 0.0156 → 基本不动 ❌
第 6 层及以下:梯度 ≈ 0.000... → 完全不动 ❌

3.2 模型“学不动”了

  • 深层(靠近输出):勉强能学到一些东西,但因为没有浅层特征支撑,效果很差
  • 浅层(靠近输入):几乎没被训练,仍然是随机初始化的状态
  • 整体效果:模型的 loss 下降极慢,或者卡在一个很高的值下不去

3.3 深层网络反而比浅层网络差

这是最反直觉的现象:

python

1
2
3
4
5
6
# 现象
20 层网络 < 10 层网络 < 5 层网络 (效果对比)

# 原因
不是过拟合,而是:20 层网络的浅层梯度已经消失了,有效层数可能只有 5 层左右
其余 15 层是“僵尸层”——既没学到东西,还增加了噪声

四、哪些场景特别容易发生梯度消失?

4.1 使用 sigmoid 或 tanh 激活函数

激活函数 导数范围 最大导数 风险
sigmoid (0, 0.25] 0.25 🔴 极高
tanh (0, 1] 1.0 🟡 中等
ReLU {0, 1} 1.0 🟢 低
LeakyReLU {0.01, 1} 1.0 🟢 低

sigmoid 的导数最大值只有 0.25,连乘几层就消失了。这也是为什么现代深度学习几乎不用 sigmoid 作为隐藏层激活函数。

4.2 权重初始化太小

python

1
2
3
4
5
# 错误的初始化方式
W = torch.randn(512, 512) * 0.01 # 权重太小了!
# 每层输出方差 ≈ 0.0001,激活函数输入极小
# sigmoid(0.01) ≈ 0.5025,导数 ≈ 0.25(还好)
# 但 tanh(0.01) ≈ 0.01,导数 ≈ 0.9999(较安全)

4.3 网络层数过深

即使是 ReLU,没有残差连接的话,深层网络梯度消失风险也很大。因为:

  • 如果某层激活函数输出为 0(ReLU 死亡),梯度就断了
  • 权重的随机性可能导致部分层有效梯度 < 1

4.4 RNN 处理长序列

RNN 的时间步展开和深度网络类似。

  • 梯度消失:RNN 记不住很久以前的信息(LSTM 部分解决)
  • 梯度爆炸:RNN 在某个时间步突然崩溃(梯度裁剪解决)

五、如何解决?(重点)

方法 原理 一句话评价
ReLU / 变体激活函数 正区间导数为 1,不衰减梯度 🥇 最关键、最有效的改进
残差连接 梯度有一条“高速公路”直通浅层 🥇 Transformer 等架构的基础
LayerNorm / BatchNorm 稳定输入分布,防止进入饱和区 🥈 间接但很重要
合理初始化 让每层输出的方差保持为 1 🥈 治本,从源头防止
LSTM / GRU 门控机制让梯度可以长时间保持 🥇 RNN 场景的最佳方案

5.1 使用 ReLU 及其变体(最关键)

python

1
2
3
4
5
6
7
8
9
10
11
12
# ReLU:最简单、最常用
def relu(x):
return max(0, x)
# 导数为 1(当 x > 0)或 0(当 x < 0)
# 正区间梯度不衰减!

# 变体:LeakyReLU(解决了 ReLU 死亡问题)
def leaky_relu(x, alpha=0.01):
return max(alpha * x, x)
# 负区间也有小梯度,神经元不会“死亡”

# 变体:ELU / GELU / Swish(各有优缺点)

ReLU 带来的问题

  • ReLU 死亡:如果输入始终为负,神经元输出恒为 0,梯度为 0,永远无法复活
  • 解决方案:LeakyReLU、ELU、或更小的学习率 + 合理的初始化

5.2 残差连接(Add)

原理:让梯度有一条不经过任何权重乘法的直通路径。

python

1
2
3
4
# 残差块
def residual_block(x):
output = layer(x) # 对 x 做变换
return x + output # 输出 = 输入 + 变换

梯度传播时

text

1
2
3
4
5
6
∂loss/∂x = ∂loss/∂output × ∂output/∂x
= ∂loss/∂output × (1 + ∂layer/∂x)
= ∂loss/∂output + ∂loss/∂output × ∂layer/∂x

这条路径不经过 layer 的权重!
梯度可以直接从 output 传到 x

效果:即使 ∂layer/∂x 的梯度消失为 0,∂loss/∂x 仍然有 ∂loss/∂output,不会消失。

5.3 使用 LayerNorm / BatchNorm

原理:把激活函数的输入稳定在非饱和区(比如均值 0、方差 1 附近),让导数保持在较大值。

python

1
2
3
# 以 sigmoid 为例
没有归一化时:x 可能很大(比如 10)→ sigmoid(10)≈1 → 导数≈0
有归一化后:x 被拉回 0 附近 → sigmoid(0)=0.5 → 导数≈0.25

5.4 合理的权重初始化(Xavier / Kaiming)

原理:让每层的输出方差保持为 1,输入信号不衰减。

python

1
2
3
4
5
6
7
# Xavier 初始化(适合 tanh / sigmoid)
nn.init.xavier_uniform_(W)
# 让 Var(output) = Var(input)

# Kaiming 初始化(适合 ReLU)
nn.init.kaiming_uniform_(W, nonlinearity='relu')
# 专门针对 ReLU 设计,考虑了 ReLU 会砍掉一半神经元

5.5 使用 LSTM / GRU(RNN 专用)

原理:引入门控机制(遗忘门、输入门、输出门),让梯度可以选择性地通过,不会在每个时间步都衰减。

python

1
2
3
4
5
# LSTM 的核心思想
# 有一条“细胞状态”高速公路,梯度可以在上面长时间保持
c_t = f_t * c_{t-1} + i_t * g_t
# f_t 是遗忘门(0~1),控制多少旧信息保留
# 如果 f_t 接近 1,梯度几乎不衰减

六、如何判断发生了梯度消失?

6.1 明显的信号

现象 说明
Loss 下降极慢 训练了很久,loss 还是很高
浅层参数几乎不变 打印浅层权重的变化,几乎为 0
梯度值极小 浅层梯度在 1e-6 以下,甚至为 0
深层效果好,浅层效果差 靠近输出的层学到了一些东西,但整体模型很差
模型不收敛 loss 卡在一个值,怎么训练都不降

6.2 监控代码

python

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 训练循环中监控各层梯度的分布
for name, param in model.named_parameters():
if param.grad is not None:
grad_mean = param.grad.mean().item()
grad_std = param.grad.std().item()
print(f"{name}: mean={grad_mean:.2e}, std={grad_std:.2e}")

# 观察规律:
# 正常情况:浅层和深层的梯度量级相近
# 梯度消失:深层梯度 1e-3,浅层梯度 1e-8(差了 5 个数量级)

# 也可以可视化
import matplotlib.pyplot as plt

gradients = [p.grad.norm().item() for p in model.parameters()]
plt.plot(gradients)
plt.yscale('log') # 对数坐标,更容易看出衰减趋势
plt.title('梯度范数分布(左:浅层,右:深层)')
plt.show()

七、梯度消失 vs 梯度爆炸:对比总结

维度 梯度消失 梯度爆炸
现象 梯度 → 0 梯度 → ∞
原因 乘数 < 1 连乘 乘数 > 1 连乘
后果 浅层学不动,模型欠拟合 Loss 变 NaN,模型崩溃
发生概率 更常见 相对少见
主要元凶 sigmoid/tanh 的导数 < 1 权重初始化太大
高发场景 深层网络、RNN RNN、初始化不当
直接解法 用 ReLU 梯度裁剪
治本解法 残差连接 + 归一化 合理初始化 + 归一化
站内搜索
常搜: