来源材料

资料来源

← 首页

                                                         Fast Inference from Transformers via Speculative Decoding


                                                                           Yaniv Leviathan * 1 Matan Kalman * 1 Yossi Matias 1


                                                                  Abstract                                    developed to make inference from them faster. Some ap-
                                               Inference from large autoregressive models like                proaches aim to reduce the inference cost for all inputs
                                               Transformers is slow - decoding K tokens takes                 equally (e.g. Hinton et al., 2015; Jaszczur et al., 2021;




arXiv:2211.17192v2 [cs.LG] 18 May 2023
                                               K serial runs of the model. In this work we in-                Hubara et al., 2016; So et al., 2021; Shazeer, 2019). Other
                                               troduce speculative decoding - an algorithm to                 approaches stem from the observation that not all infer-
                                               sample from autoregressive models faster without               ence steps are born alike - some require a very large model,
                                               any changes to the outputs, by computing several               while others can be approximated well by more efficient
                                               tokens in parallel. At the heart of our approach lie           models. These adaptive computation methods (e.g. Han
                                               the observations that (1) hard language-modeling               et al., 2021; Sukhbaatar et al., 2019; Schuster et al., 2021;
                                               tasks often include easier subtasks that can be ap-            Scardapane et al., 2020; Bapna et al., 2020; Elbayad et al.,
                                               proximated well by more efficient models, and                  2019; Schwartz et al., 2020) aim to use less compute re-
                                               (2) using speculative execution and a novel sam-               sources for easier inference steps. While many of these
                                               pling method, we can make exact decoding from                  solutions have proven extremely effective in practice, they
                                               the large models faster, by running them in par-               usually require changing the model architecture, changing
                                               allel on the outputs of the approximation mod-                 the training-procedure and re-training the models, and don’t
                                               els, potentially generating several tokens concur-             maintain identical outputs.
                                               rently, and without changing the distribution. Our             The key observation above, that some inference steps are
                                               method can accelerate existing off-the-shelf mod-              “harder” and some are “easier”, is also a key motivator for
                                               els without retraining or architecture changes. We             our work. We additionally observe that inference from large
                                               demonstrate it on T5-XXL and show a 2X-3X                      models is often not bottlenecked on arithmetic operations,
                                               acceleration compared to the standard T5X imple-               but rather on memory bandwidth and communication, so
                                               mentation, with identical outputs.                             additional computation resources might be available. There-
                                                                                                              fore we suggest increasing concurrency as a complemen-
                                                                                                              tary approach to using an adaptive amount of computation.
                                         1. Introduction                                                      Specifically, we are able to accelerate inference without
                                                                                                              changing the model architectures, without changing the
                                         Large autoregressive models, notably large Transformers              training-procedures or needing to re-train the models, and
                                         (Vaswani et al., 2017), are much more capable than smaller           without changing the model output distribution. This is
                                         models, as is evidenced countless times in recent years e.g.,        accomplished via speculative execution.
                                         in the text or image domains, like GPT-3 (Brown et al.,
                                         2020), LaMDA (Thoppilan et al., 2022), Parti (Yu et al.,             Speculative execution (Burton, 1985; Hennessy & Patterson,
                                         2022), and PaLM (Chowdhery et al., 2022). Unfortunately,             2012) is an optimization technique, common in processors,
                                         a single decode step from these larger models is significantly       where a task is performed in parallel to verifying if it’s
                                         slower than a step from their smaller counterparts, and mak-         actually needed - the payoff being increased concurrency.
                                         ing things worse, these steps are done serially - decoding K         A well-known example of speculative execution is branch
                                         tokens takes K serial runs of the model.                             prediction. For speculative execution to be effective, we
                                                                                                              need an efficient mechanism to suggest tasks to execute
                                         Given the importance of large autoregressive models and              that are likely to be needed. In this work, we generalize
                                         specifically large Transformers, several approaches were             speculative execution to the stochastic setting - where a
                                           *                     1                                            task might be needed with some probability. Applying this
                                            Equal contribution     Google Research, Mountain
                                         View, CA, USA. Correspondence to: Yaniv Leviathan                    to decoding from autoregressive models like Transformers,
                                         <leviathan@google.com>.                                              we sample generations from more efficient approximation
                                                                                                              models as speculative prefixes for the slower target mod-
                                         Proceedings of the 40 th International Conference on Machine         els. With a novel sampling method, speculative sampling,
                                         Learning, Honolulu, Hawaii, USA. PMLR 202, 2023. Copyright           we maximize the probability of these speculative tasks to
                                         2023 by the author(s).

                                                                                                          1
                                     Fast Inference from Transformers via Speculative Decoding




Figure 1. Our technique illustrated in the case of unconditional language modeling. Each line represents one iteration of the algorithm.
The green tokens are the suggestions made by the approximation model (here, a GPT-like Transformer decoder with 6M parameters
trained on lm1b with 8k tokens) that the target model (here, a GPT-like Transformer decoder with 97M parameters in the same setting)
accepted, while the red and blue tokens are the rejected suggestions and their corrections, respectively. For example, in the first line the
target model was run only once, and 5 tokens were generated.


be accepted, while guaranteeing that the outputs from our                2. Speculative Decoding
system have the same distribution as those from the target
model alone. For example, the sentence in Figure 1, con-                 2.1. Overview
sisting of 38 tokens, was generated by our method with                   Let Mp be the target model, inference from which we’re
only 9 serial runs of a larger target model (97M parameters)             trying to accelerate, and p(xt |x<t ) the distribution we get
thanks to a smaller and more efficient approximation model               from the model for a prefix x<t . Let Mq be a more effi-
(6M parameters), while the probability of generating it is               cient approximation model for the same task, and denote
unchanged.                                                               by q(xt |x<t ) the distribution we get from the model for a
We analyze our method in a variety of tasks and model                    prefix x<t 1 . The core idea is to (1) use the more efficient
sizes: unconditional generation from a 97M parameter GPT-                model Mq to generate γ ∈ Z+ completions (see Section 3.5
like model trained on lm1b, English to German translation                for how to optimally choose this parameter), then (2) use
and news article summarization with an 11B parameters                    the target model Mp to evaluate all of the guesses and their
T5-XXL model, and a dialog task with a 137B parameter                    respective probabilities from Mq in parallel, accepting all
LaMDA model. We implement our method and compare                         those that can lead to an identical distribution, and (3) sam-
actual walltimes for T5-XXL to those of the robust T5X                   pling an additional token from an adjusted distribution to fix
implementation (Roberts et al., 2022), showing an out-of-                the first one that was rejected, or to add an additional one
the-box latency improvement of 2X-3X, without any change                 if they are all accepted. That way, each parallel run of the
to the outputs (Section 4).                                              target model Mp will produce at least one new token (so the
                                                                         number of serial runs of the target model can never, even
Our method is easy to employ in actual production settings,              in the worst case, be larger than the simple autoregressive
doesn’t require training new models, and doesn’t change the              method), but it can potentially generate many new tokens,
outputs. Therefore, in common situations where memory                    up to γ + 1, depending on how well Mq approximates Mp .
bandwidth is the bottleneck, and compute resources are
available, it may be a good default to accelerate sampling               2.2. Standardized Sampling
from autoregressive models like Transformers.
                                                                         First, note that while there are many methods and parame-
