While data parallelism using pmap
, combined with techniques like gradient accumulation and checkpointing, allows scaling to larger datasets and effective batch sizes, we eventually encounter another fundamental limit: the model's parameters and intermediate activations become too large to fit into the memory of a single accelerator device. Even with mixed precision, the sheer size of state-of-the-art models can exceed the capacity of the largest available GPUs or TPUs.
When a single replica of your model cannot reside on one device, you need to partition the model itself across multiple devices. This is the core idea behind model parallelism. Unlike data parallelism, where the same model runs on different data slices, model parallelism involves splitting the model's components (layers or even parts of layers) and assigning them to different accelerators. Data flows between these devices as it progresses through the partitioned model.
There are two primary strategies for implementing model parallelism:
Tensor parallelism focuses on splitting the computation within a single layer (or operation) across multiple devices. This is particularly effective for layers with massive weight matrices, like the large linear layers found in Transformers.
Consider a large matrix multiplication, a fundamental operation in neural networks: Y=XW. If the weight matrix W is too large for one device, we can partition it across devices. For instance, W could be split column-wise across two devices, W=[W1,W2]. The input X is sent to both devices. Each device computes a part of the output: Y1=XW1 on device 1 and Y2=XW2 on device 2. The final result Y is obtained by concatenating the partial results: Y=[Y1,Y2].
A simplified view of tensor parallelism applied to matrix multiplication Y=XW. The weight matrix W is split into W1 and W2, and computations XW1 and XW2 occur on separate devices. The input X is duplicated, and the partial results Y1 and Y2 are gathered.
Alternatively, W could be split row-wise, requiring communication (e.g., an all-reduce operation) to sum partial results. More complex operations, like attention mechanisms, can also be parallelized this way.
Challenges:
In JAX, tensor parallelism typically involves using pmap
over a subset of devices assigned to the parallel layer, often combined with collective communication primitives like jax.lax.psum
or jax.lax.all_gather
applied over specific axis_name
dimensions defined in the pmap
. Libraries built on JAX may offer abstractions for common tensor parallelism patterns.
Pipeline parallelism takes a different approach: it partitions the model between layers, assigning sequential stages or blocks of layers to different devices. The output activations from one stage (on one device) become the input for the next stage (on another device).
Imagine a model with four layers (L1, L2, L3, L4) distributed across four devices (D1, D2, D3, D4).
A naive implementation processes one batch sequentially: D1 computes L1, sends activations to D2; D2 computes L2, sends to D3; D3 computes L3, sends to D4; D4 computes L4. During this process, only one device is active at any given time, leading to significant idle time, often called the "pipeline bubble".
Simple pipeline parallelism: Layers are assigned to sequential devices. Activations are passed between devices.
To improve efficiency and reduce the bubble, micro-batching is used. The input batch is split into smaller micro-batches. As soon as Device 1 finishes processing the first micro-batch, it sends the activations to Device 2 and immediately starts processing the second micro-batch. This allows multiple devices to work concurrently on different micro-batches, filling the pipeline.
Pipeline parallelism with micro-batching (MB1, MB2, etc.). Devices operate concurrently on different micro-batches, reducing idle time (represented by '-'). The initial and final bubbles remain but are amortized over many micro-batches.
Challenges:
In JAX, pipeline parallelism often involves manually placing different parts of the computation onto specific devices using jax.device_put
or backend-specific placement mechanisms. Communication between stages typically uses point-to-point transfers, which might be abstracted by libraries or require lower-level primitives depending on the setup (single-host vs multi-host). Managing the micro-batching and scheduling often falls to the user or a higher-level framework.
In practice, complex large-scale training setups often combine these strategies:
Implementing these advanced model parallelism strategies directly in JAX using only primitives like pmap
, lax.p*
collectives, and device placement is feasible but demanding. It requires a deep understanding of device topology, communication patterns, and careful orchestration. For this reason, practitioners often rely on higher-level libraries within the JAX ecosystem (like extensions to Flax or specialized libraries built for large models) that provide abstractions for common tensor and pipeline parallelism patterns, making them easier to apply.
While this section provides an overview, actually implementing robust and efficient model parallelism often involves using these specialized libraries or investing significant engineering effort to build custom solutions tailored to the specific model architecture and hardware setup. Understanding these strategies, however, is essential for choosing the right approach and tools when data parallelism alone is insufficient.
© 2025 ApX Machine Learning