GPT Adventures — Part 4: Distribute It

In this part, I will explore the various methods of, and factors influencing, distributed training. Things I would like to understand more include exploring how training speed depends on:

And, of course, the interdependencies between those factors.

Training Loop Refactor

We have to refactor the training loop in order to make it compatible with distributed training. Rather than a single for loop inside the script’s main method, main is now only used to configure and spawn the distributed process group. Each rank executes a training loop analogous to the original, single-device loop, with each rank allocated an equal, disjoint shard of the training set data. In our case, we shard the data using PyTorch’s DistributedSampler.

Per-device batch size is also adjusted to keep the same overall batch size; this means that the amount of per-forward/backward-pass compute is reduced, which reduces the amount of compute that can be overlapped with comms as the number of devices increases. Depending on the distributed training scheme and the topology of our compute resource, the overall comms latency can increase, creating a double-whammy of sorts; we’ll need to keep a close eye on this to ensure we are not limited by our comms bandwidth.

Distributed Data Parallel

Let us start with the conceptually simplest of frameworks. In DDP, we simply fully replicate the model on each of the N accelerators, and have every accelerator process a different part of the overall batch. Once all machines have finished processing and have produced gradients for their shard of the batch, we synchronise the gradients by performing an AllReduce.

Note that there actually are many smaller AllReduces. When running experiments, it struck me that in the initial implementation I never had to specify how to partition the gradient communications; clearly, the NCCL backend did it for me, but it was not entirely clear how I ended up with the number of AllReduce calls I did.

It turns out that NCCL performs something called gradient bucketing. During the backward pass, it accumulates the gradient artifacts as they become available into a flat buffer. Once the size of that buffer is exceeded, an AllReduce fires and gradients are synchronised; the default bucket size is 25MB. Below we can observe the effect of varying the bucket size. While medium and large buckets perform similarly, if the bucket size becomes too small, the overhead of launching kernels dominates — using 1MB buckets was considerably slower. Note also that when the parameters of a single, particularly large weights matrix can exceed the capacity of a bucket on a single increment — hence even in the small bucket example some of the AllReduces are larger and take longer to complete. In particular, the first and last buckets correspond to the embedding/unembedding matrices, and these immediately fill the bucket for all three sizes tested.

(Note: the bottom row is a separate CUDA stream, and this is where we can see the collective operations being overlapped with the backward pass compute.)

Gradient bucketing — effect of varying bucket size on AllReduce patterns.

This flexibility in how we synchronise gradients stems from the fact that, in DDP, while gradients must be fully in sync before we can perform the optimiser step, the gradient tensors themselves are not on the critical path — meaning that they are not inputs into any of the subsequent backward pass computations; they are independent end-products in their own right. This (and many other concepts) is excellently laid out in Part 5 of The Scaling Book.

Interconnects — PCIe, DCI and Host Bridges

All other things being equal, the duration of an AllReduce should be inversely proportional to the bandwidth of the interconnect over which it is being performed. To see how well this holds, we can experiment with different types of interconnects. For example, one of the GPU slices I have access to is an 8×A100 slice, composed of two subgroups of 4 GPUs each. Within each subgroup, each GPU is connected to a CPU host via PCIe, but GPUs from different subgroups can only communicate over a (much slower) DCI connection between the CPU hosts, as shown below.

GPU topology diagram showing PCIe connections within subgroups and DCI host bridge connections between subgroups.

Therefore, a distributed training run on two GPUs within the same subgroup should be faster (owing to the decreased AllReduce durations) than an equivalent run performed on two GPUs from disparate subgroups (which requires data to egress the node and traverse the DCI host-to-host bridge).

We can test this by profiling a training run as usual. But we can also get a quicker, cleaner, lower-level insight into how the latency of the collectives themselves is affected by the networking configuration. NCCL comes with a suite of utilities for testing collective op latencies. One of these tools is:

