附录D:训练循环增强
把”能跑”的训练循环升级成”能跑得稳、跑得快、跑得好”。本章覆盖三个标准技巧:学习率热身 + 余弦衰减、梯度裁剪、参数分组的 weight decay。
D.1 学习率热身 + 余弦衰减
直接用大学习率从 step 0 开始训练,前几百步极容易”炸”——loss 先降一点然后突然爆涨到 NaN。标准应对:
- 热身(warmup):前
n_warmup步把学习率从 0 线性升到目标值; - 余弦衰减:之后按余弦曲线缓慢降到一个最小值(通常是峰值的 10%)。
import math
def lr_at(step, peak_lr, n_warmup, n_total, min_ratio=0.1):
if step < n_warmup:
return peak_lr * step / max(1, n_warmup)
progress = (step - n_warmup) / max(1, n_total - n_warmup)
progress = min(1.0, progress)
cosine = 0.5 * (1 + math.cos(math.pi * progress))
return peak_lr * (min_ratio + (1 - min_ratio) * cosine)
在训练循环里:
for step, (x, y) in enumerate(loader):
lr = lr_at(step, peak_lr=3e-4, n_warmup=500, n_total=20_000)
for g in optimizer.param_groups:
g["lr"] = lr
...
经验:热身步数大约取总步数的 1%
3%,或固定 2002000 步。
D.2 梯度裁剪
少数 batch 的梯度异常大(数值溢出、坏样本等)会让参数被一脚踹飞。全局梯度范数裁剪是最简单也最有效的防御:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
放在 loss.backward() 之后、optimizer.step() 之前。max_norm=1.0 是 GPT 系列的常用值。
D.3 weight decay 的正确分组
AdamW 的 weight decay 是对所有参数都生效的。但不应该给以下参数加 decay:
- 所有
LayerNorm的scale和shift; - 所有
bias; - 所有
Embedding表(这点有争议,本课程不衰减它们)。
否则模型会在初期被无谓地往 0 拉,训练变慢。
def make_param_groups(model, weight_decay=0.1):
decay, no_decay = [], []
for name, p in model.named_parameters():
if not p.requires_grad: continue
if p.ndim < 2 or name.endswith(".bias") or "norm" in name.lower() or "emb" in name.lower():
no_decay.append(p)
else:
decay.append(p)
return [
{"params": decay, "weight_decay": weight_decay},
{"params": no_decay, "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(make_param_groups(model, 0.1), lr=3e-4, betas=(0.9, 0.95))
D.4 梯度累积
显存不够装下你想要的”等效 batch”时,把一个大 batch 切成 N 份,前向反向 N 次,最后只调用一次 optimizer.step():
ACCUM = 8
for step, (x, y) in enumerate(loader):
x, y = x.to(device), y.to(device)
loss = compute_loss(model, x, y) / ACCUM
loss.backward()
if (step + 1) % ACCUM == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
注意 loss / ACCUM——这是为了让平均梯度的尺度和”真正的大 batch”一致。
D.5 检查点 (Checkpointing)
长时间训练务必每隔 N 步存一次:
if step % 1000 == 0:
torch.save({
"step": step,
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
}, f"ckpt_{step}.pt")
恢复时记得把 optimizer.state_dict() 也加载回来——否则 Adam 的一阶/二阶动量重置,loss 会”回弹”几百步才稳定。
← 附录 A · 返回目录 · 附录 E · LoRA →