A 7.5 billion parameter model, trained with Adam in mixed precision, needs about 120 GB just to hold its own state. An H100 has 80 GB. An A100 has 40 or 80. The arithmetic does not work, and no amount of clever batching fixes it, because the problem is not the batch. The problem is the model.
This post is about the technique that broke that wall. It is called sharding, and the two names you will see most are ZeRO, from Microsoft's DeepSpeed team, and FSDP, the PyTorch-native version that ships in torch.distributed. Both rest on the same idea: when one GPU cannot hold the model, stop pretending it can.
Origin: why DDP runs out of room
Distributed Data Parallel, covered in an earlier post in this series on data parallel training, is the workhorse of multi-GPU training. Every GPU keeps a full copy of the model. Each worker processes a different slice of the batch, runs forward and backward, then all the GPUs sync their gradients with an all-reduce. Simple, fast, and built on the assumption that the model fits.
The model state has three parts. Parameters are the weights themselves. Gradients are the slope you compute during backward. Optimizer states are the bookkeeping your optimizer keeps between steps. With Adam, the most common choice, that bookkeeping is heavier than the model. The canonical breakdown from the ZeRO paper (Rajbhandari et al., 2019) is 16 bytes per parameter for mixed precision training: 2 bytes for the FP16 parameter copy, 2 bytes for the FP16 gradient, 4 bytes for an FP32 master parameter, 4 bytes for Adam's first moment, and 4 bytes for its second moment. Two thirds of that total is the optimizer. The weights and the gradients together are the smaller chunk.
DDP replicates all 16 bytes per parameter on every GPU. For a 7.5B model that is 120 GB on each card, plus activations, plus the workspace your kernels need. Past a point, no card holds it. Worse, every replica is duplicating data that nobody needed two copies of in the first place. The same optimizer states, the same gradients, the same parameters, copied N times across N GPUs.
In 2019 Samyam Rajbhandari and colleagues at Microsoft asked the obvious question. What if you split that state across the GPUs instead of replicating it? They called the result the Zero Redundancy Optimizer, or ZeRO, and shipped it inside DeepSpeed.
The three ZeRO stages, in order of how much they hurt
ZeRO comes in three flavors, layered. Each one shards a different part of the model state across the data-parallel workers. The DeepSpeed name for each piece is Pos, Pos+g, and Pos+g+p.
ZeRO-1, optimizer states only. Each GPU keeps the full parameters and full gradients but only a 1/N slice of the optimizer states. Since Adam states are two thirds of the total, this alone gives a roughly 4x memory cut. For a 7.5B model on 64 GPUs, that is 120 GB per card down to about 31 GB. The communication pattern barely changes from DDP. You still all-reduce gradients once per step.
ZeRO-2, optimizer states plus gradients. Now the gradients are also sharded. Each GPU only ever holds the slice it owns. The all-reduce is replaced by a reduce-scatter, which is the same total bandwidth but lands the result already partitioned across the workers. Memory drops by roughly 8x. For the same 7.5B model on 64 GPUs, you are now around 15 GB per card.
ZeRO-3, everything sharded. Parameters themselves are split. No GPU holds the full model, ever. Memory scales as 1/N up to a 64x cut on 64 GPUs, give or take. The 7.5B model that needed 120 GB per card now fits in about 2 GB. The cost is paid in communication. You have to gather the missing weights every time you need to compute with them, which means twice per layer, once on the forward pass and once on the backward. We will come back to what that costs.
The trick that makes ZeRO-3 affordable is that no GPU holds the full model at any one time, but every GPU can reconstruct any piece of it on demand by talking to its peers. The model is never lost, it is just stored in 64 places at once, with each place holding its own slice.
FSDP, or how PyTorch reimplemented ZeRO-3 natively
DeepSpeed sat outside PyTorch as a separate library. Useful, but a heavy dependency, and the team at Meta wanted the same capabilities first-class in the framework. In July 2021 they shipped Fully Sharded Data Parallel, pitching it as the PyTorch-native take on full parameter sharding, the idea that ZeRO-3 had popularised. The full paper landed in April 2023 (Zhao et al., 2023). The selling point was tighter integration with PyTorch's autograd, dispatcher, and memory allocator, plus a configuration surface that fit existing nn.Module code.
FSDP shards model state across the data-parallel workers in groups called sharding units. In the original FSDP, those units were called FlatParameters, big concatenated tensors per layer or per transformer block. The newer rewrite, FSDP2, switched to per-parameter sharding using DTensor, which is cleaner for mixed parallelism and for working with frozen weights (FSDP2 design). The user-visible idea is the same. You wrap your model, pick a unit boundary (usually each transformer decoder layer), and FSDP handles the rest.
The execution pattern is the part worth understanding, because it is what costs you the wall-clock time.
Gather, compute, discard
Picture a transformer with 32 decoder layers, sharded across 8 GPUs. Each GPU permanently owns 1/8 of every layer's weights. When forward begins:
- Layer 0 needs its full weights. All 8 GPUs run an all-gather, each sending their 1/8 slice and receiving the other 7 slices in return. Now every GPU has the full Layer 0.
- Every GPU computes forward through Layer 0 on its own slice of the batch.
- Layer 0 weights are immediately discarded. Each GPU drops the slices it does not own. Memory is returned to the pool.
- Repeat for Layer 1, Layer 2, all the way through the stack.
Backward runs the same dance in reverse. Each layer's weights are all-gathered again because they are not in memory anymore. Gradients are computed, then reduce-scattered, so each GPU comes out holding only the 1/N slice of gradients that matches the parameter slice it owns. The optimizer step then runs locally on that slice, against the matching slice of optimizer state. Nothing is ever fully assembled in one place except the layer currently doing work.
The thing that makes this tolerable in practice is overlap. While Layer 0 is computing, FSDP is already prefetching Layer 1's all-gather. The communication and the compute run on different streams, and a well-tuned setup hides most of the gather latency behind the matmul. The PyTorch FSDP tutorial lists backward_prefetch=BACKWARD_PRE as the setting that pulls this off on the backward pass. Get the auto-wrap policy wrong (the rule that says where one shard unit ends and the next begins) and you can lose the overlap entirely. The whole model collapses into one giant unit, the gather happens once and is huge, and your throughput tanks.
The communication math, plainly
In a ring-pipelined implementation, both reduce-scatter and all-gather move Ψ total bytes per call for a model of Ψ parameters, distributed across the N workers. The DDP baseline, which is one all-reduce per step, costs 2Ψ. ZeRO-1 and ZeRO-2 are also 2Ψ. ZeRO-3 and FSDP are 3Ψ, an extra Ψ because you have to all-gather the parameters twice, once for forward and once for backward, instead of once (DeepSpeed ZeRO++ blog).
That is the deal. A 50 percent jump in communication volume buys you roughly 1/N memory per GPU. Whether you take that deal depends entirely on whether you can afford the extra traffic. Inside a node with NVLink, where bandwidth is hundreds of gigabytes per second, the answer is usually yes. Between nodes on a slow interconnect, the answer is often no, and FSDP can run several times slower than DDP on the same hardware for a model small enough to fit.
When FSDP beats DDP, and when it does not
The honest version of the choice looks like this.
Use DDP when the model fits. A 1B parameter model in mixed precision Adam needs 16 GB of state. An A100 holds it with room for a healthy batch. DDP gives you a single all-reduce per step and gets out of the way. Real benchmarks back this up: for sub-3B-parameter models, DDP and FSDP land in the same throughput range, and DDP usually wins on the simplest setups. FSDP's extra all-gathers buy you nothing if there was nothing to spill.
Use FSDP when the model does not fit. Past 3B or 7B parameters depending on the GPU, you are forced into sharding to fit at all. FSDP becomes the floor of what is achievable, not the ceiling. A 70B model in mixed precision Adam needs over a terabyte of state. There is no DDP route to that. The communication cost is real, but the alternative is not running.
Use HSDP (hybrid sharding) when you have a fast inner network and a slower outer one. HSDP shards inside a node, where NVLink is fast, and replicates across nodes the way DDP does, where the network is slow. The ZeRO papers call this two-dimensional partitioning. The PyTorch API exposes it as HYBRID_SHARD. It is the de facto default for multi-node training on commodity Infiniband or Ethernet (FSDP sharding strategies).
Use tensor parallel or pipeline parallel when even FSDP is not enough. When the activations themselves are the memory hog, or when the model is so big that the gather of a single layer is too large to handle, you start splitting the layer itself across GPUs. Those techniques get covered in other posts in this distributed-training series. They are not free either. They demand more bandwidth, just of a different kind.
The general rule, useful even when it is wrong: sharding trades memory for communication, and you choose the strategy by deciding which one is your bottleneck.
Future and impact: the wall is moving, not gone
The original ZeRO paper aimed at a trillion parameter model. Two years later ZeRO-Infinity (Rajbhandari et al., 2021) extended the same trick across CPU memory and NVMe SSDs, letting a single DGX-2 node fine-tune trillion-parameter models by treating the SSD as the cold tier. The cost is throughput, the win is fitting models on hardware that has no right to hold them.
The PyTorch team rewrote FSDP into FSDP2 in 2024, using DTensor for cleaner per-parameter sharding, fewer constraints around frozen weights and LoRA adapters, and tighter composition with tensor parallel. SimpleFSDP and TorchTitan extend the line further, with torch.compile integration that hides communication better and a stack designed for production pretraining.
The bigger trend underneath: training stacks have stopped picking one parallelism. Modern pretraining runs combine FSDP across nodes, tensor parallel within a node, sequence parallel for activations, and pipeline parallel across racks, picking each one for the dimension it shards best. ZeRO-3 stopped being the headline and became a layer in a five-layer cake. The wall is still there. We just have more ways to climb it.
The trick worth remembering, the thing that made all of this possible, is the one in the ZeRO paper's title. The redundancy was the bug. Every GPU holding a full copy of state nobody needed twice was costing the field nearly an order of magnitude of capability. The fix was to stop. Everything else is bookkeeping.
For the upstream open-source writing this post drew from, see the Conscious Engines deep dive on FSDP, which is more hands-on than the official docs about what fails in practice.
Council summary
The argument is built on one number: 16 bytes per parameter for mixed-precision Adam, two thirds of it optimizer bookkeeping. DDP replicates all 16 across every GPU and runs out of room; ZeRO and FSDP partition those bytes across N workers and reach 1/N memory at the cost of a single extra all-gather per layer pass. The three stages, the gather-compute-discard cycle, and the 2Ψ versus 3Ψ communication ledger all trace back to that one accounting decision in the 2019 ZeRO paper. The practical takeaway is that sharding is a deliberate trade: memory for bandwidth, and the right answer depends on which one you are actually short of.
Comments