all_reduce_perf -b 1M -e 1G -f 2 -g 2

where -b and -e set the size of the test data packet being AllReduced, -f sets the increment factor between packets, and -g sets the number of GPUs to use in the test (i.e. here we are using packets of size between 1MB and 1GB, doubling the message size between iterations, and the communication is between two GPUs).

First, we run nvidia-smi topo -m to verify which devices are PCIe connected (NODE) and which are only connected via the host (SYS):

GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 NIC0 NIC1
GPU0 NODE NODE NODE SYS SYS SYS SYS SYS SYS
GPU1 NODE NODE NODE SYS SYS SYS SYS SYS SYS
GPU2 NODE NODE NODE SYS SYS SYS SYS SYS SYS
GPU3 NODE NODE NODE SYS SYS SYS SYS SYS SYS
GPU4 SYS SYS SYS SYS NODE NODE NODE NODE NODE
GPU5 SYS SYS SYS SYS NODE NODE NODE NODE NODE
GPU6 SYS SYS SYS SYS NODE NODE NODE PHB PHB
GPU7 SYS SYS SYS SYS NODE NODE NODE NODE NODE

Note how the block matrix reflects the configuration shown in the diagram above.

GPUs 0–1 (PCIe-connected):

Size (B) Count Time OOP (µs) AlgBW OOP (GB/s) BusBW OOP (GB/s) Time IP (µs) AlgBW IP (GB/s) BusBW IP (GB/s)
8,388,608 2,097,152 580.2 14.46 14.46 632.5 13.26 13.26
16,777,216 4,194,304 947.6 17.71 17.71 926.0 18.12 18.12
33,554,432 8,388,608 1798.5 18.66 18.66 1802.0 18.62 18.62
67,108,864 16,777,216 3528.4 19.02 19.02 3520.6 19.06 19.06
134,217,728 33,554,432 6924.6 19.38 19.38 6914.9 19.41 19.41
268,435,456 67,108,864 13,698 19.60 19.60 13,699 19.59 19.59
536,870,912 134,217,728 27,104 19.81 19.81 27,090 19.82 19.82
1,073,741,824 268,435,456 53,866 19.93 19.93 54,288 19.78 19.78

Avg bus bandwidth: 18.51 GB/s

GPUs 3–7 (host bridge / DCI-connected):

Size (B) Count Time OOP (µs) AlgBW OOP (GB/s) BusBW OOP (GB/s) Time IP (µs) AlgBW IP (GB/s) BusBW IP (GB/s)
8,388,608 2,097,152 679.7 12.34 12.34 643.6 13.03 13.03
16,777,216 4,194,304 1274.7 13.16 13.16 1250.0 13.42 13.42
33,554,432 8,388,608 2466.6 13.60 13.60 2482.7 13.52 13.52
67,108,864 16,777,216 4928.8 13.62 13.62 4898.2 13.70 13.70
134,217,728 33,554,432 9763.1 13.75 13.75 9762.8 13.75 13.75
268,435,456 67,108,864 19,506 13.76 13.76 19,517 13.75 13.75
536,870,912 134,217,728 38,960 13.78 13.78 38,888 13.81 13.81
1,073,741,824 268,435,456 77,783 13.80 13.80 77,824 13.80 13.80

Avg bus bandwidth: 13.54 GB/s

We will try to unpick the exact bandwidth numbers later, but for now we just note that communicating over the host bridge is slower. This confirms that we should expect the AllReduces to take longer on the latter device pair. Given that the AllReduce operations on the PCIe-connected pair are already nearly densely packed (see below), we expect the AllReduce ops to start increasing the overall latency.

Even though, as previously mentioned, AllReduces are not on the critical path of the backward pass itself, they must finish before the optimiser step can be taken. While even in the PCIe case the final AllReduce can only start after the backward pass propagates all the way through the final embedding layer, and therefore extends beyond (i.e. is not overlapped with) backward compute, if many AllReduces are queued up and haven’t finished by the time the embedding backward finishes, the final AllReduce may not even start for a while after. This is exactly the pattern seen in the traces below.

