Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity
Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity
Metadata
- Slug:
switch_transformer_2021 - Year: 2021
- Venue: JMLR
- Authors: William Fedus, Barret Zoph, Noam Shazeer
- Reading status: read complete
- Compute regime: Sparse and memory-efficient scaling
- Primary sources: PDF, extracted text
Compute Setup
The paper explicitly focuses on TPU architectures. The key benchmark table says all MoE and Switch models are trained with the same amount of computation, 32 cores, and on the same TPUv3 hardware. Figure 5 likewise states that all models in the speed comparison are trained on 32 TPUv3 cores with equal FLOPs per example.
For trillion-scale models, the extracted text gives the model sizes and parallelism design but not a clear exact device count. Under the project rule, the large-run hardware is inferred as TPU-v3-style distributed training consistent with the rest of the paper and its Mesh TensorFlow implementation. The paper says Switch models combine data, model, and expert parallelism, and recommends one expert per core for smaller expert-count regimes, but the exact core count for Switch-XXL and Switch-C is not stated in the extracted text.
Bottleneck
Prior MoE systems were difficult to adopt because of routing complexity, communication cost, and training instability. Sparse routing creates expert overflow, all-to-all communication, and bfloat16 numerical issues. The paper frames Switch as a response to three blockers: model complexity, training difficulty, and communication costs.
The device bottleneck is fixed-shape accelerator execution with dynamic routing. TPU compilation wants statically sized tensors, but the router sends a data-dependent number of tokens to each expert. Switch therefore needs a fixed expert capacity: tokens per batch divided by number of experts, multiplied by a capacity factor. If too many tokens choose one expert, overflowed tokens skip the expert layer. Raising capacity reduces drops but increases computation and all-to-all communication.
Method Adaptation
Switch Transformer simplifies sparse routing:
- Use top-1 routing instead of top-2 routing.
- Send each token to one expert.
- Use fixed expert capacity and an auxiliary load-balancing loss.
- Cast router inputs and softmax to float32, then recast dispatch/combine to bfloat16.
- Keep the dense Transformer structure while replacing feed-forward blocks with sparse expert blocks.
The central adaptation is top-1 routing. Earlier MoE Transformers often used top-2 routing, sending each token to two experts. Switch sends each token only to its highest-probability expert. The paper lists three benefits: less router computation, roughly half the expert capacity requirement because each token is routed once, and lower communication/implementation complexity. The model still keeps the router differentiable through the gate probability for the chosen expert and adds an auxiliary load-balancing loss so tokens do not collapse onto a few experts.
The implementation maps naturally onto expert parallelism. In the Mesh TensorFlow pseudocode, inputs are reshaped into num cores by tokens per core by model dimension. The router creates dispatch and combine tensors over num cores, tokens, experts, and expert capacity. An all-to-all sends tokens to the cores that own the experts, expert feed-forward computation runs, and another all-to-all returns outputs. Top-1 routing reduces both expert paths and routed traffic.
Selective precision is a hardware-specific stability fix. Pure bfloat16 is fast on TPUs, but the router softmax is numerically sensitive. Switch casts router inputs/logits to float32 inside the router function, computes softmax and dispatch/combine decisions, then recasts dispatch/combine tensors to bfloat16 so expensive cross-device tensors do not remain float32. That keeps the communication path close to bfloat16 speed while giving the routing decision float32 stability.
Evidence
The controlled benchmark table is the first compute evidence. On 32 TPUv3 cores, T5-Base reaches quality -1.731 after 100k steps and does not hit the -1.50 threshold in the measured 100k steps. T5-Large reaches the threshold in 131.1 hours at 470 examples/sec. Switch-Base with 128 experts and capacity factor 1.0 reaches quality -1.561 after 100k steps, the threshold in 62.8 hours, and 1000 examples/sec. The comparable top-2 MoE-Base takes 80.1 hours at 860 examples/sec.
Figure 5 gives the headline speedup: a 64-expert Switch-Base reaches the same quality as T5-Base in one-seventh the time on 32 TPUv3 cores with equal FLOPs per example. Figure 4 shows that scaling experts from 2 through 256 improves perplexity while keeping the per-example compute budget equal; the 256-expert point has 14.7B parameters compared with the 223M-parameter T5-Base point. The source also notes a 7.5x step-time speedup for a Switch model reaching at 60k steps the quality T5-Base reaches at 450k.
Selective precision is directly tested. A 32-expert Switch-Base in float32 has quality -1.718 and speed 1160 examples/sec. Pure bfloat16 reaches 1390 examples/sec but diverges, with quality listed as -3.780. Selective precision reaches -1.716 at 1390 examples/sec, matching float32 training dynamics while preserving bfloat16-like speed.
The multilingual and large-model evidence extends the scale. In multilingual pretraining, Switch shows a mean 5x speedup over mT5-Base, and 91% of languages achieve at least 4x speedup. Table 9 lists Switch-XXL at 395B parameters and 6.3T FLOPs per sequence, and Switch-C at 1571B parameters and 890B FLOPs per sequence. Switch-C uses 2048 experts, 15 layers, and expert frequency 1; Switch-XXL uses 64 experts, 24 layers, and expert frequency 1/2. At 500k steps, Switch-XXL reports negative log perplexity -1.008 and Switch-C -1.043, both better than T5-XXL at -1.095 in the table.
Historical Effect
Switch made sparse expert Transformers simpler and more trainable, helping top-1 routing become the practical MoE default. It is a key sparse-scaling branch after GShard because it reduces the routing and communication surface while preserving the idea of increasing parameter count without proportional activated FLOPs per token.
The historical effect is practical adoption. GShard proved compiler-supported MoE could scale to hundreds of billions of parameters; Switch made the layer easier to understand, benchmark, and stabilize with bfloat16-friendly selective precision.
Limits
- Largest models can be unstable.
- Fine-tuning behavior depends on poorly understood tradeoffs between FLOPs per token and parameter count.
- Deployment remains hard.
- Distillation preserves only part of the quality gain.
- The extracted text does not clearly list the exact device count for Switch-XXL or Switch-C, so their hardware scale is inferred rather than explicitly reported in the card.
- Fixed expert capacity means tokens can overflow and skip expert computation; larger capacity factors reduce overflow but increase compute and communication.
- The paper reports no observed instability for Switch-C, but says Switch-XXL remained unstable despite the proposed stability techniques.
Links
- Parent regime: compute spine
- Related card: GShard 2020
- Method index: moe
- Ledger updates: compute bottlenecks