Distilling the Knowledge in a Neural Network

下载 PDF

Distilling the Knowledge in a Neural Network - 中文验证版

英文原文卡片:distillation_2015.md

状态:已翻译。

元数据

计算设置

论文明确将 Android Voice Search 作为一个部署目标,但未命名确切的训练硬件。按项目规则,研究期间的设置推断为 Google 的分布式数据中心 CPU 时代/早期加速器基础设施,而产品约束是移动端或对延迟敏感的服务。论文明确说明 Android 声学模型是一个大规模的生产型 DNN,而非玩具 MNIST 模型。

对于 ASR 实验,基线声学模型有 8 个隐藏层,每层 2560 个 ReLU 单元,最终 softmax 包含 14,000 个 HMM 标签,约 85M 参数。输入为 26 帧 40 个 Mel 尺度滤波器组系数,帧移为 10 ms,预测第 21 帧的 HMM 状态。训练使用约 2000 小时口语英语,产生约 700M 训练样本。论文称声学模型使用分布式随机梯度下降训练,但未列出设备类型、核心数或加速器。

对于 JFT,论文对分布式结构更加明确。JFT 有 100M 标记图像和 15,000 个标签。Google 的基线模型在大量核心上使用异步 SGD 训练了约六个月,许多副本处理不同的 mini-batch,梯度发送到分片参数服务器,每个副本通过将不同神经元子集放在每个核心上分布到多个核心。这是数据并行加模型并行神经元分片加参数服务器同步。

瓶颈

大型集成提高准确率,但对延迟敏感的生产服务来说过于昂贵。核心计算问题是在部署单一紧凑模型的同时保持繁琐教师的准确率。在 ASR 情况下,一个 10 模型集成会在本就庞大的 85M 参数声学网络上倍增推理工作量。这对 Android Voice Search 式服务来说并不吸引人,因为延迟和吞吐量至关重要。

JFT 的瓶颈不同但相关。庞大图像模型的完整集成并不可行,因为基线全模型本身就已训练了约六个月。训练多个专家可以并行化,但对每张图像运行每个专家在推理时过于昂贵。因此论文将训练时并行性与服务时成本分开:在并行数据中心计算上花费教师和专家的计算资源,然后尝试迁移或有选择地使用这些知识。

方法适配

蒸馏通过训练一个学生来匹配软化的教师概率,将模型压缩适配到生产计算。提高 softmax 温度产生更软的目标分布,暴露出硬标签隐藏的类别相似性。蒸馏时学生使用相同的高温来匹配教师分布,训练后以温度 1 运行。这将计算从推理转移到训练:用繁琐模型生成教师概率,然后部署更便宜的学生。

ASR 实验使用一个 10 模型教师集成,架构与基线模型相同。蒸馏尝试温度 1、2、5 和 10,并对硬标签交叉熵使用 0.5 的相对权重。这是一个务实的训练目标:在迁移集成暗知识的同时保留真实标签。

专家部分适配巨大的标签空间和稀缺的全集成计算。每个专家覆盖一个易混淆的类别子集加上一个杂项类;JFT 实验训练了 61 个专家,每个 300 个类加上杂项类。专家从训练好的基线网络出发,在几天而非数周内独立训练,并可在推理时使用通才模型的预测来选择。这是 embarrassingly parallel 训练加上在小规模活跃专家集上的条件推理。

证据

对于 Android 声学建模,基线达到 58.9% 测试帧准确率和 10.9% WER。10 模型集成达到 61.1% 帧准确率和 10.7% WER。蒸馏后的单模型达到 60.8% 帧准确率和 10.7% WER。论文指出集成帧准确率提升的 80% 以上迁移到了单一蒸馏模型,WER 的提升同样迁移。从计算角度看,这是核心结果:近乎集成级的服务质量,而无需在推理时运行 10 个网络。

在 JFT 上,来源报告 100M 标记图像和 15,000 个标签,基线全网络训练约六个月。从该基线出发,61 个专家在几天而非数周内完成训练。通才-专家组合系统总体提升测试准确率 4.4% 相对值。表格还显示更多专家覆盖正确类别时收益更大;例如被 9 个专家覆盖的样本获得 16.6% 相对 top-1 准确率提升,而 10 个或更多覆盖时提升 14.1%。

证据支持两种计算模式:将大型或集成教师蒸馏为一个服务模型,以及使用独立专家利用并行训练而不让每条推理路径承担全部集成成本。

历史影响

这成为标杆式的昂贵训练/紧凑部署模式:在数据中心计算上花费教师成本,然后将行为压缩到符合生产延迟和内存预算的模型中。它还使输出分布包含超越标签的有用结构这一观念常态化,使教师生成的数据本身成为一种计算产物。

对于后来的高效推理工作,历史影响是直接的。蒸馏成为大型集成或基础模型与较小部署模型之间的标准桥梁。论文的 ASR 示例尤为重要,因为它围绕一个真实的、服务级的声学模型展开,而非仅围绕小基准。

局限

论文不是内核或加速器的重新设计,硬件描述不足。ASR 部分提到分布式 SGD,但未说明设备或规模。JFT 部分提到大量核心和参数服务器,但未说明硬件类型。学生只能继承教师暴露的内容,因此教师的质量和校准很重要。最后,论文未展示将 61 个 JFT 专家蒸馏回一个模型;它展示了专家集成收益并论证了独立专家训练易于并行化。

链接