Summary

A detailed technical walkthrough of data-parallel distributed training: replicating models across N nodes, splitting mini-batches, computing gradients independently, and synchronizing with MPI AllReduce before weight updates. The post covers the pebble graph mental model, Ring-AllReduce bandwidth/latency analysis, and interleaved communication-computation optimization as implemented in PyTorch DDP.

本文詳細介紹數據並行分佈式訓練:在 N 個節點上複製模型,分割批次後獨立計算梯度,最後透過 MPI AllReduce 同步。涵蓋鵝卵石圖心理模型、Ring-AllReduce 分析和計算通信交錯優化(PyTorch DDP 的實現方式)。

Prerequisites

  • Backpropagation and gradient computation — data parallelism synchronizes gradients, so understanding the gradient computation graph is essential
  • MPI communication primitives — AllReduce is the core collective operation; knowing ReduceScatter + AllGather decomposition helps understand bandwidth analysis
  • Floating-point arithmetic — non-associativity of floating-point explains why distributed training results may differ slightly from sequential training

Core Idea

Data parallelism works because the gradient of a mean loss is the mean of per-sample gradients — this separability means N nodes can compute gradients on N independent batch chunks and sum them via AllReduce before updating shared weights. The key optimization beyond naïve “compute then AllReduce” is interleaving: since each layer’s gradients are ready as soon as its backward pass completes, non-blocking AllReduce can overlap with computation of earlier layers’ gradients. PyTorch DDP implements this by registering grad hooks that fire per-parameter and batching into buckets to amortize communication overhead.

Results

ImplementationResNet training timeNotes
Non-interleaved DDPBaselineSequential compute then AllReduce
Interleaved DDP (NCCL)Significantly fasterCompute/communication overlap
Interleaved DDP (GLOO)FasterCPU-based collective

Limitations

  • Author-stated: Data parallelism requires large batch sizes to be effective; small batches limit the degree of parallelism
  • Author-stated: Works best for parameter-efficient models (high FLOPS/parameter ratio) like CNNs; less suited for large language models
  • Unstated: The blog post covers only data parallelism; pipeline and tensor parallelism (needed for models that don’t fit in one GPU’s memory) are covered separately

Reproducibility

  • Code: ShallowSpeed library by the author (GitHub: siboehm/shallowspeed)
  • Datasets: Standard examples (MNIST referenced)
  • Compute: Multi-node setup required; examples scale from 2 nodes upward

Insights

The “pebble graph” mental model for visualizing cached activations during forward/backward passes is unusually clean and transfers to understanding gradient checkpointing. The bandwidth-optimal proof for Ring-AllReduce (data transferred doesn’t scale with node count) is the key insight that made modern distributed training practical — more nodes don’t linearly increase communication cost.

Connections

Raw Excerpt

The most commonly used loss functions are means over the loss of individual samples… Conveniently, the gradient of a sum is the sum of the gradients of each term. Hence, we can calculate the gradients of the samples independently on each machine and sum them up before performing the weight update.