To summarize, our main contributions are: (1) A generaliza-              ters of sampling, like argmax, top-k, nucleus, and setting
tion of speculative execution to the stochastic setting, with            a temperature, and popular implementations usually treat
a novel sampling method we call speculative sampling, and                them differently at the logits level, they can all easily be cast
(2) A decoding mechanism we call speculative decoding that               into standard sampling from an adjusted probability distribu-
can accelerate decoding from autoregressive models, with-                tion. For example, argmax sampling is equivalent to zeroing
out any change to the model architectures, training regimes              out non-max elements of the distribution and normalizing.
and output distributions.                                                We can therefore only deal with standard sampling from a
                                                                             1
                                                                               We’ll use p(x) to denote p(xt |x<t ) whenever the prefix x<t
                                                                         is clear from the context, and similarly for q(x).


                                                                     2
                                     Fast Inference from Transformers via Speculative Decoding

probability distribution, and cast all of the other types of            ber of tokens produced by a single run of Algorithm 1.
sampling into that framework. Going forward we’ll assume                Definition 3.1. The acceptance rate βx<t , given a prefix
that p(x) and q(x) are the distributions from Mp and Mq                 x<t , is the probability of accepting xt ∼ q(xt |x<t ) by
respectively, adjusted for the sampling method.                         speculative sampling, as per Section 2.32 .

2.3. Speculative Sampling                                               E(β) is then a natural measure of how well Mq approxi-
To sample x ∼ p(x), we instead sample x ∼ q(x), keeping                 mates Mp . If we make the simplifying assumption that the
it if q(x) ≤ p(x), and in case q(x) > p(x) we reject the                βs are i.i.d., and denote α = E(β), then the number of
                                                                        tokens produced by a single run of Algorithm 1 is a capped
sample with probability 1− p(x)
                             q(x) and sample x again from an            geometric variable, with success probability 1 − α and cap
adjusted distribution p0 (x) = norm(max(0, p(x) − q(x)))                γ + 1, and the expected number of tokens generated by
instead. It’s easy to show (see Appendix A.1) that for any              Algorithm 1 satisfies Equation (1). See Figure 2.
distributions p(x) and q(x), and x sampled in this way,
indeed x ∼ p(x).
                                                                                                                                                       1 − αγ+1