CUDA timeline showing AllReduce operations and their overlap (or lack thereof) with backward compute.

Even though the final AllReduce will necessarily introduce a gap between the end of the backward pass and the optimiser step, for the PCIe case in the example above this is relatively minor and more than made up for by the fact that our compute-limited training loop now does ~half the work per device, leading to a nearly 2× speed-up.

However, as the number of GPUs increases, the PCIe connection quickly becomes inadequate.

Increasing the Number of Devices

As the number of GPUs involved in computation increases, the size of the gradients being transferred also goes up; each gradient shard is \(\frac{2}{N}S\) bytes, and \(N-1\) shards (all but the one already present on the device) need to be communicated to each device. Evaluating the ratio of total comms size for \(N=2\) and \(N=4\), we get a factor of 1.5.

Moreover, since the host-mediated communication does not allow different pairs of devices to communicate in parallel at full interconnect bandwidth each — while in two-device DDP two comms streams shared the bus, in a four-device DDP this bus now serves all four streams — so each communication stream gets half as much bandwidth allocated. This means we’d expect the AllReduce latency to increase by at least 3× for the four-device configuration.

In reality, we see closer to — likely due to imperfect bandwidth contention handling in the shared bus.

The amount of backward pass compute is also less when the same effective batch is shared between four devices, leading to a prolonged stall waiting for the NCCL collectives to complete. Overall, this leads to a significant decrease in training speed; increasing the number of devices is actually counterproductive.

Training performance comparison across device counts — adding GPUs hurts when communication bottlenecks on PCIe.

Sanity Checks

So far, we have been looking at trends: do collective ops get faster with increased bandwidth; does compute time reduce as work is shared between more GPUs. But before moving on, let us be more quantitative with respect to the numbers we are seeing. We know the maximum bandwidth numbers for individual configurations — but is our training loop anywhere near those numbers? How far can we push our hardware?

There is some overhead associated with launching and routing collective ops, so for the purpose of comparison it is best to marginalise these by increasing bucket size; we saw that 100MB is a sensible value for bucket size, so we will use these instead of the default 25MB.

Testing on two different GPU boxes, each supposedly identically networked in terms of NVLink count (12 per device) and generation (3.0) and the NVSwitch generation (2.0), we get different numbers for AllReducing a 100MB bucket: 1250µs and 800µs for the 40GB SKU and 80GB SKU, respectively. For 100MB buckets, this implies around 80 GB/s and 120 GB/s, respectively — neither anywhere close to the theoretical maximum of 300 GB/s. What gives?

AllReduce latency comparison between the two GPU boxes.

First, let us consider the differences between the boxes themselves. At face value, only the HBM capacity and HBM bandwidth is different — and neither of those two parameters should matter; HBM bandwidth is orders of magnitude larger than the transfer speeds in question, so writing to/from HBM is not the limiting factor here. Since the count and generation of the NVLinks is identical, the only remaining difference is the NVSwitch. While this reports as being the same generation as well, the NVSwitch silicon varies in-generation, whereas, as far as I am aware, there are no in-generation differences between NVLink interconnects. This example shows that all components of the network must be considered when optimising throughput — the system is only as fast as the slowest link in the chain.

As for the absolute shortfall with respect to the theoretical maximum, this is a combination of two factors. Firstly, AllReduces observed in training suffer from contention with other operations taking place on-chip. We can see this is the case by once again resorting to all_reduce_perf. By performing the op in isolation, the same 100MB AllReduce reports 100 GB/s and 180 GB/s for the two GPU boxes.

Secondly, the 100MB bucket size is still relatively small; for bucket sizes of the order of 1GB and more, we get 130 GB/s and 205 GB/s, respectively.

