GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding

Download PDF

GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding

Metadata

  • Slug: gshard_2020
  • Year: 2020
  • Venue: arXiv
  • Authors: Dmitry Lepikhin et al.
  • Reading status: read complete
  • Compute regime: Sparse and memory-efficient scaling
  • Primary sources: PDF, extracted text

Compute Setup

The paper explicitly uses TPU v3 clusters:

  • Up to 2048 TPU v3 cores.
  • 2D TPU cluster structure.
  • Expert counts tied to device counts such as 128, 512, and 2048 cores.

The main reported multilingual model is MoE(2048E, 36L), a 600B-parameter model trained on 2048 TPU v3 cores for 4 days. Table 3 reports 0.72 steps per second, a 4M-token batch, 22.4 TPU v3 core-years, and average BLEU 44.3. The paper also discusses a 1T-weight MoE(2048E, 60L) bfloat16-activation experiment, but does not include it in the main results because reproducibility required careful manual diagnostics.

Bottleneck

The bottleneck is scaling sparse giant models without hand-writing fragile model-parallel code. Single-device memory is far too small for hundreds of billions of parameters, but naive model parallelism makes the programmer coordinate communication across devices and can make graph size or compilation time grow with the number of partitions. The paper is explicit that models are orders of magnitude beyond one accelerator's memory capacity, but it also treats manual partitioning as a software bottleneck.

MoE adds its own device-level bottleneck. Tokens must be routed to experts, expert buffers must be bounded, and the dispatch/combine steps require cross-device communication. AllToAll is the critical collective for moving token representations between the group-sharded and expert-sharded layouts. Attention remains memory-bandwidth bound at short sequence lengths, while dense feed-forward and projection layers map well to TPU matrix units.

Method Adaptation

GShard adapts MoE Transformers to TPU clusters by:

  • Adding lightweight sharding annotations.
  • Using an XLA SPMD partitioner.
  • Placing top-2 MoE layers in alternating feed-forward positions.
  • Grouping tokens for routing and expert capacity control.
  • Using auxiliary load-balancing loss and random second-expert routing.
  • Using AllToAll for dispatch and combine.
  • Relying on compiler rematerialization.

The sharding API separates model description from partitioning. Users can annotate tensors with replicate, split, or shard, and the XLA SPMD partitioner produces a single program for all devices. This avoids generating separate per-device graphs and keeps compilation scalable to thousands of partitions. The MoE model primarily switches between group/batch-style partitioning and expert partitioning.

The MoE layer uses top-2 routing. A trainable gating network chooses up to two experts for each token. Expert capacity is set from the number of tokens and experts, and tokens that overflow both selected experts are dropped for that layer. To keep load balanced, the gate uses local group dispatching, an auxiliary load-balancing loss, and random second-expert routing: if the second expert's gate weight is small, the method can skip it to conserve expert capacity.

The implementation is a communication adaptation. Dispatch and combine are expressed as einsums, then the compiler inserts AllToAll resharding between dimensions. Dense Transformer FFN and projection layers remain large matmuls for TPU utilization. When memory exceeds per-device limits, compiler rematerialization recomputes activations in the backward pass. The model thus spends extra compute to stay within memory and uses sparse activation to decouple parameter count from per-token compute.

Evidence

The main training table gives the headline result. MoE(2048E, 36L) uses 2048 cores, runs at 0.72 steps per second, uses a 4M-token batch, costs 22.4 TPU v3 core-years, trains for 4.0 days, and reaches average BLEU 44.3 with delta BLEU 13.5. The 512-expert, 36-layer model uses 512 cores, 1M-token batch, 15.5 TPU core-years, 11.0 days, and BLEU 43.7. The 128-expert, 36-layer model uses 128 cores, 1M-token batch, 6.1 TPU core-years, 17.3 days, and BLEU 39.0.

The dense comparison explains why conditional compute mattered. The table lists a dense T(96L) model on 2048 cores with a 4M-token batch at about 235.5 TPU core-years, about 42 days, and BLEU 36.9. The introduction also states that the best dense single Transformer baseline with 2.3B parameters achieved much lower quality while requiring 235.5 TPU v3 core-years. GShard's 600B sparse model therefore increases parameter count massively while reducing training cost relative to a dense scale-up path.

The memory and runtime evidence is equally important. Per-device memory is roughly constant as the number of devices and experts increases, because replicated weights, distributed MoE weights, and activations are partitioned by the SPMD layout. With fixed layer count, both weight memory and activation memory stay constant as expert count grows. Step time grows sublinearly: the paper reports only a 1.7x execution-time increase when scaling from 128 to 2048 devices/experts, a 16x expert increase.

Communication becomes the limiting residual cost. At 128 experts, the model achieves more than 70% of the estimated roofline performance; at 2048 experts it still reaches 48%. Dense feed-forward layers and Transformer projections achieve more than 85% peak FLOPs, while attention achieves more than 30% because it is more memory-bandwidth bound. MoE dispatch and combine are AllToAll operations: when expert count grows 16x from 128 to 2048, their execution time increases about 3.75x and their share of MoE plus Transformer time rises from 16% to 36%. The communication microbenchmark says AllReduce is roughly constant with device count on TPU, while AllToAll scales roughly with the square root of the number of partitions; from 16 to 2048 partitions, AllToAll increases only 9x while partitions increase 128x.

Historical Effect

GShard moved MoE from hand-built cluster experiments to compiler-supported TPU-scale sparse Transformers. It made 600B-class conditional-compute models operational and showed that model capacity, memory layout, and communication collectives could be managed through lightweight annotations and compiler SPMD partitioning.

Historically, it bridges the original sparsely gated MoE idea and later simpler expert systems such as Switch Transformer. It also extends the Mesh TensorFlow lesson: layout is part of the method, not a secondary implementation detail.

Limits

  • Expert capacity cannot scale indefinitely because capacity must remain at least one token.
  • The 1T bfloat16-activation model was trainable only with careful manual diagnostics and had numerical stability issues, so it was excluded from the main reproducible results.
  • AllToAll communication becomes a growing fraction of runtime.
  • The system depends on compiler and mesh assumptions.
  • Expert count was tied to device count for simplicity in the experiments, although the paper says this is not a requirement.
  • Sparse parameter count is not the same as dense compute per token; comparisons must track activated computation, memory, communication, and quality.

Links