Fast Inference from Transformers via Speculative Decoding

下载 PDF

Fast Inference from Transformers via Speculative Decoding - 中文验证版

英文原始依据卡片:speculative_decoding_2023.md

状态:已翻译。

元数据

计算设置

论文明确声明了其 wall-clock 测量的硬件:在单个 TPU-v4 上 batch size 1。已实现的延迟实验使用 T5 version 1.1 encoder-decoder 检查点,T5-XXL 11B 作为目标模型,较小的 T5 检查点作为逼近模型:T5-large 800M、T5-base 250M 和 T5-small 77M。测试任务是微调于 WMT En-De 的英德翻译和 CNN/DailyMail 摘要。

来源还报告了更广泛的模型规模探测但没有 wall-clock 硬件:一个 GPT 式 97M decoder 目标配合 6M 逼近模型在 LM1B 上,以及 LaMDA 137B decoder 目标配合 LaMDA 8B、2B 和 100M 逼近模型在对话上。这些测量的是接受率,而非设备特定延迟。

瓶颈

瓶颈是来自大型目标模型的串行自回归解码。标准生成每步目标模型得到一个 token,因此延迟与串行目标传递的次数相关联。论文强调,大模型推理通常不瓶颈于算术操作,而瓶颈于内存带宽和通信。这在 batch size 1 时尤其如此,其中每个解码步骤重新读取模型权重和 KV 缓存,算术重用量有限。

投机解码从内存/延迟侧攻击问题。它可能花费更多的总算术,因为一次算法迭代并行评估目标的多个位置,同时也运行逼近模型。但如果硬件可以并发评估目标位置而不增加 wall time,它就减少了串行目标模型调用,并且可以减少对目标权重和 KV 缓存的内存访问(每个生成 token)。因此该方法最适合具有空闲并行计算但存在串行延迟或内存带宽瓶颈的推理场景。

方法适配

该方法使用一个小型逼近模型 Mq 来提议 gamma 个未来 tokens,然后并行在提议的前缀位置上运行大型目标模型 Mp。一个接受/拒绝过程接受尽可能多的 draft tokens,只要它们在分布上是有效的,并在需要时采样一个纠正 token。重要的算法约束是精确性:输出遵循与标准目标模型解码相同的分布,包括随机采样,而不仅仅是贪心解码。

计算适配由两个量控制。接受率 alpha 衡量 draft 与目标的一致性程度,成本系数 c 衡量一步 draft 相对于一步目标的 wall-clock 成本。论文报告对于比目标小几个数量级的逼近模型,c 始终低于 0.05,通常接近零。理论上的 wall-time 改善取决于 alphagammac;该方法假设 gamma + 1 个并发目标评估可以并行运行。该假设使硬件并行性成为算法的一部分,而不仅是实现细节。

draft 大小的选择是内存/延迟的权衡。T5-small 便宜且提供强劲加速,尽管接受率低于更大的 draft。T5-large 接受率略高但成本更大,并且可能产生更低的实测加速。这是核心的推理教训:最佳 draft 不一定是最高精度的小模型;而是平衡接受率与 draft 开销和目标并行性的模型。

证据

主要 wall-clock 表支持了计算论点。在单个 TPU-v4 上 batch size 1,WMT En-De 使用 T5-XXL 和 T5-small 在温度 0 下达到 3.4 倍加速(gamma = 7alpha = 0.75、draft 成本约 0.02)。在温度 1 下,同一任务/模型对达到 2.6 倍,alpha = 0.62。CNN/DailyMail 使用 T5-small 在温度 0 下达到 3.1 倍,温度 1 下 2.3 倍。T5-base 也表现良好,在 En-De 上达到 2.8 倍和 2.4 倍,在 CNN/DailyMail 上达到 3.0 倍和 2.2 倍,取决于温度。

该表也显示了为什么内存和 draft 开销很重要。T5-large 在若干设置中接受率高于 T5-small,但其更大的成本系数(约 0.11)将加速降为 En-De 温度 0 下的 1.7 倍、En-De 温度 1 下的 1.4 倍、CNN/DailyMail 温度 0 下的 2.2 倍和 CNN/DailyMail 温度 1 下的 1.7 倍。更小的 draft 更快,因为更好逼近的成本不值得。

更广泛的接受率实验显示了通用性。对于 GPT 式 97M 目标,一个 6M GPT 式 draft 给出 alpha 约 0.88 到 0.89,而 unigram 和 bigram draft 则弱得多但非零。对于 T5-XXL,T5-small/base/large 逼近模型通常在不同任务和温度间产生约 0.53 到 0.82 的 alpha。LaMDA 137B 被包含作为用于接受分析的大规模目标,但非论文中的 wall-clock 基准。

历史影响

投机解码建立了一条清晰的无损 LLM 延迟降低路径:使用廉价模型创建可能的未来 tokens,并用昂贵模型在一次并行传递中验证其中许多。其历史重要性在于它将分布正确性与串行解码分离。早期 blockwise 或自适应计算想法通常改变输出、仅支持贪心解码或需要架构变更;本文使精确的随机版本在现有检查点下变得实用。

在推理系统层面,论文帮助将问题从"每个 token 多少 FLOPs?"转变为"每个 token 多少串行目标模型传递和内存读取?"该框架在现代 LLM 推理栈中仍然核心,特别是对于 batch-1 或低延迟工作负载,其中内存带宽、KV 缓存流量和解码串行化占主导。

局限

该方法并非免费计算。它可能增加总算术量,特别是在接受率低时,因为被拒绝的 draft tokens 仍然需要目标验证。它还假设有足够的硬件并发性可以并行运行 gamma + 1 个目标位置而不增加 wall time。如果推理系统已经在大 batch 下完全计算饱和,同一算法可能提供较小的延迟增益或没有增益。

加速高度依赖于任务、温度和 draft。论文显示 argmax 解码的加速高于随机温度 1 解码,因为更锐利的分布提高了接受率。它还显示更大的 draft 不是自动更好的。最后,仅 T5 wall-clock 结果被实现;GPT 式和 LaMDA 案例是接受率分析,因此设备特定加速不应在没有新测量的情况下外推。

链接

  • 计算范式:history/compute_regimes/efficient_edge_inference/README.md
  • 来源 PDF 和抽取文本见上方元数据。
  • Queue 状态:read_complete
  • 方法索引:speculative_decoding
  • 对照更新:compute bottlenecks