Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour with Batch Normalization

2018 Multi-GPU dense training 53 citations
Download PDF

Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour with Batch Normalization

Metadata

  • Slug: sync_batchnorm_2018
  • Year: 2018
  • Venue: queue duplicate/source coverage
  • Authors: Priya Goyal et al.
  • Reading status: read complete, duplicate source
  • Compute regime: Multi-GPU dense training
  • Primary sources: PDF, extracted text

Compute Setup

This card is a duplicate/source-coverage row for the one-hour ImageNet paper, not a separate SyncBatchNorm paper. The source explicitly describes the hardware: Facebook Big Basin servers, each with 8 NVIDIA Tesla P100 GPUs connected by NVIDIA NVLink, 3.2TB NVMe SSD local storage, a Mellanox ConnectX-4 50Gbit Ethernet network card, and Wedge100 Ethernet switches. The main result scales to 256 GPUs, i.e. 32 Big Basin servers, and the paper also reports a 352-GPU timing point.

The training workload is ResNet-50 on ImageNet: about 1.28M training images, 50,000 validation images, 90 epochs, and Nesterov momentum. The baseline uses 8 GPUs in one server, 32 images/GPU, total minibatch 256, and reference learning rate 0.1. The large run keeps 32 images/GPU and increases worker count to 256, giving total minibatch 8192 and reference learning rate 3.2.

Bottleneck

The bottleneck is distributed synchronous SGD at very large minibatch size. To scale data parallelism, each worker needs enough local computation to hide communication; but increasing total batch can change optimization behavior. The paper argues that for ImageNet up to 8k images, the main issue is early optimization rather than inherent generalization loss, provided the implementation preserves the right training objective.

Batch normalization is central. BN statistics depend on the per-worker minibatch size n, so changing n changes the loss. The paper therefore keeps n = 32 and scales by adding workers. It explicitly says BN statistics should not be computed across all workers, not only to reduce communication, but to maintain the same underlying loss function. This is the opposite of global SyncBatchNorm: local BN is part of the method's correctness argument.

Communication is still substantial. ResNet-50 has about 25M parameters, or 100MB in FP32. The paper estimates single-P100 backprop at 120ms and allreduce requiring roughly 2x bytes, leading to about 12.8Gbit/s peak bandwidth, or about 15Gbit/s with overhead. The 50Gbit Ethernet fabric is therefore sufficient if allreduce is pipelined with backpropagation.

Method Adaptation

The main algorithmic adaptation is the linear scaling rule: when total minibatch grows by k while per-worker batch stays fixed, scale the learning rate by k. For 8192 images, the reference learning rate is 0.1 times 8192/256 = 3.2. The second adaptation is gradual warmup: start at the small-batch learning rate and increase linearly to the scaled learning rate over 5 epochs. This prevents early instability when weights are changing rapidly.

The implementation details are also compute-structure adaptations. Loss normalization must use the total minibatch size, because allreduce primitives sum rather than average by default. Momentum correction is needed if using a variant where learning rate is absorbed into the momentum buffer. Data shuffling must be a single random order per epoch partitioned among workers, or multi-worker shuffling changes behavior.

The communication stack uses three phases: reduce gradients from 8 local GPUs into one server buffer, allreduce across servers, then broadcast back to local GPUs. NCCL handles intra-server collectives for buffers at least 256KB, and Gloo implements inter-server allreduce. The paper compares halving/doubling and ring-style allreduce and finds halving/doubling 3x faster than ring on 32 servers for the relevant buffer sizes. Local reduction, cross-server allreduce, and broadcast are pipelined where possible.

Evidence

The paper's headline result is ResNet-50 trained with minibatch 8192 on 256 GPUs in one hour while matching small-batch accuracy. The baseline 256-minibatch run has 23.60% top-1 validation error with standard deviation 0.12 over 5 runs. The 8k run with no warmup is worse at 24.84 +/- 0.37; constant warmup is worse still at 25.88 +/- 0.56; gradual warmup reaches 23.74 +/- 0.09, within 0.14 points of the baseline.

The batch-size sweep shows the useful range. With the linear scaling rule and gradual warmup, ImageNet error remains stable from minibatch 64 through 8k and then rises; beyond 64k the run diverges under the same rule. ResNet-101 also scales: minibatch 256 gives 22.08 +/- 0.06 top-1 error, while minibatch 8k gives 22.36 +/- 0.09, a 0.28 point increase. The ResNet-101 8k run takes 92.5 minutes on 256 Tesla P100 GPUs.

Transfer evidence uses Mask R-CNN on COCO. ResNet-50 ImageNet pretraining with minibatches 256, 2k, 4k, and 8k all transfers to nearly identical box and mask AP; the 16k model deteriorates, matching its ImageNet degradation. Mask R-CNN training itself also follows the linear scaling rule from 1 to 8 GPUs with stable AP.

Runtime evidence shows near-linear scaling. The time per iteration increases only 12% while total minibatch scales by 44x from 256 to 11264. Time per ImageNet epoch falls from over 16 minutes to about 30 seconds at 352 GPUs. Throughput scaling is about 90% efficient, with most allreduce communication hidden by pipelining and achieved on commodity Ethernet.

Historical Effect

The card is valuable for the compute history because it shows that large-batch training required hardware, optimization, and normalization details to align. It is not a generic "more GPUs" story: per-worker batch size, BN locality, warmup, allreduce scheduling, and loss normalization are all part of the result. For the SyncBatchNorm topic specifically, the paper is a counterexample: it keeps BN local rather than synchronizing it globally.

Limits

This is not an independent SyncBatchNorm source and should be consolidated with the ImageNet-in-one-hour card in higher-level summaries. The successful regime has a breaking point around 8k-16k for ResNet-50/ImageNet under this recipe. The communication analysis is specific to ResNet-50 size, P100 compute, and 50Gbit Ethernet; later models and optimizers can shift the bottleneck.

Links