jax.jit
jit
jit
jit
grad
jax.grad
grad
of grad
)jax.value_and_grad
)vmap
jax.vmap
in_axes
, out_axes
)vmap
vmap
with jit
and grad
vmap
pmap
jax.pmap
in_axes
, out_axes
)lax.psum
, lax.pmean
, etc.)pmap
with other Transformationspmap
ped FunctionsLet's put theory into practice and see how jax.pmap
works. In this section, we'll walk through distributing a simple computation across available devices. Remember, pmap
operates under the Single Program, Multiple Data (SPMD) principle: the same Python function code runs on all participating devices, but each device gets its own slice of the input data.
For these examples, JAX will automatically use all available devices (CPU cores if no GPU/TPU is found, or all available GPUs/TPUs). You can check the number of devices JAX recognizes using jax.device_count()
.
import jax
import jax.numpy as jnp
import numpy as np
# Check available devices
num_devices = jax.device_count()
print(f"Number of devices available: {num_devices}")
# Get the list of devices
devices = jax.devices()
print(f"Available devices: {devices}")
If you are running this on a machine with only a CPU, jax.device_count()
might return 1 or the number of CPU cores JAX is configured to see as separate devices (often 1 by default unless configured otherwise). If you have multiple GPUs or are on a TPU Pod slice, it will report the number of available accelerators. pmap
demonstrates its real power when num_devices > 1
. Even with one device, the code will run, but without parallel execution across devices.
Let's define a simple function we want to run in parallel: scaling an array by 2.
# A simple function to apply element-wise
def scale_by_two(x):
print("Compiling and running scale_by_two...") # To see when JAX traces/runs
return x * 2
Now, let's create some data. For pmap
, the input data needs to be sharded (split) across devices. The easiest way is to ensure the input array has a leading axis whose size equals the number of devices. Each slice along this axis will be sent to a corresponding device.
# Create data that can be split across devices
# Let's create an array with a leading dimension equal to num_devices
data_size_per_device = 4
total_data_size = num_devices * data_size_per_device
global_data = jnp.arange(total_data_size)
# Reshape the data so the first dimension matches the number of devices
# Each row will be sent to one device
sharded_data = global_data.reshape((num_devices, data_size_per_device))
print(f"Global data shape: {global_data.shape}")
print(f"Sharded data shape: {sharded_data.shape}")
print(f"Sharded data:\n{sharded_data}")
Next, we transform our function using jax.pmap
. By default, pmap
assumes the first axis (axis 0) of the input arguments should be mapped across devices.
# Apply pmap to our function
parallel_scale = jax.pmap(scale_by_two)
# Run the parallel function on the sharded data
result = parallel_scale(sharded_data)
# Let's examine the result
print(f"\nOutput shape: {result.shape}")
print(f"Output type: {type(result)}")
print(f"Output content:\n{result}")
You should observe:
result
has the same shape as the input sharded_data
. It's a JAX ShardedDeviceArray
(or similar type depending on JAX version), representing data distributed across the devices.scale_by_two
applied to the corresponding input slice on each device. For instance, if num_devices=2
, the first row of result
is [0*2, 1*2, 2*2, 3*2]
computed on device 0, and the second row is [4*2, 5*2, 6*2, 7*2]
computed on device 1.in_axes
What if our function takes multiple arguments, and we only want to parallelize over some of them, or parallelize over different axes? This is where the in_axes
argument to pmap
comes in.
in_axes
specifies which axis of each input argument should be mapped over. It can be an integer (the axis index), None
(broadcast the argument to all devices), or a tuple/list matching the number of arguments.
Let's modify our function to take two arguments: an array x
to be sharded and a scalar y
to be broadcasted.
# Function taking two arguments
def scale_and_add(x, y):
print(f"Compiling and running scale_and_add on device {jax.process_index()}...")
return x * 2 + y
# Data remains the same for x
# sharded_data = jnp.arange(num_devices * data_size_per_device).reshape((num_devices, data_size_per_device))
# Scalar value to be broadcasted
scalar_y = jnp.float32(100.0)
# Apply pmap, specifying how each argument is handled
# x (sharded_data): map over axis 0
# y (scalar_y): broadcast (send the same value to all devices)
parallel_scale_add = jax.pmap(scale_and_add, in_axes=(0, None))
# Run the parallel function
result_add = parallel_scale_add(sharded_data, scalar_y)
print(f"\nInput x shape: {sharded_data.shape}")
print(f"Input y: {scalar_y}")
print(f"\nOutput shape: {result_add.shape}")
print(f"Output content:\n{result_add}")
Here, in_axes=(0, None)
tells pmap
:
x
), split it along axis 0 and send each slice to a device.y
), send the entire value scalar_y
to every device.The result result_add
will again have the shape (num_devices, data_size_per_device)
, with each element calculated as x_slice * 2 + 100.0
on its respective device.
lax.psum
A common requirement in parallel computing is to aggregate results from different devices. For example, calculating the total sum or average of a value computed independently on each device. JAX provides collective primitives in jax.lax
for this purpose. These only work inside a pmap
-ped function.
Let's compute the sum of all elements processed across all devices.
# Function that includes a collective operation (sum across devices)
def scale_and_sum(x):
print(f"Compiling and running scale_and_sum on device {jax.process_index()}...")
scaled_x = x * 2
# Calculate the sum *on each device* first
local_sum = jnp.sum(scaled_x)
# Now, sum the local sums from all devices
global_sum = jax.lax.psum(local_sum, axis_name='devices')
# Note: Every device now holds the same 'global_sum'
return scaled_x, global_sum
# We need to tell pmap about the axis name used in psum
# The name 'devices' is arbitrary but must match between pmap and psum
parallel_scale_sum = jax.pmap(scale_and_sum, axis_name='devices')
# Run the parallel function
sharded_data = jnp.arange(num_devices * data_size_per_device).reshape((num_devices, data_size_per_device))
result_scaled, result_sum = parallel_scale_sum(sharded_data)
# Calculate the expected global sum manually for verification
expected_sum = jnp.sum(jnp.arange(total_data_size) * 2)
print(f"\nOutput scaled data shape: {result_scaled.shape}")
print(f"Output scaled data:\n{result_scaled}") # Still sharded
print(f"\nOutput sum shape: {result_sum.shape}") # Should be replicated
print(f"Output sum (replicated across devices):\n{result_sum}")
print(f"Expected global sum: {expected_sum}")
# Verify that the sum is the same on all devices
# (Accessing the data from the first device for demonstration)
print(f"Sum computed by pmap (device 0): {result_sum[0]}")
Key points in this example:
axis_name='devices'
in jax.pmap
. This name logically groups the devices participating in the parallel computation.jax.lax.psum(local_sum, axis_name='devices')
performs the sum reduction. It takes the local_sum
computed independently on each device and sums these values across all devices grouped under the name 'devices'
.global_sum
returned by psum
is replicated across all devices. Notice the shape of result_sum
is (num_devices,)
, and all its elements are identical, holding the true total sum. The result_scaled
array remains sharded as before.These hands-on examples illustrate the core mechanics of using jax.pmap
for data parallelism: preparing sharded data, mapping function arguments using in_axes
, executing the function in parallel, and using collective operations like psum
for cross-device communication. This forms the basis for scaling JAX computations, particularly large machine learning model training, across multiple accelerators.
© 2025 ApX Machine Learning