在预训练和微调大模型时,我们几乎无脑选择 AdamW 优化器。平常写代码,可能只是简单地调包 ,但感觉还是很有必要理解一下原理的,今天就写一篇这样的博客。
一、SGD
记得刚接触深度学习时,老师教的是 SGD (随机梯度下降)。用 SGD 训练模型,就像开一辆手动挡的车,你需要极其精细地调节学习率。刚开始要大火猛炒,中间要小火慢炖,而且面对复杂的 loss 曲面,SGD 经常震荡,或者卡在梯度接近 0 局部最优处动弹不得。这就会导致训练直接变成了无尽的调参,loss不下降或者直接起飞的情况经常出现。
但是现在用的都是Adam和AdamW,有一种自动挡的感觉。我在训练的时候其实只需要初始化一个默认的学习率,就能获得不错的效果,因为它能利用梯度的惯性快速冲出平坦区,还能自适应地给不同参数分配不同的学习率。
简单来说,我对于其的理解就是,只要模型架构没问题,能跑起来,用Adam基本都能收敛。
二、Adam
Adam的核心在于它维护了两个状态变量,用来指导参数更新。
假设我们要更新参数 $w$,当前的梯度是 $g$。
1. 第一个状态变量:$m_t$ (一阶动量 / 惯性)
它记录过去梯度的平均方向。这个主要是为了解决震荡问题,$\beta_1$ 通常设为 0.9。意思是今天的方向,保留 90% 昨天的惯性,加上 10% 今天的梯度。 \(m_t = \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t\)
2. 第二个状态变量:$v_t$ (二阶动量 / 能量)
它记录过去梯度的波动大小(平方)。$\beta_2$ 通常设为 0.999。这个主要来衡量参数的活跃度,比如说梯度大的时候,V就会由于g很大变的很大,反之就很小 \(v_t = \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot g_t^2\)
3. Adam 更新公式
\(w_{new} = w_{old} - \eta \cdot \frac{m_t}{\sqrt{v_t} + \epsilon}\) 说白了,它和sgd相比,多出了这两个参数,一个是分子 $m_t$:带惯性的方向。一个是分母 $\sqrt{v_t}$:自适应步长(Adam 的灵魂)。如果参数太活跃($v$ 大),除以大数,步子变小(压制)。如果参数太懒惰($v$ 小),除以小数,步子变大(鼓励)。
三、Adam 的Bug:L2 正则化失效
为了防止模型过拟合,我们通常使用 L2 正则化,目的是让权重 $w$ 尽量趋近于 0。
1. 传统的 L2 实现
在 SGD 中,L2 正则化是直接加在梯度上的: \(g_{final} = g_{original} + \lambda w\) 这没问题,梯度变大了,参数更新时就会多减去一点,实现了衰减。
2. Adam 遇到的问题
如果我们把这个加了料的梯度 $g_{final}$ 代入 Adam 公式,问题就来了:
\[w_{new} = w_{old} - \eta \cdot \frac{m(包含 \lambda w)}{\sqrt{v(包含 \lambda w)}}\]注意分母 $\sqrt{v}$。如果 $w$ 很大(需要大幅衰减),那么 $\lambda w$ 很大。这导致 $v$(梯度的平方)变得非常大。
- 分母变大了,更新步长反而变小了!
我们希望 $w$ 越大衰减越快,但 Adam 的自适应机制(除以 $\sqrt{v}$)抵消了 L2 的效果,导致大参数衰减不动。这就是 Adam 在大模型训练中泛化能力差的根源。
四、AdamW
AdamW 的思路非常简单粗暴:既然 Adam 会乱改梯度,那我就不把 L2 加到梯度里了
1. AdamW 更新逻辑
算梯度:只算 Loss 的梯度 $g$,完全不加 L2 正则项。 算 Adam:用纯净的 $g$ 去算 $m$ 和 $v$,得到 Adam 的更新量。 手动减去:在更新 $w$ 的最后一步,硬性减去 $\lambda w$。
2. AdamW 公式
\(w_{new} = w_{old} - \underbrace{\eta \cdot \frac{m_t}{\sqrt{v_t}}}_{\text{Adam自适应更新}} - \underbrace{\eta \cdot \lambda \cdot w_{old}}_{\text{硬性权重衰减}}\)
最后一项 $\eta \lambda w_{old}$。它没有被除以 $\sqrt{v}$。无论参数活跃与否,都会受到公平的衰减惩罚。
五、调参
最后一个问题是:我在训练的时候调参,调的是谁。 下面就是一个调包的代码,
torch.optim.AdamW(
params,
lr=0.001,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=0.01
)
lr 对应公式:$\eta$ ,也就是基础步长。
betas 是一个包含两个数字的元组 (beta1, beta2)。betas[0]:对应公式中的 $\beta_1$betas[1]:对应公式中的 $\beta_2$
eps 是数值稳定项
weight_decay是$\lambda$,也就是权重衰减系数。Adam 中,它被加到了梯度里(这就是那个 Bug)。 在 AdamW 中,它是独立出来硬性减去的那个系数。
最常调的就是lr和 weight_decay。这两个对模型效果影响最大。几乎不调betas 和 eps,用默认值 (0.9, 0.999) 和 1e-8 就足够了。
目前来说,我对于adam和adamw的了解就止步于此了,很浅显,实际训练的时候,还有warmup和余弦退火等等trick,下一次再更新。