Given the distribution q(x) obtained from running Mq on                                                    E(# generated tokens) =                                               (1)
                                                                                                                                                         1−α
a conditioning pref ix, we can sample a token x1 ∼ q(x).
We then calculate the distribution p(x) by running Mp on
pref ix while in parallel speculatively calculating the distri-
bution of the next token x2 by running Mp on pref ix+[x1 ].                                         10                                                                      10
Once both computations complete, we proceed as per above:                                                       Baseline
                                                                                                                 =1
If x1 is rejected, we discard the computation of x2 and                                                          =3
                                                                                                    8            =5                                                         8
re-sample x1 from an adjusted distribution, and if x1 is ac-
                                                                                                                 =7
cepted, we keep both tokens. Algorithm 1 generalizes this


                                                                          E(tokens per iteration)
                                                                                                                 =
idea to sample between 1 and γ + 1 tokens at once.                                                  6                                                                       6


Algorithm 1 SpeculativeDecodingStep                                                                 4                                                                       4
  Inputs: Mp , Mq , pref ix.
  . Sample γ guesses x1,...,γ from Mq autoregressively.                                             2                                                                       2
  for i = 1 to γ do                                                                                 1                                                                       1
      qi (x) ← Mq (pref ix + [x1 , . . . , xi−1 ])
                                                                                                    0                                                                       0
      xi ∼ qi (x)                                                                                        0.50    0.55      0.60   0.65   0.70   0.75   0.80   0.85   0.90
  end for
  . Run Mp in parallel.
  p1 (x), . . . , pγ+1 (x) ←                                            Figure 2. The expected number of tokens generated by Algorithm 1
          Mp (pref ix), . . . , Mp (pref ix + [x1 , . . . , xγ ])       as a function of α for various values of γ.
  . Determine the number of accepted guesses n.
  r1 ∼ U (0, 1), . . . , rγ ∼ U (0, 1)                                  3.2. Calculating α
  n ← min({i − 1 | 1 ≤ i ≤ γ, ri > pqii(x)   (x)
                                                 } ∪ {γ})
                                                                        We’ll now derive a simple formula for calculating α given a
  . Adjust the distribution from Mp if needed.
                                                                        prefix and the two models Mp and Mq . We start by defining
  p0 (x) ← pn+1 (x)
                                                                        a natural divergence DLK :
  if n < γ then
      p0 (x) ← norm(max(0, pn+1 (x) − qn+1 (x)))
                                                                                                         P
                                                                        Definition 3.2. DLK (p, q) =        x |p(x) − M (x)| =
  end if
                                                                        P                                   p(x)+q(x)
                                                                           x |q(x) − M  (x)| where M (x) =      2     .
  . Return one token from Mp , and n tokens from Mq .                                                  P
                                                                        Lemma 3.3. DLK (p, q) = 1 − x min(p(x), q(x))
  t ∼ p0 (x)
  return pref ix + [x1 , . . . , xn , t]
                                                                        Proof. DLK (p, q) = x |p(x) − M (x)| = x |p−q|
                                                                                            P                    P
                                                                                                                   2   =
                                                                           P p+q−|p−q|          P
                                                                        1− x       2      = 1 −  x min(p(x), q(x))
3. Analysis
3.1. Number of Generated Tokens                                         From Lemma 3.3 we immediately get the following results:
Let’s analyze the reduction factor in the number of serial                  2
                                                                              As before, we’ll omit the x<t subscript whenever the prefix is
calls to the target model, or equivalently, the expected num-           clear from the context.

                                                                    3
                                   Fast Inference from Transformers via Speculative Decoding

Corollary 3.4. DLK (p, q) is a symmetric divergence in [0, 1]. Corollary 3.9. If α > c, there exists γ for which we’ll get
DLK (p, q) = 0 ⇐⇒ p = q.                                       an improvement, and the improvement factor will be at least
                                                               1+α
DLK (p, q) = 1 ⇐⇒ p and q have disjoint support.               1+c .
Theorem 3.5. β = 1 − DLK (p, q)
                                                                    Proof. If we get an improvement for γ, we’d also get an
                                                                    improvement for any 0 < γ ∗ < γ, so for our method to
                              (
                               1        q(x) ≤ p(x)
Proof. β      =     Ex∼q(x)     p(x)                       =        yield an improvement, we can evaluate Theorem 3.8 for
                                q(x)    q(x) > p(x)                                    1−α2
                                                                    γ = 1, yielding (1−α)(c+1) = 1+α
                                                                                                 1+c .
Ex∼q(x) min(1, p(x)
                          P
               q(x) ) =     x min(p(x), q(x))
                                                                    3.4. Number of Arithmetic Operations
Finally we get:
                                                                    Algorithm 1 does γ +1 runs of Mp in parallel, so the number
Corollary 3.6. α = 1 − E(DLK (p, q)) = E(min(p, q))                 of concurrent arithmetic operations grows by a factor of
                                                                    γ +1. Now, since Algorithm 1 produces at most γ +1 tokens
See Table 3 for empirically observed α values in our experi-
                                                                    per run, the total number of arithmetic operations might be
ments.
                                                                    higher than that of the standard decoding algorithm. When
                                                                    we accept the sample from Mq the increased concurrency
3.3. Walltime Improvement
                                                                    is “free” and the total number of operations isn’t increased3 .
We’ve shown that with the i.i.d. assumption our algorithm           When we reject a guess though, computation is wasted. Let’s
reduces the number of calls to the target model by a factor         now analyze the effect of our method on the total number
       γ+1
of 1−α                                                              of arithmetic operations.
    1−α . Note that speculative execution in general, and
our algorithm in particular, assume that we have enough             Definition 3.10. Let ĉ be the ratio of arithmetic operations
compute resources to support the increased concurrency              per token of the approximation model Mq to that of the
(Section 3.4). For the walltime anaylsis, we’ll assume that         target model Mp .
we can run γ + 1 concurrent evaluations of Mp in parallel
                                                                    Theorem 3.11. The expected factor of increase in the num-
without increasing the walltime. To get the total walltime
improvement, we now consider the cost of running the ap-            ber of total operations of Algorithm 1 is (1−α)(γĉ+γ+1)
                                                                                                                 1−αγ+1      .
proximation model Mq .
Definition 3.7. Let c, the cost coefficient, be the ratio be-       Proof. Denote by T̂ the number of arithmetic operations
tween the time for a single run of Mq and the time for a            done by a standard decoding baseline per token, i.e. the
single run of Mp .                                                  number of operations of a single run of Mp . Then a single
                                                                    iteration of Algorithm 1 costs T̂ ĉγ + T̂ (γ + 1) operations
Note that unlike α which is an intrinsic property of the            (for γ runs of Mq and γ + 1 parallel runs of Mp ). Dividing
models and the task, the value of c depends on the hardware         by the expected number of tokens produced by Algorithm 1,
configuration and software implementation details. In our           i.e. Equation (1), and by T̂ , we get the desired result.
experiments where Mq is typically a couple of orders of
magnitude smaller than Mp , c was always less than 0.05             If α is low, the increase in the number of arithmetic oper-
and often negligibly close to 0.                                    ations is high, and vice-versa. Note that for Transformer
Theorem 3.8. The expected improvement factor in total               decoders, the total number of arithmetic operations by Al-
                             1−αγ+1
walltime by Algorithm 1 is (1−α)(γc+1) .                            gorithm 1 (not counting runs of Mq ) can be bounded from
                                                                    above by a single run of the same-size Transformer encoder.
Proof. Denote the cost of running a single step of Mp by T .        Unlike the total number of arithmetic operations, the total
Now, each run of Algorithm 1 costs T cγ + T (for running            number of memory accesses can go down with our method.
the approximation model Mq γ times and running Mp once)             Specifically, the target model’s weights and KV cache can
                                              γ+1
and according to Equation (1) produces 1−α                          be read once per execution of Algorithm 1, so the number
                                            1−α tokens on
average. So the overall expected cost for producing a token         of memory accesses for reading them shrinks by a factor of
                                                                    1−αγ+1
with Algorithm 1 is (cγ+1)(1−α)
                         1−αγ+1   T . Since the cost of pro-          1−α , according to Equation (1).
ducing a single token with the standard decoding algorithm
is T , we get the desired result.                                   3.5. Choosing γ
                                                                    Given c and α and assuming enough compute resources (see
Note that Theorem 3.8 assumes long enough generations               Section 3.4), the optimal γ is the one maximizing the wall-
(for example, since we run Mp at least once, the improve-
                                                                       3
ment factor is capped by the number of generated tokens).                  Neglecting the cost of Mq .

                                                                4
                                                         Fast Inference from Transformers via Speculative Decoding




                                                                                                       10                                                                   10
            24          c = 0.01                                                                                                        Speed = 1
            23          c = 0.02                                                                                                        Ops = 1
            22
            21          c = 0.05                                                                                                        Speed = 3
            20          c = 0.1                                                                        8                                Ops = 3                             8
            19
            18                                                                                                                          Speed = 5
            17                                                                                                                          Ops = 5
            16
            15                                                                                         6                                Speed = 7                           6
            14                                                                                                                          Ops = 7
            13
  Optimal   12
            11
                                                                                                                                        Speed = 10
            10                                                                                         4                                Ops = 10                            4
             9
             8
             7
             6
             5                                                                                         2                                                                    2
             4
             3                                                                                         1                                                                    1
             2
             1
             0                                                                                         0                                                                    0
                 0.50    0.55      0.60   0.65    0.70    0.75   0.80   0.85   0.90                         0.50   0.55   0.60   0.65    0.70   0.75   0.80   0.85   0.90


Figure 3. The optimal γ as a function of α for various values of c.                             Figure 4. The speedup factor and the increase in number of arith-
                                                                                                metic operations as a function of α for various values of γ.

                                                                                      γ+1
                                                   1−α
time improvement equation (Theorem 3.8): (1−α)(γc+1)          .
Since γ is an integer, it can be easily found numerically, see                                  3.6. Approximation Models
Figure 3.
                                                                                                Speculative sampling, and therefore speculative decoding,
Table 1 and Figure 4 illustrate the trade-off between infer-                                    guarantee an identical output distribution for any choice
ence speed and the total number of arithmetic operations for                                    of approximation model Mq without restriction (see Ap-
various values of α and γ, assuming c = ĉ = 0. Figure 5                                        pendix A.1). In our experiments, we mostly tested existing
shows a simplified trace diagram.                                                               off-the-shelf smaller Transformers as the approximation
                                                                                                models. Further, we only tested approximation models of
                                                                                                the same architecture as the target models Mp and using the
Table 1. The total number of arithmetic operations and the infer-                               same probability standardization. In this setup, choosing
ence speed vs the baseline, for various values of γ and α, assuming
                                                                                                Mq to be around two orders of magnitude smaller than Mp
c = ĉ = 0.
                                                                                                usually performed best, balancing α and c (Theorem 3.8).
                        α           γ     O PERATIONS            S PEED                         Another type of approximation models, negligible-cost mod-
                        0.6         2            1.53X           1.96X                          els, are those for which c ≈ 0, i.e. approximation models
                        0.7         3            1.58X           2.53X                          with a negligible cost relative to the target model. In this
                        0.8         2            1.23X           2.44X                          case, we get an expected walltime improvement of 1−α
                                                                                                                                                        γ+1

                        0.8         5            1.63X           3.69X                                                                                1−α ,
                                                                                                                                     1
                        0.9         2            1.11X           2.71X                          which is bounded from above by 1−α      (we approach equal-
                        0.9        10            1.60X           6.86X                          ity if γ is large). One interesting type of negligible-cost
                                                                                                approximation models are n-gram models, where the evalu-
                                                                                                ation amounts to a table lookup. Interestingly, in empirical
Instead of picking a single value for γ based on α, since the
                                                                                                tests (Section 4.2) we get non zero αs even for these triv-
βs aren’t constant, we could get further improvement by pre-
                                                                                                ial n-gram models. For example, for the English-German
dicting the value of β and accordingly varying the value of γ
                                                                                                translation task, with Mp being T5-XXL 11B and Mq being
during the run of Algorithm 1. To get an upper bound on the
                                                                                                a trivial bigram model, we get α ≈ 0.2 which leads to an
additional improvement factor, assume we had an oracle for
                                                          1                                     inference speed improvement factor of 1.25X with γ = 3.
γ. We would then have E(# generated tokens) = 1−α            .
For typical values of c and α, and assuming unbounded com-                                      Other simple heuristics can be used as negligible-cost ap-
pute resources, the enhanced walltime improvement factor                                        proximation models. For example, in cases where long se-
can be up to ∼60% higher than the improvement factor with                                       quences are likely to repeat, such as for summarization tasks
a fixed γ. We leave exploring this for future work4 .                                           or chat-like interfaces 5 , an approximation model that simply
     4                                                                                             5
     The above bound assumes that we still run Mp to verify the or-                                  E.g. where a user and a language model iterate on content, like
acle’s predictions. If we skip those verifications the bound doesn’t                            text or code (“can you rewrite this story but change the ending”,
hold and we would get a substantial additional improvement.                                     “can you make this function also do X”).

                                                                                            5
                                     Fast Inference from Transformers via Speculative Decoding


=7                                                                                                                                Mp encoder
                                                                                                                                  Mq encoder
                                                                                                                                  Mp decoder
=3                                                                                                                                Mq decoder

Base
                                                                Wall time


Figure 5. A simplified trace diagram for a full encoder-decoder Transformer stack. The top row shows speculative decoding with γ = 7
so each of the calls to Mp (the purple blocks) is preceded by 7 calls to Mq (the blue blocks). The yellow block on the left is the call to the
encoder for Mp and the orange block is the call to the encoder for Mq . Likewise the middle row shows speculative decoding with γ = 3,
and the bottom row shows standard decoding.


copies tokens from the context in case we find a matching                 approximation models. As expected we see that α increases
prefix, might yield high values of α. These parameter-less                with the size of the approximation model. Interestingly, α
approximation models, have the additional advantage of                    and walltime improvement are higher for argmax sampling
being even simpler to deploy from a production standpoint.                (temp=0). We observe speedups of 2.6X (temp=1) and 3.4X
                                                                          (temp=0) on the translation task and slightly lower speedups
Another type of approximation models that can be used by
                                                                          of 2.3X (temp=1) and 3.1X (temp=0) for the summarization
speculative decoding are non-autoregressive models, like
                                                                          task. These empirical results match well with the theoreti-
those from (Stern et al., 2018). Then, instead of the au-
                                                                          cal predictions, with some variance due to implementation
togreressive loop in Algorithm 1 we’d just call the non-
                                                                          details (see Appendix A.3).
autoregressive model once.
A final example, interesting mostly from a theoretical per-
                                                                          Table 2. Empirical results for speeding up inference from a T5-
spective, is an approximation model which chooses tokens
                                                                          XXL 11B model.
at random, which guarantees some improvement (although
very small) for all models Mp .                                             TASK         Mq                 T EMP    γ      α       S PEED
                                                                            ENDE         T5- SMALL F          0       7    0.75       3.4X
                                                                            ENDE         T5- BASE             0       7     0.8       2.8X
                                                                            ENDE         T5- LARGE            0       7    0.82       1.7X
                                                                            ENDE         T5- SMALL F          1       7    0.62       2.6X
4. Experiments                                                              ENDE         T5- BASE             1       5    0.68       2.4X
                                                                            ENDE         T5- LARGE            1       3    0.71       1.4X
4.1. Empirical Walltime Improvement
                                                                            CNNDM        T5- SMALL F          0       5    0.65       3.1X
We implement our algorithm and compare it to the imple-                     CNNDM        T5- BASE             0       5    0.73       3.0X
mentation in the T5X codebase for accelerating T5-XXL.                      CNNDM        T5- LARGE            0       3    0.74       2.2X
                                                                            CNNDM        T5- SMALL F          1       5    0.53       2.3X
                                                                            CNNDM        T5- BASE             1       3    0.55       2.2X
Setup We test a standard encoder-decoder T5 version 1.1                     CNNDM        T5- LARGE            1       3    0.56       1.7X
model (Raffel et al., 2020) on two tasks from the T5 paper:
(1) English to German translation fine tuned on WMT EnDe,
and (2) Text summarization fine tuned on CCN/DM. For                      4.2. Empirical α Values
both tasks, we use T5-XXL (11B) for Mp . For the approx-                  While we only implemented our method for T5, we mea-
imation model Mq we test several existing configurations,                 sured α values for various tasks, sampling methods, target
namely T5-large (800M), T5-base (250M), and T5-small                      models Mp , and approximation models Mq . Specifically,
(77M) (Raffel et al., 2020). We use existing checkpoints                  we evaluated the expectation from Corollary 3.6 on 10K
for all models. We measure walltime improvements with a                   tokens generated by Mp , for each of the settings below.
batch size of 1 on a single TPU-v4 for both argmax sampling
(temp=0) and standard sampling (temp=1).                                  GPT-like (97M params) We test a decoder-only Trans-
                                                                          former model on unconditional language generation, trained
Results Table 2 shows the empirical results from our                      on lm1b (Chelba et al., 2013). The model here is a GPT-
method. We see that T5-small (77M), with a good balance                   like Transformer decoder with Gelu activations (Hendrycks
of c and α, provides the highest speedup out of the tested                & Gimpel, 2016). For Mq we experimented with a Trans-

                                                                      6
                                   Fast Inference from Transformers via Speculative Decoding

former decoder model with 6M parameters: dim 256, dim
                                                                       Table 3. Empirical α values for various target models Mp , approx-
feed-forward 1024, 2 layers, 4 attention heads, as well as
                                                                       imation models Mq , and sampling settings. T=0 and T=1 denote
simple unigram and bigram models. Mp has 97M parame-                   argmax and standard sampling respectively6 .
ters: dim 768, dim feed-forward 3072, 12 layers, 12 atten-
tion heads. We used Bert tokenization (Devlin et al., 2019)             Mp                        Mq                    S MPL     α
with 8k tokens for all models.
                                                                        GPT- LIKE (97M)           U NIGRAM              T =0     0.03
                                                                        GPT- LIKE (97M)           B IGRAM               T =0     0.05
LaMDA (137B params) We tested a decoder only                            GPT- LIKE (97M)           GPT- LIKE (6M)        T =0     0.88
LaMDA model on a dialog task (Thoppilan et al., 2022).                  GPT- LIKE (97M)           U NIGRAM              T =1     0.03
                                                                        GPT- LIKE (97M)           B IGRAM               T =1     0.05
We used existing checkpoints from LaMDA 137B as Mp
                                                                        GPT- LIKE (97M)           GPT- LIKE (6M)        T =1     0.89
and LaMDA 8B, LaMDA 2B, and LaMDA 100M for Mq .
                                                                        T5-XXL (E N D E )         U NIGRAM              T =0     0.08
See Section 4.1 for the setup of the T5-XXL (11B params)                T5-XXL (E N D E )         B IGRAM               T =0     0.20
model.                                                                  T5-XXL (E N D E )         T5- SMALL             T =0     0.75
                                                                        T5-XXL (E N D E )         T5- BASE              T =0     0.80
Table 3 summarizes the α values for the tested cases. We                T5-XXL (E N D E )         T5- LARGE             T =0     0.82
observe that approximation models that are a couple of                  T5-XXL (E N D E )         U NIGRAM              T =1     0.07
orders of magnitude smaller than the target model tend to               T5-XXL (E N D E )         B IGRAM               T =1     0.19
produce α values between 0.5 and 0.9. Interestingly, we also            T5-XXL (E N D E )         T5- SMALL             T =1     0.62
                                                                        T5-XXL (E N D E )         T5- BASE              T =1     0.68
note that for all models, the sharper the adjusted distribution,        T5-XXL (E N D E )         T5- LARGE             T =1     0.71
the higher the α values. Finally, we note that even trivial
unigram and bigram approximations yield non negligible                  T5-XXL (CNNDM)            U NIGRAM              T =0     0.13
α values. For example, for the case of English to German                T5-XXL (CNNDM)            B IGRAM               T =0     0.23
                                                                        T5-XXL (CNNDM)            T5- SMALL             T =0     0.65
translation, the bigram model has an α value of 0.2, and                T5-XXL (CNNDM)            T5- BASE              T =0     0.73
since c = 0 in this case, yields a 1.25X speed improvement,             T5-XXL (CNNDM)            T5- LARGE             T =0     0.74
which is surprisingly high for this trivial approximation               T5-XXL (CNNDM)            U NIGRAM              T =1     0.08
model (but is still lower than the speedup we get from using            T5-XXL (CNNDM)            B IGRAM               T =1     0.16
T5-small as the approximation model).                                   T5-XXL (CNNDM)            T5- SMALL             T =1     0.53
                                                                        T5-XXL (CNNDM)            T5- BASE              T =1     0.55
                                                                        T5-XXL (CNNDM)            T5- LARGE             T =1     0.56
5. Related work                                                         L A MDA (137B)            L A MDA (100M)        T =0     0.61
The efficiency of inference from large models was studied               L A MDA (137B)            L A MDA (2B)          T =0     0.71
                                                                        L A MDA (137B)            L A MDA (8B)          T =0     0.75
extensively (Dehghani et al., 2021). Many approaches aim                L A MDA (137B)            L A MDA (100M)        T =1     0.57
to speed up inference from large models in general, and au-             L A MDA (137B)            L A MDA (2B)          T =1     0.71
toregressive models like Transformers in particular. Numer-             L A MDA (137B)            L A MDA (8B)          T =1     0.74
ous techniques try to make inference more efficient for all
tokens, e.g. distillation (Hinton et al., 2015), sparcification
(Jaszczur et al., 2021), quantization (Hubara et al., 2016),
                                                                       re-training of existing models. They usually also change the
and architecture modification (So et al., 2021; Shazeer,
                                                                       outputs of the model. We note that while many of the meth-
2019). Closer to our approach are adaptive computation
                                                                       ods above improve the memory to arithmetic-operations
methods which adapt the amount of computation to problem
                                                                       ratio, in cases where the ratio remains high, these methods
difficulty (Han et al., 2021). Examples include attending to a
                                                                       and our speculative decoding method might be effective in
subset of the inputs (Sukhbaatar et al., 2019), and early exits
                                                                       tandem.
(Schuster et al., 2021; Scardapane et al., 2020; Bapna et al.,
2020; Elbayad et al., 2019; Schwartz et al., 2020). Notably,           Two prior methods leverage speculative execution for speed-
Wisdom of Committees (Schwartz et al., 2020) leverages                 ing up decoding from autoregressive models. Blockwise
off-the-shelf smaller models, but is an adaptive computation           Parallel Decoding (Stern et al., 2018) decodes several to-
approach, and so it uses a heuristic to determine when to              kens in parallel, similarly to our work. However, it only
stop, losing the guarantee of identical outputs to those of            supports greedy decoding (temperature=0) and not the gen-
the target models. In general, adaptive computation meth-              eral stochastic setting, it requires additional training of a
ods usually learn, either within the model itself or with an           custom model, and focuses on preserving down-stream task
auxiliary model, when a computation shortcut can be taken.             quality, instead of guaranteeing identical outputs. Shallow
Usually, these methods save on both inference time and                 Aggressive Decoding (SAD) (Sun et al., 2021) also decodes
arithmetic operations, but require a change of architecture, a         several tokens in parallel, similarly to our work. Unlike
change of training procedure and training custom models or             our work, SAD only supports copying the input to the out-

                                                                   7
                                  Fast Inference from Transformers via Speculative Decoding

put, and not general approximation models, making it only             experiments we always performed the same standardization
suitable for the cases where the inputs and outputs are very          on the distributions generated by the approximation model
similar like grammatical error correction. In addition, simi-         as the desired one for the target model (Section 2.2), but fur-
larly to Blockwise Parallel Decoding, SAD does not support            ther improvements might be obtained by applying different
the general stochastic sampling setting.                              transformations. We tested speculative decoding only in the
                                                                      text modality, but it might work well in other domains (e.g.
After we initially published our work, an independent im-
                                                                      images) which would be interesting to experiment with.
plementation of speculative decoding (Chen et al., 2023)
showed similar 2X-2.5X improvements on Chinchilla 70B.                Finally, we note that stochastic speculative execution and
                                                                      speculative sampling can be helpful outside the scope of
6. Discussion                                                         speculative decoding from autoregressive models. For ex-
                                                                      ample, given two slow functions, f (x) and g(y) such that
We presented speculative sampling which enables efficient             f (x) generates a distribution from which g’s input is sam-
stochastic speculative execution - i.e. speculative execu-            pled, we could use our method to run f and g in parallel.
tion in the stochastic setting. We analyzed its impact on             This setup might arise e.g. in physics simulations, or in rein-
decoding from autoregressive models like Transformers via             forcement learning where f is a large model that produces a
speculative decoding and have shown that given enough                 distribution on actions, and g is the world simulation, which
compute resources, we get meaningful 2X-3X speedups in                would be interesting to explore.
practice vs T5X, a popular optimized implementation.
One limitation of speculative execution in general, and of            Acknowledgments
speculative decoding in particular, is that latency is im-
                                                                      We would like to extend a special thank you to YaGuang Li
proved through increased concurrency at the cost of an in-
                                                                      for help with everything LaMDA related and for calculating
creased number of arithmetic operations. Thus, our method
                                                                      the LaMDA figures in the paper, and to Blake Hechtman
is not helpful for configurations where additional compu-
                                                                      for great insights and help with XLA. We would also like
tation resources are not available. However, in common
                                                                      to thank the reviewers for insightful comments, as well
cases where additional computation resources are available
                                                                      as Asaf Aharoni, Reiner Pope, Sasha Goldshtein, Nadav
(e.g. when memory bandwidth is the bottleneck) our method
                                                                      Sherman, Eyal Segalis, Eyal Molad, Dani Valevski, Daniel
provides the speedup with significant benefits: the model
                                                                      Wasserman, Valerie Nygaard, Danny Vainstein, the LaMDA
architecture doesn’t change, retraining isn’t required, and
                                                                      and Theta Labs teams at Google, and our families.
most importantly, the output distribution is guaranteed to
stay the same. Our method is easy to implement, and can
be used to speedup inference using out-of-the-box models              References
without developing and evaluating custom schemes.                     Bapna, A., Arivazhagan, N., and Firat, O. Controlling
There are several directions for follow up research, impor-             computation versus quality for neural sequence models.
tantly, further investigating the compatibility of speculative          ArXiv, abs/2002.07106, 2020.
decoding with beam search (see Appendix A.4). Also, while
                                                                      Brown, T. B., Mann, B., Ryder, N., Subbiah, M., Kaplan,
our method yields substantial speedups with existing off-the-
                                                                        J., Dhariwal, P., Neelakantan, A., Shyam, P., Sastry, G.,
shelf approximation models, greater improvements might
                                                                        Askell, A., Agarwal, S., Herbert-Voss, A., Krueger, G.,
be obtained via custom approximation models (Section 3.6),
                                                                        Henighan, T., Child, R., Ramesh, A., Ziegler, D. M., Wu,
such as those with custom architectures (e.g. custom sizes,
                                                                        J., Winter, C., Hesse, C., Chen, M., Sigler, E., Litwin, M.,
non-autoregressive models, or various heuristics) or with
                                                                        Gray, S., Chess, B., Clark, J., Berner, C., McCandlish,
custom training procedures (e.g. standard distillation with
                                                                        S., Radford, A., Sutskever, I., and Amodei, D. Lan-
soft targets from Mp , or optimizing Mq for α directly). It
                                                                        guage models are few-shot learners. In Proceedings of
could also be interesting to explore a hierarchical version
                                                                        the 34th International Conference on Neural Informa-
of the algorithm, where the approximation model is itself
                                                                        tion Processing Systems, NIPS’20, Red Hook, NY, USA,
accelerated by an even faster model, which could allow
                                                                        2020. Curran Associates Inc. ISBN 9781713829546.
for more capable approximation models. In this work we
fixed the approximation model and the number of guesses               Burton, F. W. Speculative computation, parallelism, and
γ throughout inference, but varying them during inference               functional programming. IEEE Transactions on Comput-
could yield additional improvements (Section 3.5). In our               ers, C-34(12):1190–1193, 1985. doi: 10.1109/TC.1985.
    6
      Note that the outputs from the LaMDA model always go
                                                                        6312218.
through a T op40 filter. This has no effect on argmax, but does
have some effect on standard sampling.
                                                                      Chelba, C., Mikolov, T., Schuster, M., Ge, Q., Brants, T.,
                                                                        Koehn, P. T., and Robinson, T. One billion word bench-

                                                                  8
                                 Fast Inference from Transformers via Speculative Decoding

  mark for measuring progress in statistical language mod-          Jaszczur, S., Chowdhery, A., Mohiuddin, A., Kaiser, L.,
  eling. In Interspeech, 2013.                                        Gajewski, W., Michalewski, H., and Kanerva, J. Sparse
                                                                      is enough in scaling transformers. In Neural Information
Chen, C., Borgeaud, S., Irving, G., Lespiau, J.-B., Sifre,            Processing Systems, 2021.
  L., and Jumper, J. M. Accelerating large language
  model decoding with speculative sampling. ArXiv,                  Raffel, C., Shazeer, N., Roberts, A., Lee, K., Narang, S.,
  abs/2302.01318, 2023.                                               Matena, M., Zhou, Y., Li, W., and Liu, P. J. Exploring
                                                                      the limits of transfer learning with a unified text-to-text
Chowdhery, A., Narang, S., Devlin, J., Bosma, M., Mishra,
                                                                      transformer. The Journal of Machine Learning Research,
  G., Roberts, A., Barham, P., Chung, H. W., Sutton, C.,
                                                                      21(1):5485–5551, 2020.
  Gehrmann, S., Schuh, P., Shi, K., Tsvyashchenko, S.,
  Maynez, J., Rao, A., Barnes, P., Tay, Y., Shazeer, N. M.,         Roberts, A., Chung, H. W., Levskaya, A., Mishra, G., Brad-
  Prabhakaran, V., Reif, E., Du, N., Hutchinson, B. C.,               bury, J., Andor, D., Narang, S., Lester, B., Gaffney, C.,
  Pope, R., Bradbury, J., Austin, J., Isard, M., Gur-Ari,             Mohiuddin, A., Hawthorne, C., Lewkowycz, A., Salcianu,
  G., Yin, P., Duke, T., Levskaya, A., Ghemawat, S., Dev,             A., van Zee, M., Austin, J., Goodman, S., Soares, L. B.,
  S., Michalewski, H., Garcı́a, X., Misra, V., Robinson,              Hu, H., Tsvyashchenko, S., Chowdhery, A., Bastings,
  K., Fedus, L., Zhou, D., Ippolito, D., Luan, D., Lim,               J., Bulian, J., Garcı́a, X., Ni, J., Chen, A., Kenealy, K.,
  H., Zoph, B., Spiridonov, A., Sepassi, R., Dohan, D.,               Clark, J., Lee, S., Garrette, D. H., Lee-Thorp, J., Raffel,
  Agrawal, S., Omernick, M., Dai, A. M., Pillai, T. S., Pel-          C., Shazeer, N. M., Ritter, M., Bosma, M., Passos, A.,
  lat, M., Lewkowycz, A., Moreira, E., Child, R., Polozov,            Maitin-Shepard, J. B., Fiedel, N., Omernick, M., Saeta,
  O., Lee, K., Zhou, Z., Wang, X., Saeta, B., Dı́az, M., Fi-          B., Sepassi, R., Spiridonov, A., Newlan, J., and Ges-
  rat, O., Catasta, M., Wei, J., Meier-Hellstern, K. S., Eck,         mundo, A. Scaling up models and data with t5x and
  D., Dean, J., Petrov, S., and Fiedel, N. Palm: Scaling lan-         seqio. ArXiv, abs/2203.17189, 2022.
  guage modeling with pathways. ArXiv, abs/2204.02311,
  2022.                                                             Scardapane, S., Scarpiniti, M., Baccarelli, E., and Uncini,
                                                                      A. Why should we add early exits to neural networks?
Dehghani, M., Arnab, A., Beyer, L., Vaswani, A., and Tay,             Cognitive Computation, 12(5):954–966, 2020.
 Y. The efficiency misnomer. ArXiv, abs/2110.12894,
  2021.                                                             Schuster, T., Fisch, A., Jaakkola, T., and Barzilay, R. Con-
                                                                      sistent accelerated inference via confident adaptive trans-
Devlin, J., Chang, M.-W., Lee, K., and Toutanova, K. Bert:            formers. In Conference on Empirical Methods in Natural
  Pre-training of deep bidirectional transformers for lan-            Language Processing, 2021.
  guage understanding. ArXiv, abs/1810.04805, 2019.
                                                                    Schwartz, R., Stanovsky, G., Swayamdipta, S., Dodge, J.,
Elbayad, M., Gu, J., Grave, E., and Auli, M. Depth-adaptive
                                                                      and Smith, N. A. The right tool for the job: Matching
  transformer. ArXiv, abs/1910.10073, 2019.
                                                                      model and instance complexities. In Annual Meeting of
Han, Y., Huang, G., Song, S., Yang, L., Wang, H., and Wang,           the Association for Computational Linguistics, 2020.
  Y. Dynamic neural networks: A survey. IEEE Transac-
  tions on Pattern Analysis and Machine Intelligence, 44:           Shazeer, N. M. Fast transformer decoding: One write-head
  7436–7456, 2021.                                                    is all you need. ArXiv, abs/1911.02150, 2019.

Hendrycks, D. and Gimpel, K. Bridging nonlinearities and            So, D. R., Ma’nke, W., Liu, H., Dai, Z., Shazeer, N. M., and
  stochastic regularizers with gaussian error linear units.           Le, Q. V. Primer: Searching for efficient transformers for
 ArXiv, abs/1606.08415, 2016.                                         language modeling. ArXiv, abs/2109.08668, 2021.

Hennessy, J. L. and Patterson, D. A. Computer Architecture:         Stern, M., Shazeer, N., and Uszkoreit, J. Blockwise parallel
  A Quantitative Approach. Morgan Kaufmann, Amster-                   decoding for deep autoregressive models. Advances in
  dam, 5 edition, 2012. ISBN 978-0-12-383872-8.                       Neural Information Processing Systems, 31, 2018.

Hinton, G. E., Vinyals, O., and Dean, J. Distilling the             Sukhbaatar, S., Grave, E., Bojanowski, P., and Joulin, A.
  knowledge in a neural network. ArXiv, abs/1503.02531,               Adaptive attention span in transformers. In Annual Meet-
  2015.                                                               ing of the Association for Computational Linguistics,
                                                                      2019.
Hubara, I., Courbariaux, M., Soudry, D., El-Yaniv, R., and
  Bengio, Y. Quantized neural networks: Training neu-               Sun, X., Ge, T., Wei, F., and Wang, H. Instantaneous gram-
  ral networks with low precision weights and activations.            matical error correction with shallow aggressive decoding.
 ArXiv, abs/1609.07061, 2016.                                         ArXiv, abs/2106.04970, 2021.

                                                                9
                                 Fast Inference from Transformers via Speculative Decoding

Thoppilan, R., Freitas, D. D., Hall, J., Shazeer, N. M., Kul-
  shreshtha, A., Cheng, H.-T., Jin, A., Bos, T., Baker, L.,
  Du, Y., Li, Y., Lee, H., Zheng, H., Ghafouri, A., Mene-
  gali, M., Huang, Y., Krikun, M., Lepikhin, D., Qin, J.,
  Chen, D., Xu, Y., Chen, Z., Roberts, A., Bosma, M.,
  Zhou, Y., Chang, C.-C., Krivokon, I. A., Rusch, W. J.,
  Pickett, M., Meier-Hellstern, K. S., Morris, M. R., Doshi,
  T., Santos, R. D., Duke, T., Søraker, J. H., Zevenber-
  gen, B., Prabhakaran, V., Dı́az, M., Hutchinson, B., Ol-
  son, K., Molina, A., Hoffman-John, E., Lee, J., Aroyo,
  L., Rajakumar, R., Butryna, A., Lamm, M., Kuzmina,
  V. O., Fenton, J., Cohen, A., Bernstein, R., Kurzweil, R.,
  Aguera-Arcas, B., Cui, C., Croak, M., hsin Chi, E. H., and
  Le, Q. Lamda: Language models for dialog applications.
  ArXiv, abs/2201.08239, 2022.
Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones,
  L., Gomez, A. N., Kaiser, Ł., and Polosukhin, I. At-
  tention is all you need. Advances in neural information
  processing systems, 30, 2017.

Yu, J., Xu, Y., Koh, J. Y., Luong, T., Baid, G., Wang, Z.,
  Vasudevan, V., Ku, A., Yang, Y., Ayan, B. K., Hutchinson,
  B. C., Han, W., Parekh, Z., Li, X., Zhang, H., Baldridge,
  J., and Wu, Y. Scaling autoregressive models for content-
  rich text-to-image generation. ArXiv, abs/2206.10789,
  2022.




                                                                10
                                 Fast Inference from Transformers via Speculative Decoding

A. Appendix
A.1. Correctness of Speculative Sampling
We will now show that for any distributions p(x) and q(x), the tokens sampled via speculative sampling from p(x) and q(x)
are distributed identically to those sampled from p(x) alone. Let β be the acceptance probability (Definition 3.1).
Note that as p0 (x) = norm(max(0, p(x) − q(x))) = P 0 p(x)−min(q(x),p(x))
                                                               0          0    0    = p(x)−min(q(x),p(x)) , the normalizing
                                                         x (p(x )−min(q(x ),p(x )))         1−β
                                        0
constant for the adjusted distribution p (x) is 1 − β, where the last equation follows immediately from Lemma 3.3 and
Theorem 3.5.
Now:


                        P (x = x0 ) = P (guess accepted, x = x0 ) + P (guess rejected, x = x0 )

Where:

                                                                       p(x0 )
                         P (guess accepted, x = x0 ) = q(x0 ) min(1,          ) = min(q(x0 ), p(x0 ))
                                                                       q(x0 )

And:


                        P (guess rejected, x = x0 ) = (1 − β)p0 (x0 ) = p(x0 ) − min(q(x0 ), p(x0 ))

Overall:


                          P (x = x0 ) = min(p(x0 ), q(x0 )) + p(x0 ) − min(p(x0 ), q(x0 )) = p(x0 ).

As desired. 

A.2. Speculative Sampling vs. Rejection Sampling
Rejection sampling is the following iterative sampling procedure that looks superficially similar to ours:

  1. Sample x ∼ q(x) and r ∼ U (0, 1).

  2. If r < Mp(x)
              q(x) return x.

  3. Go to 1.

Where M = maxx p(x)    q(x) . We could employ a non-iterative version of rejection sampling instead of speculative sampling
- specifically go through steps 1 and 2 above, and otherwise sample from an unmodified p(x) directly. That would
be much less efficient than our method though. Specifically, the expected accept probability here is Ex∼q(x) Mp(x)   q(x) =
P                 q(x0 )    P                q(x)     P
   x p(x) minx p(x0 ) ≤        x p(x) min(1, p(x) ) =   x min(p(x), q(x)) = α is (potentially much) lower than the expected
                0

accept probability in our method α.

A.3. Theoretical Predictions vs. Empirical Runtimes
Table 4 compares the expected runtime improvements based on Theorem 3.8 to the empirically measured runtimes from
Table 2. We estimated the values of c for the various models based on profiler traces. We can see that the theoretical
predictions mostly match the measured runtimes. The larger differences are due to: (1) optimization differences between our
implementation and the baseline, and (2) the simplifying assumption that the βs are i.i.d. being only an approximation (see
Section 3.1).

                                                             11
                                   Fast Inference from Transformers via Speculative Decoding


                 Table 4. Expected improvement factor (E XP) vs. empirically measured improvement factor (E MP).

                               TASK        Mq             T EMP        γ      α       c     E XP      E MP
                               ENDE         T5- SMALL       0          7   0.75     0.02     3.2       3.4
                               ENDE         T5- BASE        0          7    0.8     0.04     3.3       2.8
                               ENDE         T5- LARGE       0          7   0.82     0.11     2.5       1.7
                               ENDE         T5- SMALL       1          7   0.62     0.02     2.3       2.6
                               ENDE         T5- BASE        1          5   0.68     0.04     2.4       2.4
                               ENDE         T5- LARGE       1          3   0.71     0.11     2.0       1.4
                               CNNDM        T5- SMALL       0          5   0.65     0.02     2.4       3.1
                               CNNDM        T5- BASE        0          5   0.73     0.04     2.6       3.0
                               CNNDM        T5- LARGE       0          3   0.74     0.11     2.0       2.2
                               CNNDM        T5- SMALL       1          5   0.53     0.02     1.9       2.3
                               CNNDM        T5- BASE        1          3   0.55     0.04     1.8       2.2
                               CNNDM        T5- LARGE       1          3   0.56     0.11     1.6       1.7



A.4. Application to Beam Search
Our method can be applied, with some performance penalty, to beam search sampling. Given the original beam width w, we
can perform beam search with the approximation model Mq and beam width u ≥ w for γ steps. Then, we can use Mp to
check all of the candidates in parallel (costing a compute budget of (w + uγ) runs of Mp ). Finally, for each step, we can
accept the guesses of Mq as long as topw (Mp ) ⊆ topu (Mq ) to get identical results to regular beam search with Mp alone
(with a more elaborate procedure we could also accept cases where the candidates we got happen to have higher probabilities
than those of Mp alone). The analysis of our method in this setting is more involved and we leave it for future work.

A.5. Lenience
A strong property of Algorithm 1 is that the output distribution is guaranteed to remain unchanged. That said, if we’re
willing to allow some changes, with nice guarantees, we can get further inference speed improvements. To further motivate
this, note that when we train two models with identical architectures and sizes on the same dataset, the generated probability
distributions will not be identical, so some lenience might make sense. Note that the results in this paper except for this
section use the strictest version of Algorithm 1 and don’t allow lenience of any kind.
We could include a lenience parameter l ∈ [0, 1] and multiply q(x) by l before comparing with p(x) in Algorithm 1. This
still maintains the nice guarantee that no token can be sampled with probability greater than p(x)
                                                                                                l . This means for example,
                  1
that with l = 10    no token can be sampled with more than 10X its ground truth probability, so we can guarantee that
extremely rare tokens will remain extremely rare (there is no guarantee on the minimum probability, so lenience could hurt
the diversity of the samples).
                                                                  (
                                                                    1      lq(x) ≤ p(x)                       p(x)
Specifically, with a lenience factor l we have α = Ex∼q(x) p(x)                              = Ex∼q(x) max(p(x),lq(x))  =
                                                                     lq(x) lq(x)  > p(x)
P        p(x)q(x)       1
                          P                        P        p(x)
   x max(p(x),lq(x)) = l    x min(p(x), lq(x)) =      x min( l , q(x)).

Table 5 shows α values for different values of l when Mp is T5-XXL (11B) and Mq is T5-small (77M). With c = 0.015,
using lenience values of 1, 0.5, 0.3, and 0.1 (meaning that no token can be sampled with probability greater than 1X, 2X, 3X
and 10X of the ground truth) we get improvement factors of 2.5X, 3.1X, 3.6X, and 5X respectively.


Table 5. α values for various values of l with standard sampling where Mp is T5-XXL (11B) on the EnDe translation task.

                                    Mq                    l=1      l = 0.5        l = 0.3   l = 0.1
                                    U NIGRAM              0.07          0.1        0.11       0.16
                                    B IGRAM               0.19         0.23        0.25       0.32
                                    T5- SMALL (77M)       0.62         0.71        0.76       0.84
                                    T5- BASE (250M)       0.68          0.8        0.83       0.90


                                                                  12
                                      Fast Inference from Transformers via Speculative Decoding

Note that when using temperature = 0 (i.e. argmax sampling), we can no longer use lenience as above. Instead, we could
allow some lenience before standardizing the distributions. For example, we could accept the token x sampled from Mq in
case p(x) ≤ l · max(p). In this case, we measure similar empirical increases in α values to those with temperature = 1. For
example, when using lenience values of 1, 0.5, 0.3, and 0.1 for Mp T5-XXL Mq T5-small for English-German translation,
we get α values of 0.75, 0.75, 0.8, 0.87. Taking for example c = 0.015 and γ = 8 we get speed improvement factors of
3.3X, 3.3X, 3.9X, and 4.9X respectively7 .




   7
       In this case, unlike in the standard sampling case shown in Table 5, a lenience factor of 0.5 doesn’t improve the speed-up.

                                                                     13