A2(Systems and Parallelism)

A2 (Systems and Parallelism)

0) 总目标与要实现的 4 大模块

主题是 系统与并行:既要提升 单 GPU 训练速度,也要把训练 扩展到多 GPU

明确要求实现 4 件事:

  1. Benchmarking & profiling harness(基准测试与剖析框架)
  2. FlashAttention-2 Triton kernel(自注意力算子优化)
  3. Distributed Data Parallel training(分布式数据并行训练)
  4. Optimizer state sharding(优化器状态分片)

1) Part 1:Profiling & Benchmarking(单卡性能分析 + 混精度 + 显存剖析)

1.1 基准测试脚本(benchmarking_script)

写一个端到端 benchmark 脚本:可配置模型超参、生成随机数据、warmup 后测 forward / forward+backward,且每步 torch.cuda.synchronize() 保证测到真实 GPU 时间。

在多个模型规模上跑 timing,比较 forward vs backward 耗时、方差,以及“没有 warmup 会怎样”。

能力点:基准测试的正确姿势、warmup/同步、统计(均值/标准差)。

1.2 Nsight Systems 端到端剖析(nsys_profile)

nsys profile 分别 profile forward、backward、optimizer step,并回答:

能力点:把“理论 FLOPs”落到“实际 kernel 时间”,定位优化靶点。

1.3 Mixed Precision(混精度直觉 + 性能对比)

能力点:混精度为什么快、为什么 LN/归一化敏感,如何做工程切换与评估。

1.4 Memory profiling(显存剖析)

torch.cuda.memory._record_memory_history() 生成 snapshot,在 pytorch.org/memory_viz 里看:

能力点:训练显存构成、序列长度如何“吃显存”、如何从分配栈追根溯源。

2) Part 2:Optimizing Attention(FlashAttention-2 + torch.compile)

2.1 PyTorch attention 的规模测试(pytorch_attention)

写脚本 sweep head_dim 与 seq_len,测 forward/backward 时间与显存,占用到 OOM 的拐点,并做一次“attention 显存随 seq_len 变化”的手工 accounting,回答如何消除这类显存成本。

能力点:为什么 attention 的显存/IO 是瓶颈(\(O(L^2)\)),以及 OOM 的工程边界。

2.2 torch.compile(JIT 编译对比)

能力点:编译器能带来什么、不能带来什么;为什么长序列仍需要更底层 kernel。

2.3 Triton:FlashAttention-2 kernel(核心大题)

A2 给了 Triton 示例(weighted sum)教你 block pointer / tiling / forward+backward。

然后要求实现 FlashAttention-2:

flash_forward(15分)

flash_backward(5分)

flash_benchmarking(5分)

还有 FA2 leaderboard(加速竞赛)与可选 Triton backward。

能力点:算子级优化(IO-aware)、tiling、在线 softmax、重计算换显存、kernel fusion、性能评测。

3) Part 3:Distributed Data Parallel(从“最朴素”到“重叠通信+bucket”)

3.1 单机分布式通信基准(distributed_communication_single_node)

写脚本 benchmark all-reduce:后端/设备(Gloo+CPU vs NCCL+GPU)、数据大小(1MB → 1GB)、进程数(2/4/6),输出表格/图并评论因素如何交互。

3.2 朴素 DDP(naive_ddp + naive_ddp_benchmarking)

3.3 改进 DDP(flat / overlap / bucket)

能力点:通信开销来自哪里(calls overhead vs bandwidth)、如何 overlap、如何设计 bucket,如何用 profiler 证明。

4) Part 4:4D Parallelism 通信/内存核算(理论大题)

communication_accounting(10分):给一个 XXL 配置(忽略 attention),按 Scaling Book 的设定计算:

能力点:把“并行策略”落到量化的内存/通信/吞吐模型。

5) Part 5:Optimizer State Sharding(简化版 ZeRO-1 思路)

optimizer_state_sharding(15分)

实现一个“分片优化器”包装器:每个 rank 只维护自己那份参数的 optimizer state,step 后广播更新参数同步。要实现:__init__stepadd_param_group,并通过测试。

optimizer_state_sharding_accounting(5分)

写脚本对比“有/无 state sharding”:

能力点:优化器 state 为什么吃显存、ZeRO 的基本思想、工程实现与代价。

总结:A2 的任务类型谱系