Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity

下载 PDF

Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity - 中文验证版

英文原文卡片:switch_transformer_2021.md

状态:已翻译。

元数据

  • Slug: switch_transformer_2021
  • 年份: 2021
  • 会议: JMLR
  • 作者: William Fedus、Barret Zoph、Noam Shazeer
  • 阅读状态: read complete
  • 计算范式: 稀疏化与内存高效扩展
  • 主要来源: PDF抽取文本

计算设置

论文明确聚焦于 TPU 架构。关键的基准表指出所有 MoE 和 Switch 模型在相同计算量、32 核和相同的 TPUv3 硬件上训练。图 5 同样指出速度对比中的所有模型在 32 个 TPUv3 核心上训练且每样本 FLOPs 相等。

对于万亿级模型,抽取文本给出了模型大小和并行性设计,但没有清晰的精确设备数量。按照项目规则,大规模运行的硬件推断为与论文其余部分及其 Mesh TensorFlow 实现一致的 TPU-v3 式分布式训练。论文指出 Switch 模型组合了数据、模型和专家并行性,并建议在较小专家数量的范式下每个核心一个专家,但 Switch-XXL 和 Switch-C 的精确核心数在抽取文本中并未说明。

瓶颈

先前的 MoE 系统因 routing 复杂性、通信成本和训练不稳定性而难以采用。稀疏 routing 产生专家溢出、all-to-all 通信和 bfloat16 数值问题。论文将 Switch 框定为对三个障碍的回应:模型复杂性、训练难度和通信成本。

设备瓶颈是固定形状的加速器执行与动态 routing 之间的矛盾。TPU 编译期望静态大小的张量,但 router 向每个专家发送数量取决于数据的 token。Switch 因此需要一个固定的专家容量:每 batch token 数除以专家数,再乘以容量因子。如果过多 token 选择了同一个专家,溢出的 token 将跳过专家层。提高容量减少丢弃,但增加了计算和 all-to-all 通信。

方法适配

Switch Transformer 简化了稀疏 routing:

  • 使用 top-1 路由替代 top-2 路由。
  • 将每个 token 发送到一个专家。
  • 使用固定的专家容量和辅助负载均衡损失。
  • 将 router 输入和 softmax 转为 float32,然后将 dispatch/combine 转回 bfloat16。
  • 保持密集的 Transformer 结构,同时用稀疏专家模块替换前馈模块。

核心适配是 top-1 路由。早期的 MoE Transformer 常使用 top-2 路由,将每个 token 发送到两个专家。Switch 仅将每个 token 发送到其最高概率的专家。论文列出了三个好处:router 计算更少,由于每个 token 只路由一次,专家容量需求大约减半,以及更低的通信/实现复杂性。模型仍通过所选专家的门概率保持 router 可微,并添加辅助负载均衡损失以防止 token 坍缩到少数专家上。

该实现自然地映射到专家并行性。在 Mesh TensorFlow 伪代码中,输入重塑为 num cores × tokens per core × model dimension。router 在 num cores、tokens、experts 和 expert capacity 维度上创建 dispatch 和 combine 张量。一次 all-to-all 将 token 发送到持有专家的核心,专家前馈计算运行,另一次 all-to-all 返回输出。Top-1 路由同时减少了专家路径和路由流量。

选择性精度是一个硬件特定的稳定性修复。纯 bfloat16 在 TPU 上很快,但 router softmax 在数值上敏感。Switch 在 router 函数内将 router 输入/logits 转为 float32,计算 softmax 和 dispatch/combine 决策,然后将 dispatch/combine 张量转回 bfloat16,使昂贵的跨设备张量不保持 float32。这使通信路径接近 bfloat16 速度,同时为 routing 决策提供 float32 的稳定性。

证据

受控基准表是首要的计算证据。在 32 个 TPUv3 核心上,T5-Base 在 100k 步后达到质量 -1.731,且在测量到的 100k 步中未达到 -1.50 阈值。T5-Large 在 131.1 小时内以 470 examples/sec 达到阈值。Switch-Base 具有 128 个专家、容量因子 1.0,在 100k 步后达到质量 -1.561,62.8 小时内达到阈值,且 1000 examples/sec。可比照的 top-2 MoE-Base 需要 80.1 小时,860 examples/sec。

图 5 给出了标题性的加速:一个 64 专家的 Switch-Base 在 32 个 TPUv3 核心上以相同的每样本 FLOPs 用七分之一的时间达到与 T5-Base 相同的质量。图 4 显示将专家从 2 扩展到 256 在保持每样本计算预算相等的情况下改善了困惑度;256 专家点具有 14.7B 参数,而 T5-Base 点为 223M 参数。来源还指出一个 Switch 模型在 60k 步时即达到 T5-Base 在 450k 步时的质量,对应 7.5 倍步时间加速。

选择性精度经直接测试。32 专家 Switch-Base 在 float32 下质量 -1.718、速度 1160 examples/sec。纯 bfloat16 达到 1390 examples/sec 但发散,质量列为 -3.780。选择性精度达到 -1.716,1390 examples/sec,匹配 float32 训练动态的同时保持类 bfloat16 的速度。

多语言和大模型证据扩展了规模。在多语言预训练中,Switch 显示相对 mT5-Base 平均 5 倍加速,91% 的语言达到至少 4 倍加速。表 9 列出 Switch-XXL 具有 395B 参数和每序列 6.3T FLOPs,Switch-C 具有 1571B 参数和每序列 890B FLOPs。Switch-C 使用 2048 个专家、15 层且专家频率为 1;Switch-XXL 使用 64 个专家、24 层且专家频率为 1/2。在 500k 步时,Switch-XXL 报告负对数困惑度 -1.008,Switch-C 为 -1.043,两者均优于表中 T5-XXL 的 -1.095。

历史影响

Switch 使稀疏专家 Transformer 更简单、更可训练,帮助 top-1 路由成为实用的 MoE 默认方案。它是 GShard 之后的一个关键稀疏扩展分支,因为它在保留不按比例增加每 token 活跃 FLOPs 便可扩展参数量这一理念的同时,减少了 routing 和通信表面积。

历史影响在于实际采用。GShard 证明了编译器支持的 MoE 可以扩展到数千亿参数;Switch 使该层更易于理解、基准测试和使用对 bfloat16 友好的选择性精度来稳定训练。

局限

  • 最大模型可能不稳定。
  • 精调行为取决于每 token FLOPs 与参数量之间尚未充分理解的权衡。
  • 部署仍然困难。
  • 蒸馏仅保留部分质量收益。
  • 抽取文本未清晰列出 Switch-XXL 或 Switch-C 的精确设备数量,因此它们的硬件规模在本卡片中是推断的,而非来源明确报告的。
  • 固定的专家容量意味着 token 可能溢出并跳过专家计算;更大的容量因子减少溢出但增加了计算和通信。
  • 论文报告 Switch-C 未观察到不稳定,但指出 Switch-XXL 尽管采用了所提出的稳定技术仍然不稳定。

链接