Finally, AllReduces themselves require some computation — such as summing the local and arriving shards — and therefore are reliant on SM slots being available. The computations are not particularly heavy-weight, but they do add some scheduling contention and latency. To expose this difference, let us construct a bandwidth test that does not require any computation:

import torch
import time

a = torch.ones(1024*1024*1024, dtype=torch.float32, device='cuda:0')
b = torch.empty(1024*1024*1024, dtype=torch.float32, device='cuda:1')

# warm up
for _ in range(5):
    b.copy_(a)

torch.cuda.synchronize()

N = 100

start = time.perf_counter()

for _ in range(N):
    b.copy_(a)

torch.cuda.synchronize()

This test reports 180 GB/s and 280 GB/s. For the more recent box, we nearly saturate the NVLink bandwidth; this suggests that the other box, networked with identical interconnects, must indeed be limited by a slower NVSwitch internal bandwidth.

So, what kind of speed-up can we get from going from a single GPU up to eight? Well, since almost all of the AllReduces are overlapped with compute, unsurprisingly we get a near-linear speed-up with the number of devices.

Near-linear scaling efficiency with NVSwitch — AllReduces largely overlapped with compute.

FSDP

In DDP, we feed identical replicas of the model different partitions of the training data, and synchronise the resulting gradients. For very large models, the model itself (alongside the associated gradients and optimiser states) can take up a significant proportion of the on-chip memory. Replicating all this on each chip can leave little room for data, and since the size of the gradients which need to be AllReduced is fixed, this can lead to the ratio of compute to data transfer being insufficient to overlap the transfers and, consequently, the system being memory bandwidth rather than compute limited.

To avoid this, we can distribute the model itself between GPUs. At any given point, each GPU only holds a fraction of the model’s weights, and when a particular chip requires the entire weights matrix to carry out a computation, chips communicate via the networking hardware to gather all the shards onto the chip in question. If the bandwidth is sufficient, this AllGather operation will be performed ahead of time, so that a continuous compute pipeline is maintained.

Also, since each GPU only holds a smaller proportion of weights instead of a full replica, only the gradients corresponding to the local weights are needed to perform the optimizer step, so there is no need to AllReduce the entire gradient — a ReduceScatter operation, which reduces and materialises just the gradients corresponding to the local shard of model parameters, is performed instead.

In our case, however, the model itself only occupies a small fraction of the memory. We have ample space to fit a batch size large enough to mask any AllReduce computations. And while it is true that the final AllReduce cannot be overlapped with compute, and that in FSDP this AllReduce is replaced with a smaller ReduceScatter, FSDP also has to perform a separate operation that cannot (or at least not easily) be overlapped with compute: the initial AllGather of the first layer in the forward pass. All in all, in our case the difference is marginal.

DDP vs FSDP — performance comparison for our model size.

For completeness, here are some notes regarding FSDP:

  • We must specify how to shard the model; a sensible choice is to shard every layer, meaning with \(N\) devices each device holds \(\frac{1}{N}\)th of each layer, and during the computation of layer \(l\) the weights corresponding to layer \(l+1\) are being gathered. This means that at any given time only approximately \(\frac{2}{L} \times \text{model\_size}\) is materialised on each GPU. If we do not specify the sharding strategy, PyTorch will fall back to sharding the entire model, meaning that the entire model will be gathered at the start of the forward pass, entirely defeating the purpose of applying FSDP. Sharding can be specified via the wrapping policy:
auto_wrap_policy = partial(
    transformer_auto_wrap_policy, transformer_layer_cls={TransformerBlock}
)

distributed_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, use_orig_params=True)
  • Some operations used in the process of sharding/unsharding the model introduce graph breaks, so in order to torch.compile the sharded model we have to allow graph breaks by setting full_graph=False.
  • FSDP has a settable mixed precision policy, settable via the mixed_precision argument during initialisation. This defaults to fp32, so some care is advisable when using autocast. FSDP and autocast are perfectly compatible, but, particularly if working in lower precision, it is worth checking that we are not casting dtypes back and forth for no good reason.

