第7章:指令微调
学习目标
- 理解指令数据的标准格式与”prompt mask”的作用。
- 掌握 SFT 的损失:只在”回答”段计算交叉熵。
- 跑通一个最小可用的指令模型训练 + 推理流程。
- 知道初步评估对齐质量的几种方法。
7.1 从”续写”到”听话”
预训练后的 GPT 是一个续写器:你给它一段开头,它接着往下写。但用户期望的是”问答”:
用户输入: 把这句话翻译成法语:I love programming.
模型输出: J'aime la programmation.
要让模型呈现出这种”指令—回答”的行为,最直接的办法就是用大量”指令—回答”对继续训练它,这就是 SFT(Supervised Fine-Tuning)。
7.2 指令数据的标准格式
业界常见的两种 prompt 模板:
Alpaca 模板:
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Input:
{input}
### Response:
{response}
Phi/简洁式:
Instruction: {instruction}
Input: {input}
Output: {response}
不论选哪种,关键有三:
- 训练和推理用同一个模板——任何不一致都会显著降低效果;
- 明确分隔符(如
### Response:)让模型学会”在这里开始作答”; - 训练时记录每条样本的”prompt 部分长度”,以便构造 mask。
ALPACA_TEMPLATE = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instr}\n\n"
)
def format_prompt(example):
text = ALPACA_TEMPLATE.format(instr=example["instruction"])
if example.get("input"):
text += f"### Input:\n{example['input']}\n\n"
text += "### Response:\n"
return text # 不含答案的部分
def format_full(example):
return format_prompt(example) + example["output"]
7.3 关键:只在”回答”段计算损失
如果对整段(prompt + response)都算 loss,模型会被迫记忆 prompt 模板里反复出现的”### Instruction:“等套话——这是在浪费容量。
正确做法:构造一个 (B, T) 的 mask,prompt 部分为 0,response 部分为 1,只对 mask 为 1 的位置算交叉熵。
import torch
from torch.utils.data import Dataset
import tiktoken
class SFTDataset(Dataset):
def __init__(self, data, tokenizer, max_len=512, pad_id=50256):
self.samples = []
for ex in data:
prompt_ids = tokenizer.encode(format_prompt(ex))
full_ids = tokenizer.encode(format_full(ex))
full_ids = full_ids[:max_len]
input_ids = full_ids[:-1]
target_ids = full_ids[1:]
# mask: 0 表示忽略此位置的 loss
mask = [0] * (len(prompt_ids) - 1) + [1] * (len(target_ids) - len(prompt_ids) + 1)
mask = mask[:len(target_ids)]
# pad
pad_len = max_len - 1 - len(input_ids)
input_ids += [pad_id] * pad_len
target_ids += [pad_id] * pad_len
mask += [0] * pad_len
self.samples.append((
torch.tensor(input_ids),
torch.tensor(target_ids),
torch.tensor(mask, dtype=torch.float),
))
def __len__(self): return len(self.samples)
def __getitem__(self, i): return self.samples[i]
带 mask 的损失:
import torch.nn.functional as F
def loss_sft(model, x, y, mask):
logits = model(x) # (B, T, V)
log_probs = F.log_softmax(logits, dim=-1)
nll = -log_probs.gather(-1, y.unsqueeze(-1)).squeeze(-1) # (B, T)
nll = nll * mask
return nll.sum() / mask.sum().clamp(min=1)
7.4 训练循环
和预训练几乎一致,只是损失函数换成 loss_sft,并且学习率更小(典型 1e-5 ~ 5e-5):
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.0,
betas=(0.9, 0.95))
for epoch in range(num_epochs):
for x, y, m in loader:
x, y, m = x.to(device), y.to(device), m.to(device)
optimizer.zero_grad()
loss = loss_sft(model, x, y, m)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
经验值:
- 数据量:教学规模 1k
5k 条精选指令即可;工业级 SFT 一般 5w50w 条。 - epoch:2~3 个就够,太多会让模型过拟合到模板。
- batch_size × grad_accum:尽量凑出 64~256 条样本/step。
7.5 推理:拼 prompt → 生成 → 截断
@torch.no_grad()
def chat(model, tokenizer, instruction, input_=None,
max_new_tokens=200, temperature=0.7, top_k=50):
prompt = format_prompt({"instruction": instruction, "input": input_ or ""})
ids = torch.tensor([tokenizer.encode(prompt)], device=next(model.parameters()).device)
out_ids = generate(model, ids, max_new_tokens, temperature=temperature, top_k=top_k)
text = tokenizer.decode(out_ids[0].tolist())
# 取 ### Response: 之后的部分
return text.split("### Response:")[-1].strip()
在生成阶段需要决定何时停止。三种常用方案:
- 固定
max_new_tokens; - 模型自己生成
<|endoftext|>时停; - 检测到下一个
### Instruction:标签时停(防止模型”自问自答”)。
7.6 评估对齐质量
SFT 后的模型不再是简单的”困惑度越低越好”——困惑度低不代表”听话程度高”。常见评估方式:
(a) 自动指标(弱信号)
- 在留出测试集上算 response 段的困惑度;
- 用 BLEU / ROUGE 与参考答案做表面相似度比较(仅适合翻译/摘要类任务)。
(b) 模型即评委(LLM-as-a-judge)
- 用一个更强的模型(GPT-4 / Claude)比较两个回答的胜率;
- 注意要用同样的提示模板和打分标准,做位置交换以减小偏置。
(c) 人工评估(金标准)
- 抽样 50~100 条,从”指令遵循""正确性""有用性""安全性”四个维度打分;
- 教学场景里这是最快确认”模型是不是真的在听话”的办法。
7.7 SFT 的局限与下一步
仅靠 SFT,模型仍可能:
- 拒答太多(保守)或胡说太多(不安全);
- 偏好某种风格(啰嗦、爱列点);
- 对模板格式过敏,换个 prompt 风格就翻车。
工业上下一步通常是 偏好对齐(RLHF / DPO / KTO 等),用人类偏好数据进一步调整。这部分超出了本课程范围,但你已经具备了进入这个方向所需的全部底层知识。
检查清单
- 我能解释为什么 SFT 损失要在 prompt 段做 mask。
- 我能在白板上画出”原始 GPT → SFT 模型”的数据流差异。
- 我知道至少两种评估对齐质量的方法以及它们各自的盲点。
练习题
- 写一份 30 条左右的迷你指令数据集(中英文皆可),格式严格遵循 Alpaca 模板,训练并对比”有/无 prompt mask”两种 loss 的差异。
- 给推理函数加一个”遇到
### Instruction:立即停止”的逻辑。 - 用 LLM-as-a-judge 思路:让一个更强的模型给你的两个 SFT 版本(不同 epoch / 不同学习率)打胜率,记录结果。
📖 第7章补充材料 → — DPO损失推导、偏好数据生成、LLM-as-judge评估