Scalable Diffusion Models with Transformers
Scalable Diffusion Models with Transformers
Metadata
- Slug:
dit_2022 - Year: 2022
- Venue: arXiv
- Authors: William Peebles, Saining Xie
- Reading status: read complete
- Compute regime: Generative media compute (
generative_media_compute) - Primary sources: PDF, extracted text
- Reading card created: 2026-06-15
Compute Setup
The paper explicitly names the hardware and software stack: models are implemented in JAX and trained on TPU-v3 pods. Its concrete device benchmark is for the largest 256x256 model: DiT-XL/2 trains at roughly 5.7 iterations per second on a TPU v3-256 pod with global batch size 256. The recipe uses AdamW, constant learning rate 1e-4, no weight decay, horizontal flips, and EMA 0.9999. Identical hyperparameters across sizes help isolate compute allocation rather than tuning.
The model runs in latent space rather than pixel space. A Stable-Diffusion-style VAE compresses 256x256 images into 32x32x4 latents and 512x512 images into 64x64x4 latents. The reported DiT parameter and FLOP counts exclude the 84M-parameter VAE. At 256x256, DiT-XL/2 has 675M parameters and 118.64G forward FLOPs; at 512x512 the same XL/2 configuration processes 1024 latent tokens and uses 524.60G forward FLOPs. The main long runs are 7M steps at 256x256 and 3M steps at 512x512, both with batch size 256.
Bottleneck
DiT frames diffusion-image quality as a compute-allocation problem. A diffusion sample requires many model evaluations, so cost is denoiser forward FLOPs times sampling steps. Transformer denoisers add token count as a second lever. Reducing latent patch size from p=4 to p=2 quadruples sequence length and at least quadruples transformer FLOPs while leaving parameter count nearly unchanged. That makes FLOPs, not just parameters, the bottleneck and scaling variable.
The paper also distinguishes training compute from sampling compute. Training compute is approximated as model Gflops times batch size times training steps times 3, with the factor of 3 treating backward as about twice the forward compute. Sampling compute can be increased by using more denoising steps, but the paper tests whether that compensates for a smaller backbone and finds that it does not. Thus the bottleneck is not merely "run more sampling steps"; model-side Gflops during each denoising step are central.
Method Adaptation
DiT adapts the diffusion backbone to accelerator-friendly Transformer scaling. It first keeps the latent-diffusion compression trick: operate on VAE latents because training directly in high-resolution pixel space is computationally prohibitive. It then replaces the U-Net denoiser with a ViT-like sequence model. The patchify layer turns the latent grid into tokens; smaller patches buy quality by spending more attention/MLP compute per denoising pass. This is a clean compute dial because changing patch size strongly changes FLOPs while having little effect on parameter count.
Conditioning is designed to preserve dense Transformer throughput. The paper tests in-context conditioning, cross-attention, adaptive layer norm, and adaLN-Zero. Cross-attention adds the most FLOPs, roughly a 15% overhead. Plain adaLN adds the least and is described as the most compute-efficient. The winning choice, adaLN-Zero, regresses scale/shift and residual scaling parameters from timestep and class embeddings, initializes the residual paths to zero, and adds negligible FLOPs. This makes conditioning a normalization/modulation problem rather than a separate attention problem, which keeps the main kernel mix close to standard Transformer blocks.
The method also fixes batch and evaluation structure. All models use batch size 256 and common settings, and FID is computed with ADM's TensorFlow evaluation suite. Benchmark comparisons use FID-50K with 250 DDPM sampling steps except in the sampling-step study.
Evidence
The scaling tables make the compute story direct. At 400K steps on ImageNet 256x256, DiT-S/8 uses 0.36G FLOPs and has no-guidance FID 153.60, while DiT-XL/2 uses 118.64G FLOPs and reaches FID 19.47. Holding XL parameters roughly fixed, token count improves FID: XL/8 has 676M parameters, 7.39G FLOPs, and FID 106.41; XL/4 has 675M, 29.05G, and FID 43.01; XL/2 has 675M, 118.64G, and FID 19.47. Image-token compute, not parameter count alone, drives quality.
The block-design comparison also supports the method choice. Four high-GFLOP XL/2 variants are trained for 400K steps. In-context conditioning has 449M parameters, 119.37G FLOPs, and FID 35.24; cross-attention has 598M parameters, 137.62G FLOPs, and FID 26.14; vanilla adaLN has 600M parameters, 118.56G FLOPs, and FID 25.21; adaLN-Zero has 675M parameters, 118.64G FLOPs, and FID 19.47. The best result is therefore not the highest-FLOP cross-attention block, but the initialization and modulation structure that trains the dense backbone better.
After extended training, no-guidance DiT-XL/2 improves from FID 19.47 at 400K steps to 10.67 at 2.352M and 9.62 at 7M. With classifier-free guidance on ImageNet 256x256, DiT-XL/2-G reaches FID 2.27, sFID 4.60, Inception Score 278.24, precision 0.83, and recall 0.57, improving over the listed prior LDM-4-G FID 3.60. At 512x512, the 3M-step model uses 524.6G FLOPs and reaches guided FID 3.04, versus prior ADM-G/ADM-U FID 3.85.
The sampling-compute experiment is especially compute-focused. DiT-L/2 with 1000 sampling steps uses 80.7 Tflops per image, while DiT-XL/2 with 128 steps uses 15.2 Tflops per image and still has better FID-10K, 23.7 versus 25.9. The paper concludes that increasing sampling compute cannot compensate for insufficient model compute.
Historical Effect
DiT moved diffusion backbones onto the same scaling surface as language and vision Transformers. The paper's historical effect is not only replacing a U-Net, but showing that latent token count, model width/depth, and forward-pass GFLOPs form an orderly design space for image generation. It gave later text-to-image systems a familiar compute language: scale dense Transformer blocks, measure FLOPs, train longer, and use modulation rather than separate cross-attention when conditioning is simple.
Limits
- The strongest training run requires a TPU v3-256 pod and millions of steps.
- Reported DiT FLOPs and parameters exclude the external 84M-parameter VAE, even though the VAE is required for the end-to-end image pipeline.
- Best benchmark FIDs use classifier-free guidance, which increases sampling-time work.
- The experiments are ImageNet class-conditional; the paper does not prove the same compute curves for all text-conditioned or domain-specific settings.
Links
- Compute regime: generative media compute
- Method index: diffusion, transformer, generative_models
- Ledger updates: compute bottlenecks