Gradient Checkpointing

By default, activations generated during the forward pass are cached for later use in the backward pass. While this does save compute — as otherwise those activations would have to be recomputed — the memory footprint of storing activations is significant. In fact, without any treatment, this footprint can get so large that it dwarfs every other memory allocation. To see this, let us carry out a quick back-of-the-envelope estimation of different memory requirements.

Component Footprint
Model weights (120M × bf16) ~240 MB
Gradients (one per parameter) ~240 MB
Optimiser states (two momentum params per parameter) ~480 MB
Total (model + gradients + optimiser) ~960 MB

And yet, when we attempt to run with a batch size of 512 and context length of 1024 (which, with DDP over 8 GPUs, gives a per-chip batch size of 64), we hit out-of-memory errors on an 80GB device.

In transformer layers, we have some fairly large matrix multiplication ops:

Layer Activation Shape Size (bf16)
Attention Wqkv projection (64, 1024, 768×3) ~300 MB
Feed-forward Wup projection (64, 1024, 3072) ~400 MB
Vocab projection (64, 1024, 50256) ~6.6 GB

The attention and feed-forward activations scale with the number of layers (12 in our case), and even looking at just some of the larger activation tensors, we see that the required memory quickly grows to fill and exceed the chip’s memory.

Instead of saving all the activations, we can choose to not save anything at all and recompute, or to save some but not all activations. Recomputation will of course generate additional overhead compared to a fully cached training run; however, since the amount of memory required in fully cached runs is so asymmetrically higher, by reducing the amount of caching we may be able to fit a much larger batch on the chip, and thus increase token throughput overall. The caveat is that a larger batch size is not always what we want, and we should think carefully whether this increased throughput is what we really care about. For now, we are only seeking to maximise throughput.

For simplicity, it is common to checkpoint model per-layer. We can achieve this with:

if args.checkpoint_grads:

    non_reentrant_wrapper = partial(
        checkpoint_wrapper,
        checkpoint_impl=CheckpointImpl.NO_REENTRANT,
    )

    check_fn = lambda layer: isinstance(layer, TransformerBlock)

    apply_activation_checkpointing(
        distributed_model,
        non_reentrant_wrapper,
        check_fn,
    )

This traverses the model and applies checkpointing to any nn.Module instance picked up by check_fn.

A more optimised treatment would be to cache operations that are expensive to compute but produce relatively small activations, and recompute otherwise. For example, in the feed-forward layer, the up- and down-projection require the same amount of computation, but the up projection produces activations that are four times larger than the down-projection. Similarly, the un-embedding produces an enormous matrix owing to the vocab_size — it may be worth experimenting with not caching those activations.

To be more selective about how we checkpoint, we can either perform some logic inside check_fn to verify the FLOPs/footprint of the individual module, or we can wrap layers we wish to checkpoint inside custom classes. For example, if we wish to checkpoint the vocab projection op:

class LayerWrapper(nn.Module):

    def __init__(self, d_attention, vocab_size):
        super().__init__()
        self.large_projection = nn.Linear(d_attention, vocab_size)

    def forward(self, x):
        x = self.large_projection(x)
        return x

# Modify the lambda to also checkpoint LayerWrapper
check_fn = lambda layer: isinstance(layer, (TransformerBlock, LayerWrapper))

For our model specifically though, even without checkpointing, we can fit a large enough batch to saturate kernels; fitting an ever larger batch not only does not help (as the operations are pipelined in sequential order), it actually hurts throughput due to re-computation. For completeness, below is an example of running with and without checkpointing and its effect on memory allocation and throughput.

Allocated GPU memory — with and without gradient checkpointing.

Throughput — with and without gradient checkpointing.