A5 补充(Instruction Tuning & RLHF / DPO)
一份“通用对话助手对齐”的完整流水线:先做 多基准零样本评测,再做 指令微调(SFT),然后做 偏好对齐(DPO),并在多个维度(能力 + 安全 + 对话质量)上评测与分析“alignment tax”。
0) 总览:要实现什么、要跑什么、代码结构
**你需要实现(from scratch)**三类组件:
- 多个数据集的 zero-shot prompting baseline
- SFT 指令微调(instruction-response pairs)
- DPO(pairwise preference data)
并且给了专门的 GitHub repo(含 prompts、tests、data、脚本配置),要求你通过指定测试:test_data.py/test_sft.py/test_dpo.py/test_metrics.py。
1) Part 2:零样本评测 Baselines(四个任务集:知识/推理/对话质量/安全)
补充用 同一个 system prompt做通用助手 baseline(强调 helpful+safe),并利用“看到 # Query: 就停止生成”的约定来截断输出。
1.1 MMLU baseline(mmlu_baseline,4分)
任务包括:
- (a) 输出解析器:把模型生成解析成选项字母 A/B/C/D;不能解析返回 None,并通过测试
test_parse_mmlu_response。
- (b) 评测脚本:加载 MMLU、格式化 prompt(强制回答格式“正确答案是 _”)、生成、算指标、把样本/生成/分数序列化落盘。
- (c) 解析失败分析:统计解析失败数并展示例子。
- (d) 吞吐:估计 examples/sec。
- (e) baseline 指标:报告零样本性能。
- (f) 错误分析:随机抽 10 个错题,总结错误类型。
1.2 GSM8K baseline(gsm8k_baseline,4分)
任务几乎同 MMLU,但解析为数字:
- (a) 数字解析器:取生成文本里“最后一个数字”为答案;不可解析返回 None;通过测试
test_parse_gsm8k_response。
- (b) 评测脚本 + 落盘、(c) 解析失败分析、(d) 吞吐、(e) 指标、(f) 错误分析。
1.3 AlpacaEval baseline(alpaca_eval_baseline,4分)
这是“对话质量/指令遵循”评测:
- (a) 生成脚本:对 AlpacaEval 每条 instruction 生成 output,并按 AlpacaEval 需要的 JSON array 格式存盘(字段:instruction/output/generator/dataset)。
- (b) 吞吐:examples/sec。
- (c) 自动评审计算 winrate:用 Llama 3.3 70B Instruct 当 annotator,对比 GPT-4 Turbo,报告 winrate 与 length-controlled winrate(但需要两张 80GB GPU)。
- (d) 误差分析:抽 10 个被判输的例子,分析原因、是否同意自动评审。
1.4 SimpleSafetyTests baseline(sst_baseline,4分)
这是“安全/拒答能力”评测:
- (a) 生成脚本:对 SST prompts 生成 output,按 JSONL 存盘(prompts_final/output)。
- (b) 吞吐:examples/sec。
- (c) 自动安全评审:用 Llama 3.3 70B Instruct 判断 safe/unsafe,报告 safe 比例(同样需要大 GPU)。
- (d) 误差分析:抽 10 个 unsafe 样本,总结什么情况下会不安全、是否同意评审。
2) Part 3:Instruction Fine-Tuning(SFT 指令微调)
2.1 看数据(look_at_sft,4分)
- 从提供的 instruction tuning 训练集随机看 10 条:识别包含哪些传统 NLP 任务(QA/情感/写作等),评价 prompt/response 质量,用具体例子支撑。
2.2 实现数据加载(data_loading,3分)
这块很工程:把 prompt/response 变成“可 packed 的 LM token 流”。
- (a) Dataset 类:读 jsonl,按 Alpaca 模板拼成字符串、tokenize、拼接成一长串 token,再切成固定长度 seq_length 的 non-overlapping chunks;返回
input_ids 与 labels(next-token)。并通过测试(要实现 adapter get_packed_sft_dataset)。
- (b) batch iterator/DataLoader:按 batch size 产出 batches(可 shuffle),迭代一次构成一个 epoch;通过测试(adapter
run_iterate_batches)。
2.3 写 SFT 训练脚本(sft_script,4分)
- 写一个可配置训练脚本:超参可控、支持 gradient accumulation 扩大有效 batch、周期性 log train/val(可用 wandb)。
2.4 跑一次指令微调(sft,6分,算力很重)
- 在给定 instruction tuning 数据上微调 Llama 3.1 8B(建议 1 epoch、ctx=512、有效 batch=32、cosine+warmup 等),保存模型与 tokenizer;提交 learning curve 与最终 val loss。
3) Part 4:评测指令微调后的模型(四个基准 + 误差分析)
对每个基准都要求:写脚本、测吞吐、报指标、做 10 个错例分析,并对比 zero-shot baseline。
- MMLU(mmlu_sft,4分)
- GSM8K(gsm8k_sft,4分)
- AlpacaEval(alpaca_eval_sft,4分;含 winrate 计算,需要大 GPU)
- SimpleSafetyTests(sst_sft,4分;含安全评审,需要大 GPU)
4) Part 4.5:Red-teaming(红队测试,4分)
- (a) 提出三种额外的滥用方式(不重复题面示例)。
- (b) 亲自尝试用模型完成三类恶意应用,写过程、是否成功、用了哪些策略、定性结论。
5) Part 5:DPO(偏好对齐,替代 PPO/RLHF 的简化路线)
先讲 RLHF 经典流程,再给出 DPO 目标推导(Equation 3),然后实现 DPO loss 与训练。
5.2 看偏好数据(look_at_hh,2分)
- 实现加载 Anthropic HH(4 个 jsonl.gz 合并),并做预处理:
- 忽略 multi-turn
- 拆成 instruction +(chosen, rejected)
- 记录来源文件(helpful/harmless 的不同子集)
- 抽样看 helpful 与 harmless 各 3 条:分析 chosen vs rejected 的差异,是否同意标注。
5.3 实现 DPO loss(dpo_loss,2分)
- 写 per-instance DPO loss(式 3):需要同时计算 \(\pi_\theta\) 与 \(\pi_\text{ref}\) 对 chosen/rejected 的条件 logprob 比值;注意两模型可能在不同 device;用 Alpaca 模板格式化并在 response 末尾加 EOS。
- 通过测试
test_per_instance_dpo_loss(adapter per_instance_dpo)。
5.4 DPO 训练 + 多基准“对齐税”分析(dpo_training,4分)
- 实现 DPO 训练 loop:因为显存吃紧,要求 不做 batch,用 gradient accumulation;优化器用 RMSprop(不是 AdamW),建议 2 GPU:ref 模型一张、train 模型一张。
- 跑 1 epoch,保存验证准确率最高的模型(这里的“验证准确率”指:chosen logprob > rejected logprob 的比例)。
- 评测:
- AlpacaEval winrate(对比 SFT 前后)
- SimpleSafetyTests safe 比例(对比 SFT)
- 再评 MMLU + GSM8K,观察是否出现 alignment tax(能力下降)。
一句话总结:A5 补充的“任务谱系”
- 评测工程:为 MMLU/GSM8K/AlpacaEval/SST 写 prompt、生成脚本、解析器、落盘、吞吐与错误分析。
- SFT 工程:看数据 → packed 数据加载 → 可配置训练脚本(含梯度累积) → 训练并保存 → 四基准再评测。
- 安全评测与红队:自动安全评审 + 人工 red-teaming。
- 偏好对齐(DPO):看 HH 数据 → 实现 DPO loss → DPO 训练 → 对话质量/安全提升 vs 能力下降(alignment tax)全套评测。