PaLM: Scaling Language Modeling with Pathways

下载 PDF

PaLM: Scaling Language Modeling with Pathways - 中文验证版

英文原始依据卡片:palm_2022.md

状态:已翻译。

元数据

  • Slug: palm_2022
  • 年份: 2022
  • 会议: arXiv
  • 作者: Aakanksha Chowdhery et al.
  • 阅读状态: read complete
  • 计算范式: 超大规模密集 LLM 训练
  • 主要来源: PDF抽取文本

计算设置

论文明确列出训练系统:

  • 使用 Pathways 的 6144 个 TPU v4 chips。
  • 两个 TPU v4 Pods,每个有 3072 个 TPU v4 chips 和 768 hosts。
  • Model card 将硬件列为 TPU v4。

训练目标是一个 540B 参数 dense Transformer,在 780B-token 语料库上进行单次遍历训练。批次调度逐步增加:512 序列(约 1M tokens)→ 1024(约 2M tokens)→ 2048(约 4M tokens)。

瓶颈

瓶颈是 pod 规模的 dense training。论文将 pipeline parallelism 视为需要回避的瓶颈:pipeline 会为每个 micro-batch 重新加载权重,在填充和排空时产生 bubble,并增加软件复杂性。因此 PaLM 的算力问题是在不继承 pipeline 成本的前提下使用更多芯片。第二个瓶颈是跨 pod 通信:每个 host 对每步交换约 1.3 GB 梯度,所有 host 的总突发带宽约 81 Tbps。内存是第三个约束:PaLM 使用 rematerialization,因为存储所有中间激活会限制可行的 batch size。540B dense 参数的训练态内存下界约为 8.64 TB(16 字节/参数混合精度 Adam 规则),BF16 推理权重约 1.08 TB(不含 KV cache)。

方法适配

PaLM 通过以下方式适配 TPU v4 Pod-scale training:

  • 使用 Pathways 进行 pipeline-free multi-pod training。
  • 在每个 pod 内应用 12-way model parallelism 和 256-way fully sharded data parallelism。
  • 跨 pods 应用 2-way pod-level data parallelism。
  • 使用 parallel layers 架构改变,报告训练速度改善约 15%。
  • 论证 MFU 作为更干净的 utilization metric。

证据

  • 模型大小为 540B 参数。
  • Training throughput 为 238.3K tokens/sec(batch size 2048,约 4M tokens)。
  • 论文报告 46.2% model FLOPs utilization 和 57.8% hardware FLOPs utilization,优于所列先前大模型:GPT-3 21.3% MFU、Gopher 32.5%、Megatron-Turing NLG 530B 约 30%。
  • 双 pod 吞吐约为单 pod 的 1.95 倍,作者描述为完美弱扩展的 97%(因为 batch size 翻倍)。
  • Training compute 报告为 2527.2 zettaFLOPs。
  • 训练使用 6144 个 TPU v4 chips 运行 1200 小时,另有 3072 chips 运行 336 小时,包括停机时间和重复步骤。

历史影响

PaLM 展示了不使用 pipeline parallelism 的 TPU v4 pod 规模 dense LLM training。历史上它处于单 pod dense scaling 与后续将大量加速器池常态化的训练系统之间。其贡献既在模型本身,也同样在算力结构:pod 内的 model parallelism 和 fully sharded data parallelism、跨 pod 的弱扩展 data parallelism,以及对 rematerialization、batch size、网络突发的显式计账。它帮助 MFU 成为实用 systems metric。通过将所需模型 FLOPs 与额外重算 FLOPs 分离,PaLM 使不同模型和编译器策略的比较更加容易。

局限

  • Chinchilla 式分析表明,相对参数量 PaLM 可能训练不足:540B 参数仅见 780B tokens,而后续算力最优实践会将更多相同预算用于数据。
  • 一些 subcorpora 在 780B-token 规模附近开始重复,这限制了在相同数据配比下简单延长训练。
  • 用相同总算力训练更小模型需要更少芯片(增加 wall-clock 时间)或相同芯片配更大 batch,而 PaLM 540B 的 batch 已达 4M tokens,不清楚更大 batch 是否维持 sample efficiency。
  • 硬件和效率报告深入,但相对模型规模,deployment impact、fairness 和 toxicity 分析有限。

链接