When training large language models, processing the vast amounts of required data on a single accelerator quickly becomes infeasible. Data parallelism offers a fundamental strategy to distribute this workload across multiple processing units, typically GPUs or TPUs, allowing you to train models faster and handle larger effective batch sizes.
The core idea behind data parallelism is simple: replicate the entire model on each available worker (device), feed each worker a different slice (shard) of the input data batch, and then combine the results after each step. Let's break down the typical workflow:
Data parallelism workflow: The global batch is split, each worker computes gradients on its shard using a model replica. Gradients are synchronized via AllReduce, and each worker updates its model copy identically.
Frameworks like PyTorch (DistributedDataParallel
or DDP), TensorFlow (tf.distribute.Strategy
), and Horovod abstract away much of the complexity of implementing data parallelism. DeepSpeed also builds upon these concepts, adding further optimizations.
A typical structure using PyTorch's DDP might look like this (simplified):
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(rank, world_size):
# Initialize the process group (e.g., using NCCL backend for GPUs)
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def train_step(rank, world_size, model, data_loader, optimizer, criterion):
model.train()
# DDP automatically handles data sharding if DataLoader uses DistributedSampler
for data, target in data_loader:
data, target = data.to(rank), target.to(rank) # Move data to worker's device
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
# loss.backward() computes local gradients
loss.backward()
# DDP automatically triggers AllReduce during backward pass
# Gradients are synchronized across workers here
# optimizer.step() updates local model parameters using synchronized gradients
optimizer.step()
def main_worker(rank, world_size, model_definition, dataset):
setup(rank, world_size)
model = model_definition().to(rank)
# Wrap the model with DDP
ddp_model = DDP(model, device_ids=[rank])
# Use DistributedSampler for the DataLoader
sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=local_batch_size, sampler=sampler)
optimizer = optim.AdamW(ddp_model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss() # Example criterion
for epoch in range(num_epochs):
train_step(rank, world_size, ddp_model, data_loader, optimizer, criterion)
# Add validation, checkpointing etc.
cleanup()
# --- Main execution logic to spawn processes ---
# if __name__ == "__main__":
# world_size = torch.cuda.device_count()
# mp.spawn(main_worker, args=(world_size, model_def, dataset), nprocs=world_size, join=True)
Key operational points arise when using data parallelism at scale:
local_batch_size
, gradient accumulation is used. Workers process multiple smaller "micro-batches" sequentially, accumulating gradients locally before performing the AllReduce and optimizer step. This simulates a larger local batch size without increasing memory requirements, trading off computation time for memory. The AllReduce synchronization happens only once per N micro-batches, where N is the number of accumulation steps.Data parallelism is often the starting point for distributed training due to its relative simplicity and effectiveness, especially when computation per data point is high. However, when models grow too large for a single device's memory, you must combine data parallelism with model parallelism techniques, which we will discuss next.
© 2025 ApX Machine Learning