REALM: Retrieval-Augmented Language Model Pre-Training

Download PDF

REALM: Retrieval-Augmented Language Model Pre-Training

Metadata

Compute Setup

The paper reports the setup at the system level but does not name a TPU generation. It pretrains REALM for 200K steps on 64 Google Cloud TPUs with batch size 512, learning rate 3e-5, and BERT's default optimizer. The document-embedding step for the MIPS index is parallelized over 16 TPUs. The authors also state that the entire model can be run on a single machine with a 12GB GPU. Under the project rule, the exact TPU generation should not be upgraded beyond the source text; if placed on the local accelerator-era map, this belongs to the 2018-2020 Google Cloud TPU Transformer period, but "64 Google Cloud TPUs" is the paper-stated device description.

The retrieval corpus is large enough to be a system component. The English Wikipedia snapshot is split into chunks of up to 288 BERT wordpieces, yielding just over 13 million retrieval candidates. During pretraining, each example retrieves and marginalizes over 8 candidate documents, including a null document; during fine-tuning inference, the system considers the top 5 candidates.

Bottleneck

The bottleneck is the mismatch between gradient learning and large external memory. REALM wants a neural retriever whose document scores improve during masked-language-model pretraining, but exact marginalization over millions of documents is impossible per step. Maximum Inner Product Search makes retrieval sublinear, but it requires precomputed document embeddings and an index. As soon as the retriever parameters change, those cached document embeddings become stale.

This is a memory-bandwidth and systems problem as much as an NLP problem. Pure parametric models store facts in weights, making updates expensive and opaque. Retrieval-augmented models expose knowledge in a corpus, but move cost into index construction, embedding refresh, MIPS lookup, and cross-attention over retrieved text.

Method Adaptation

REALM adapts to the hardware regime by decoupling retrieval search from full neural training. The retriever scores each document with an inner product between an input embedding and a document embedding. Because the document side can be cached, MIPS can find approximate top-k candidates without scanning every chunk of Wikipedia on every batch item. After candidates are retrieved, REALM recomputes probabilities and gradients for those top documents using the current parameters, reducing the damage from stale cached embeddings.

The key engineering move is asynchronous MIPS refresh. The authors run a primary trainer job for gradient updates and a secondary index-builder job that embeds and indexes documents. The paper says this yields about one index refresh per 500 training steps, a cadence that keeps retrieval aligned without stalling the 64-TPU pretraining job.

The model itself is kept small enough for downstream use. The comparisons list REALM around 330M parameters, and the paper emphasizes that it can run on one 12GB GPU. This is the opposite of scaling all factual knowledge into a giant dense model. The large object is the corpus and its index, not the neural parameter tensor.

Evidence

The main Open-QA table reports REALM outperforming prior systems by 4-16 absolute points across NaturalQuestions-Open, WebQuestions, and CuratedTREC. With Wikipedia as both pretraining and knowledge corpus, REALM reports 39.2, 40.2, and 46.8 on the three benchmarks; with CC-News as pretraining corpus and Wikipedia as knowledge corpus, it reports 40.4, 40.7, and 42.9. The comparison row for ORQA is 33.3, 36.4, and 30.1 with the same 330M parameter scale.

The ablation table is the strongest compute-structure evidence. On NaturalQuestions-Open, REALM reaches 38.2 exact match and 38.5 recall@5. Resetting the retriever to the baseline state leaves retrieval recall at 13.9 and match at 35.3. The ORQA baseline is 31.3 match and 13.9 recall@5. A 30x stale MIPS setting collapses to 28.7 match and 15.1 recall@5. This directly shows that the index refresh schedule is not an implementation detail; stale memory breaks the method.

The paper also compares against implicit-knowledge scaling. It notes that moving T5 from Base to 11B makes the model about 50 times larger for roughly 5 points of accuracy, while REALM outperforms T5-11B while being about 30 times smaller. The claim is that some factual knowledge is cheaper to retrieve than to bake into dense parameters.

Historical Effect

REALM is an early dense-retrieval pretraining card for retrieval-augmented language models. It showed that external memory can be learned with the language-model objective itself, not merely bolted on with a frozen BM25 or heuristic retriever. The historical compute effect is that "model size" becomes only one component of a larger serving and training system: corpus size, embedding cache, MIPS index, refresh frequency, and top-k marginalization are all first-class compute variables.

It also anticipates later RAG systems in which factual freshness and provenance are handled outside dense weights. The card is especially useful because it quantifies the systems failure mode: stale retrieval erases the gain.

Limits

REALM is engineering-heavy. It needs a large text corpus, dense document embeddings, an approximate search index, a secondary index-builder job, and tuning of refresh frequency. The paper does not name the exact TPU generation, and the hardware accounting is less complete than later frontier-model reports. Retrieval also adds inference-time latency and failure modes: if the answer is not in the top candidates, the reader cannot recover it; if the corpus is outdated or noisy, the model may retrieve plausible but wrong support. The method is strongest when factual knowledge can be represented in retrievable text chunks.

Links

  • Compute regime: history/compute_regimes/inference_time_compute_post_training/README.md
  • Source PDF and extracted text are listed in metadata above.
  • Queue status: read_complete.
  • Method index: rag
  • Ledger updates: compute bottlenecks