A2 (Systems and Parallelism)
0) 总目标与要实现的 4 大模块
主题是 系统与并行:既要提升 单 GPU 训练速度,也要把训练 扩展到多 GPU。
明确要求实现 4 件事:
- Benchmarking & profiling harness(基准测试与剖析框架)
- FlashAttention-2 Triton kernel(自注意力算子优化)
- Distributed Data Parallel training(分布式数据并行训练)
- Optimizer state sharding(优化器状态分片)
1) Part 1:Profiling & Benchmarking(单卡性能分析 + 混精度 + 显存剖析)
1.1 基准测试脚本(benchmarking_script)
写一个端到端 benchmark 脚本:可配置模型超参、生成随机数据、warmup 后测 forward / forward+backward,且每步 torch.cuda.synchronize() 保证测到真实 GPU 时间。
在多个模型规模上跑 timing,比较 forward vs backward 耗时、方差,以及“没有 warmup 会怎样”。
能力点:基准测试的正确姿势、warmup/同步、统计(均值/标准差)。
1.2 Nsight Systems 端到端剖析(nsys_profile)
用 nsys profile 分别 profile forward、backward、optimizer step,并回答:
- forward 总耗时是否与 python timing 一致
- forward 最耗时的 CUDA kernel 是哪个、调用多少次;forward+backward 是否还是它最耗时
- 除 matmul 外哪些 kernel 也占显著时间
- 训练步(含 AdamW)时 matmul 占比变化
- attention 中 softmax vs matmul 的 runtime 与 FLOPs 差异对比
能力点:把“理论 FLOPs”落到“实际 kernel 时间”,定位优化靶点。
1.3 Mixed Precision(混精度直觉 + 性能对比)
- 先做一个“累加误差”小实验,比较 FP32/FP16 混合累加带来的精度差。
- 再分析 autocast 下 ToyModel 各处的 dtype(参数、fc、layernorm、logits、loss、grad)。
- 最后改 benchmark 脚本支持 BF16 mixed precision,并对不同模型规模比较 full vs mixed timing。
能力点:混精度为什么快、为什么 LN/归一化敏感,如何做工程切换与评估。
1.4 Memory profiling(显存剖析)
用 torch.cuda.memory._record_memory_history() 生成 snapshot,在 pytorch.org/memory_viz 里看:
- forward vs full training step 的显存时间线(截图)
- 不同 context length 的峰值显存表
- mixed precision 对峰值显存影响
- residual stream activation 张量大小计算(MB)
- 最大分配来自哪里(stack trace)
能力点:训练显存构成、序列长度如何“吃显存”、如何从分配栈追根溯源。
2) Part 2:Optimizing Attention(FlashAttention-2 + torch.compile)
2.1 PyTorch attention 的规模测试(pytorch_attention)
写脚本 sweep head_dim 与 seq_len,测 forward/backward 时间与显存,占用到 OOM 的拐点,并做一次“attention 显存随 seq_len 变化”的手工 accounting,回答如何消除这类显存成本。
能力点:为什么 attention 的显存/IO 是瓶颈(\(O(L^2)\)),以及 OOM 的工程边界。
2.2 torch.compile(JIT 编译对比)
- 把 attention module 编译后对比 uncompiled 的 forward/backward timing
- 把整个 Transformer model
torch.compile(model),比较整体 forward / train step 性能
能力点:编译器能带来什么、不能带来什么;为什么长序列仍需要更底层 kernel。
2.3 Triton:FlashAttention-2 kernel(核心大题)
A2 给了 Triton 示例(weighted sum)教你 block pointer / tiling / forward+backward。
然后要求实现 FlashAttention-2:
flash_forward(15分)
先写“纯 PyTorch 的 tiled FlashAttention-2 forward” autograd.Function(慢但用于 debug)
再写 Triton kernel 实现 fused forward,并封装成 autograd.Function,通过测试
加 causal masking(Triton 里做 index mask,masked 元素加 -1e6),并在 ctx 里保存 flag
flash_backward(5分)
- 用 PyTorch(不是 Triton)+
torch.compile 实现 FA2 backward(含 D 向量与 Eq 13-19 的流程)。
flash_benchmarking(5分)
- 用
triton.testing.do_bench 对比:你的 FA2(前向+后向)vs 常规 PyTorch attention;需要在 H100 上 sweep seq_len、d、dtype,并给表格。
还有 FA2 leaderboard(加速竞赛)与可选 Triton backward。
能力点:算子级优化(IO-aware)、tiling、在线 softmax、重计算换显存、kernel fusion、性能评测。
3) Part 3:Distributed Data Parallel(从“最朴素”到“重叠通信+bucket”)
3.1 单机分布式通信基准(distributed_communication_single_node)
写脚本 benchmark all-reduce:后端/设备(Gloo+CPU vs NCCL+GPU)、数据大小(1MB → 1GB)、进程数(2/4/6),输出表格/图并评论因素如何交互。
3.2 朴素 DDP(naive_ddp + naive_ddp_benchmarking)
- naive_ddp:每个参数梯度分别 all-reduce;用 toy model 验证与单进程训练权重一致。
- naive_ddp_benchmarking:在 1 node×2 GPU、XL 模型上测每步总时间与通信占比。
3.3 改进 DDP(flat / overlap / bucket)
- flatten all grads 一次 all-reduce:比较性能与通信占比。
- overlap(逐参数就绪即异步 all-reduce):用 hook + async_op,实现 DDP wrapper,并用 Nsight 截图证明 overlap。
- bucketed overlap(bucket_size_mb):把参数分桶(建议 reverse order),桶内就绪后异步 all-reduce;benchmark 不同 bucket size,并推导“通信开销模型 + 最优 bucket size”。
能力点:通信开销来自哪里(calls overhead vs bandwidth)、如何 overlap、如何设计 bucket,如何用 profiler 证明。
4) Part 4:4D Parallelism 通信/内存核算(理论大题)
communication_accounting(10分):给一个 XXL 配置(忽略 attention),按 Scaling Book 的设定计算:
- 单卡 FP32 权重+梯度+优化器 state + BF16 activations 的内存;折算需要多少 H100 80GB
- 引入 FSDP 分片后每卡内存表达式,算 NFSDP 的最小值使得内存 < TPU v5p 95GB
- 用给定带宽/算力,推 compute-bound 的 per-device batch size 和 overall batch size
- 讨论如何在不通信瓶颈下尽量减小 batch(技巧+引用/公式)
能力点:把“并行策略”落到量化的内存/通信/吞吐模型。
5) Part 5:Optimizer State Sharding(简化版 ZeRO-1 思路)
optimizer_state_sharding(15分)
实现一个“分片优化器”包装器:每个 rank 只维护自己那份参数的 optimizer state,step 后广播更新参数同步。要实现:__init__、step、add_param_group,并通过测试。
optimizer_state_sharding_accounting(5分)
写脚本对比“有/无 state sharding”:
- 初始化后、step 前、step 后的峰值显存与分解
- 每步训练速度变化
- 与 ZeRO stage 1 的差异(内存与通信量)
能力点:优化器 state 为什么吃显存、ZeRO 的基本思想、工程实现与代价。
总结:A2 的任务类型谱系
- 测量与剖析:timeit/benchmark、Nsight、memory_viz、mixed precision 对比
- 算子级优化:从 PyTorch attention → torch.compile → Triton FlashAttention-2(含 causal、backward、benchmark、leaderboard)
- 分布式训练工程:all-reduce 基准 → naive DDP → flatten → overlap → bucket + 理论建模
- 内存优化:optimizer state sharding + 量化分析 + ZeRO 对比
- 理论核算:4D 并行通信/内存/吞吐推导