PaLM: Scaling Language Modeling with Pathways
PaLM: Scaling Language Modeling with Pathways
Metadata
- Slug:
palm_2022 - Year: 2022
- Venue: arXiv
- Authors: Aakanksha Chowdhery et al.
- Reading status: read complete
- Compute regime: Hyperscale dense LLM training
- Primary sources: PDF, extracted text
Compute Setup
The paper is unusually explicit about the machine. PaLM 540B was trained on 6144 TPU v4 chips using Pathways. The system was two TPU v4 Pods, each with 3072 TPU v4 chips attached to 768 hosts, so the full training job spanned 1536 hosts across two pods. The model card also lists TPU v4 for training and deployment. This matters because the paper is not only a model-scaling report; it is a report that dense decoder-only training could cross a single-pod boundary without pipeline parallelism.
The training target was a 540B-parameter dense Transformer trained on a single pass over a 780B-token corpus. The largest batch schedule increased with training: 512 sequences, or about 1M tokens, until step 50k; 1024 sequences, or about 2M tokens, until step 115k; then 2048 sequences, or about 4M tokens. The training run is reported as 1200 hours on 6144 TPU v4 chips plus 336 hours on 3072 TPU v4 chips, including downtime and repeated steps. The training compute table reports 2527.2 zettaFLOPs for PaLM 540B.
Bottleneck
The immediate bottleneck is dense training at pod scale. A 540B dense model has to shard weights and optimizer state while still presenting large matrix multiplications to the hardware. The paper's system discussion treats pipeline parallelism as a bottleneck rather than a solution for this case: pipeline training splits a batch into micro-batches, creates bubbles while filling and draining the pipeline, reloads weights for each micro-batch, and adds software complexity. PaLM's compute problem is therefore to use more chips without inheriting those pipeline costs.
The second bottleneck is cross-pod communication. Within each pod the model is sharded, but across pods the training uses pod-level data parallelism. After each pod computes gradients on half the batch, the corresponding hosts exchange gradient shards over the datacenter network. The paper quantifies this as about 1.3 GB exchanged by each host pair per training step, producing an aggregate burst of 81 Tbps across all hosts. That bursty all-host exchange is the systems constraint behind Pathways' networking work.
Memory is a third constraint. PaLM uses rematerialization because storing all intermediate activations would cap feasible batch size. The paper is explicit that extra hardware FLOPs can save memory, but the practical target is tokens per second, not simply high raw FLOP use.
The rough training-state lower bound for 540B dense parameters is about 8.64 TB under the 16-bytes-per-parameter mixed-precision Adam rule, before activations, temporary buffers, and sharding overhead. BF16 inference weights alone would be about 1.08 TB before KV cache. The card does not expose enough per-layer key/value-head detail to compute an exact cache size from the local summary, but the correct inference bound would be weight memory plus KV cache, not the parameter file alone.
Method Adaptation
PaLM adapts the dense Transformer to the TPU v4 topology with a two-level parallelism plan. Inside each pod, every pod holds a full copy of the model logically, but each weight tensor is partitioned over 3072 chips using 12-way model parallelism and 256-way fully sharded data parallelism. Across pods, Pathways runs two-way data parallelism: one Python client dispatches half the batch to each pod, both pods execute forward and backward computation, then the pods exchange gradients and apply updates in parallel.
The architecture was also adjusted for hardware efficiency. The paper uses a parallel formulation of Transformer blocks in which attention and MLP work can be expressed in a way that improves large-scale training speed; the reported improvement is roughly 15%. Batch size is deliberately increased later in training, partly because larger batches create larger matrix multiplication dimensions and therefore better TPU efficiency. The deterministic data pipeline, JAX/XLA/T5X stack, compiler optimizations, and parallel layers are all part of the compute story, not incidental implementation detail.
The paper also argues for model FLOPs utilization (MFU) as the cleaner metric. Hardware FLOPs utilization depends on implementation choices such as rematerialization, while MFU compares observed token throughput with the theoretical throughput needed for the model's forward and backward passes. That distinction lets PaLM report high throughput without treating recomputation FLOPs as useful model work.
Evidence
The strongest compute evidence is the throughput/utilization accounting. PaLM 540B reaches 238.3K tokens/sec at batch size 2048, or about 4M tokens. The paper reports 46.2% MFU and 57.8% hardware FLOPs utilization, comparing favorably with the listed prior large models: GPT-3 at 21.3% MFU, Gopher at 32.5%, and Megatron-Turing NLG 530B at about 30%. With two pods, throughput is about 1.95x single-pod throughput, which the authors describe as 97% of perfect weak scaling because the batch size doubles.
The model-quality evidence shows that this hardware scale was not merely a utilization exercise. The 540B model is evaluated at the 780B-token checkpoint across language, reasoning, multilingual, and code tasks. The paper also includes an internal counterfactual discussion: a 62B model trained for many more tokens could improve on some aggregate curves for lower total FLOPs, but training that smaller model on the same number of chips would require very large batches, and PaLM 540B was already at a 4M-token batch. This connects the benchmark frontier directly to accelerator utilization and sample-efficiency constraints.
Historical Effect
PaLM demonstrates dense LLM training at TPU v4 pod scale without pipeline parallelism. Historically, it sits between single-pod dense scaling and later training systems that normalize huge accelerator pools. Its contribution is partly the model, but just as much the compute structure: model parallelism and fully sharded data parallelism within a pod, weak-scaling data parallelism across pods, and explicit accounting for rematerialization, batch size, and network bursts.
It also helped establish MFU as a practical systems metric for large language models. By separating required model FLOPs from extra recomputation FLOPs, PaLM made it easier to compare different model and compiler strategies. The result is a compute card that says not just "6144 chips," but how those chips were fed, synchronized, and measured.
Limits
The major compute limit is token allocation. Chinchilla-style analysis suggests PaLM is likely undertrained relative to its parameter count: 540B parameters saw 780B tokens, while later compute-optimal practice would spend more of the same budget on data. PaLM itself notes that some subcorpora begin repeating around the 780B-token scale, which constrains simply extending training under the same data mixture.
The batch and hardware arguments also have limits. Training a smaller model for more tokens with the same total FLOPs might require fewer chips, increasing wall-clock time, or the same chips with even larger batches. The paper states that the 540B batch was already 4M tokens and that it was unclear whether larger batches would maintain sample efficiency. Finally, hardware and efficiency are reported in depth, but deployment impact, fairness, and toxicity analysis remain limited relative to the scale of the model.
Links
- Parent regime: compute spine
- Related card: Chinchilla 2022
- Method index: transformer, parallelism, scaling_laws
- Ledger updates: compute bottlenecks