Scalable Diffusion Models with Transformers
Scalable Diffusion Models with Transformers - 中文验证版
英文原始依据卡片:dit_2022.md
状态:已翻译。
元数据
- Slug:
dit_2022 - 年份: 2022
- 会议: arXiv
- 作者: William Peebles、Saining Xie
- 阅读状态: read complete
- 计算范式: 生成式媒体计算 (
generative_media_compute) - 主要来源: PDF、抽取文本
- 阅读卡创建日期: 2026-06-15
计算设置
论文明确列出了硬件和软件栈:模型在 JAX 中实现并在 TPU-v3 pod 上训练。其具体设备基准是对于最大的 256x256 模型:DiT-XL/2 在 TPU v3-256 pod 上以全局 batch size 256 大约以 5.7 次迭代/秒的速度训练。该方案使用 AdamW、恒定学习率 1e-4、无权重衰减、水平翻转和 EMA 0.9999。跨规模使用相同的超参数有助于分离计算分配而非调优。
模型在潜在空间而非像素空间中运行。一个 Stable-Diffusion 风格的 VAE 将 256x256 图像压缩为 32x32x4 的 latent,将 512x512 图像压缩为 64x64x4 的 latent。报告的 DiT 参数和 FLOP 计数排除了 84M 参数的 VAE。在 256x256 下,DiT-XL/2 有 675M 参数和 118.64G 前向 FLOP;在 512x512 下,相同的 XL/2 配置处理 1024 个 latent token 并使用 524.60G 前向 FLOP。主要长时间运行是 256x256 下的 7M 步和 512x512 下的 3M 步,均使用 batch size 256。
瓶颈
DiT 将扩散图像质量框架化为计算分配问题。一个扩散样本需要多次模型评估,因此成本为去噪器前向 FLOP 乘以采样步数。Transformer 去噪器将 token 计数作为第二个杠杆。将 latent patch 大小从 p=4 减小到 p=2 四倍化了序列长度,并至少四倍化了 transformer FLOP,同时参数计数几乎不变。这使得 FLOP(而不仅仅是参数)成为瓶颈和缩放变量。
论文还区分了训练计算和采样计算。训练计算近似为模型 Gflops 乘以 batch size 乘以训练步数乘以 3,其中因子 3 将反向传播视为大约两倍的前向计算。采样计算可以通过使用更多去噪步骤来增加,但论文测试了这是否能补偿较小的骨干网络,发现不能。因此瓶颈不仅仅是"运行更多采样步数";每次去噪步骤中的模型端 Gflops 是核心。
方法适配
DiT 将扩散骨干网络适配到加速器友好的 Transformer 缩放。它首先保留 latent 扩散压缩技巧:在 VAE latent 上操作,因为直接在高分辨率像素空间中训练在计算上是不可行的。然后它将 U-Net 去噪器替换为类似 ViT 的序列模型。patchify 层将潜在网格转换为 token;较小的 patch 通过每个去噪遍花费更多注意力/MLP 计算来换取质量。这是一个清晰的计算旋钮,因为改变 patch 大小强烈改变 FLOP,但对参数计数几乎没有影响。
条件化被设计为保持密集 Transformer 吞吐量。论文测试了上下文条件化、交叉注意力、自适应层归一化和 adaLN-Zero。交叉注意力增加了最多的 FLOP,大约 15% 的开销。纯 adaLN 增加了最少的开销,是计算效率最高的。胜出的选择 adaLN-Zero,从时间步和类别嵌入回归 scale/shift 以及残差缩放参数,将残差路径初始化为零,并增加了可忽略不计的 FLOP。这使得条件化成为归一化/调制问题而非单独的注意力问题,从而使主要 kernel 组合接近标准 Transformer 块。
该方法还固定了 batch 和评估结构。所有模型使用 batch size 256 和通用设置,FID 使用 ADM 的 TensorFlow 评估套件计算。基准比较使用 FID-50K 和 250 步 DDPM 采样,除采样步数研究外。
证据
缩放表格直接给出了计算故事。在 ImageNet 256x256 上训练 400K 步,DiT-S/8 使用 0.36G FLOP,无引导 FID 为 153.60,而 DiT-XL/2 使用 118.64G FLOP,达到 FID 19.47。在保持 XL 参数大致固定的情况下,token 计数改善 FID:XL/8 有 676M 参数、7.39G FLOP 和 FID 106.41;XL/4 有 675M 参数、29.05G FLOP 和 FID 43.01;XL/2 有 675M 参数、118.64G FLOP 和 FID 19.47。图像 token 计算,而非参数计数本身,驱动质量。
块设计比较也支持方法选择。四个高 GFLOP 的 XL/2 变体训练了 400K 步。上下文条件化有 449M 参数、119.37G FLOP 和 FID 35.24;交叉注意力有 598M 参数、137.62G FLOP 和 FID 26.14;普通 adaLN 有 600M 参数、118.56G FLOP 和 FID 25.21;adaLN-Zero 有 675M 参数、118.64G FLOP 和 FID 19.47。因此最佳结果不是最高 FLOP 的交叉注意力块,而是更好地训练密集骨干网络的初始化和调制结构。
经过扩展训练后,无引导 DiT-XL/2 在 400K 步的 FID 从 19.47 改善到 2.352M 步的 10.67 和 7M 步的 9.62。在 ImageNet 256x256 上使用分类器无关引导,DiT-XL/2-G 达到 FID 2.27、sFID 4.60、Inception Score 278.24、precision 0.83 和 recall 0.57,优于列出的先前 LDM-4-G FID 3.60。在 512x512 下,3M 步模型使用 524.6G FLOP 达到引导 FID 3.04,而先前 ADM-G/ADM-U 的 FID 为 3.85。
采样计算实验尤其以计算为重点。DiT-L/2 使用 1000 步采样的每张图像消耗 80.7 Tflops,而 DiT-XL/2 使用 128 步消耗 15.2 Tflops,且仍然具有更好的 FID-10K,23.7 对 25.9。论文的结论是,增加采样计算不能补偿不足的模型计算。
历史影响
DiT 将扩散骨干网络搬到了与语言和视觉 Transformer 相同的缩放面上。本文的历史影响不仅仅是用 Transformer 替换 U-Net,而是展示了 latent token 计数、模型宽度/深度和前向传递 GFLOP 构成图像生成的有序设计空间。它赋予后来的 text-to-image 系统一种熟悉的计算语言:缩放密集 Transformer 块、测量 FLOP、训练更长时间,并在条件化简单时使用调制而非独立的交叉注意力。
局限
- 最强的训练运行需要 TPU v3-256 pod 和数百万步。
- 报告的 DiT FLOP 和参数排除了外部 84M 参数 VAE,即使 VAE 是端到端图像流水线所必需的。
- 最佳基准 FID 使用分类器无关引导,增加了采样时的工作量。
- 实验是 ImageNet 类别条件化的;论文并未证明相同的计算曲线适用于所有文本条件化或特定领域的设置。