jax.jit
jit
jit
jit
grad
jax.grad
grad
of grad
)jax.value_and_grad
)vmap
jax.vmap
in_axes
, out_axes
)vmap
vmap
with jit
and grad
vmap
pmap
jax.pmap
in_axes
, out_axes
)lax.psum
, lax.pmean
, etc.)pmap
with other Transformationspmap
ped Functionsin_axes
, out_axes
)While jax.vmap
automatically handles adding a batch dimension, you often need more control. What if only some arguments represent batches of data? What if your data isn't conveniently batched along the first axis (axis 0)? This is where the in_axes
and out_axes
arguments come into play. They provide fine-grained control over how vmap
transforms your function's inputs and outputs.
in_axes
The in_axes
argument tells vmap
which axis of each input argument should be mapped over (vectorized). It's typically provided as a tuple or list, where the length matches the number of positional arguments of the function being vectorized.
Each element in the in_axes
tuple corresponds to an argument:
i
means that the i-th axis of the corresponding argument is the batch dimension. vmap
will effectively iterate over slices along this axis.None
means that the corresponding argument should not be mapped over. Instead, the entire argument will be broadcast and reused across all vectorized calls. This is useful for parameters or constants that are shared across the batch.Let's look at an example. Suppose we have a function that adds a scalar value to each element of a vector:
import jax
import jax.numpy as jnp
def add_scalar(vector, scalar):
# Adds a scalar to every element of a vector
return vector + scalar
# Example data
vectors = jnp.arange(12).reshape(4, 3) # A batch of 4 vectors, each size 3
scalar_val = 100.0 # A single scalar value
If we want to apply add_scalar
to each vector in our vectors
batch, using the same scalar_val
for every vector, we tell vmap
to map over axis 0 of vectors
but not to map over scalar_val
:
# Map over axis 0 of the first argument (vectors)
# Broadcast the second argument (scalar_val)
vectorized_add_scalar = jax.vmap(add_scalar, in_axes=(0, None))
result = vectorized_add_scalar(vectors, scalar_val)
print("Input vectors (shape {}):\n{}".format(vectors.shape, vectors))
print("Input scalar:", scalar_val)
print("Result (shape {}):\n{}".format(result.shape, result))
Input vectors (shape (4, 3)):
[[ 0 1 2]
[ 3 4 5]
[ 6 7 8]
[ 9 10 11]]
Input scalar: 100.0
Result (shape (4, 3)):
[[100. 101. 102.]
[103. 104. 105.]
[106. 107. 108.]
[109. 110. 111.]]
As you can see, vmap
applied add_scalar
four times. In each application, it took one row (axis 0) from vectors
and the entire scalar_val
. The output result
collects these individual results, stacked along axis 0, matching the input batch dimension.
What if we had a batch of scalars as well, and wanted to add the i-th scalar to the i-th vector? We would specify in_axes=(0, 0)
:
scalars = jnp.array([100., 200., 300., 400.]) # A batch of 4 scalars
# Map over axis 0 of the first argument (vectors)
# Map over axis 0 of the second argument (scalars)
vectorized_add_scalar_batch = jax.vmap(add_scalar, in_axes=(0, 0))
result_batch = vectorized_add_scalar_batch(vectors, scalars)
print("Input vectors (shape {}):\n{}".format(vectors.shape, vectors))
print("Input scalars (shape {}):\n{}".format(scalars.shape, scalars))
print("Result (shape {}):\n{}".format(result_batch.shape, result_batch))
Input vectors (shape (4, 3)):
[[ 0 1 2]
[ 3 4 5]
[ 6 7 8]
[ 9 10 11]]
Input scalars (shape (4,)):
[100. 200. 300. 400.]
Result (shape (4, 3)):
[[100. 101. 102.]
[203. 204. 205.]
[306. 307. 308.]
[409. 410. 411.]]
Notice that JAX automatically handled the broadcasting of the scalar scalars[i]
across the elements of vectors[i]
within each mapped function call.
You can also map over axes other than 0. For instance, in_axes=(1, None)
would map over axis 1 of the first argument. This requires the shapes to align correctly. The size of the mapped axes across all mapped input arguments must be the same. JAX will raise an error if they don't match.
# Example: Map over axis 1 of vectors
vectors_T = vectors.T # Shape (3, 4)
# Map over axis 1 (columns) of vectors_T, broadcast scalar_val
vectorized_add_scalar_axis1 = jax.vmap(add_scalar, in_axes=(1, None))
result_axis1 = vectorized_add_scalar_axis1(vectors_T, scalar_val)
print("Input vectors_T (shape {}):\n{}".format(vectors_T.shape, vectors_T))
print("Input scalar:", scalar_val)
# The output batch dimension (size 4) will be axis 0 by default
print("Result (shape {}):\n{}".format(result_axis1.shape, result_axis1))
Input vectors_T (shape (3, 4)):
[[ 0 3 6 9]
[ 1 4 7 10]
[ 2 5 8 11]]
Input scalar: 100.0
Result (shape (4, 3)):
[[100. 101. 102.]
[103. 104. 105.]
[106. 107. 108.]
[109. 110. 111.]]
Even though we mapped over axis 1 of the input vectors_T
, the resulting batch dimension in the output is axis 0 by default. We can control this using out_axes
.
out_axes
By default, vmap
stacks the results along axis 0. The out_axes
argument allows you to specify which axis in the output should correspond to the mapped dimension.
Let's consider a function that processes a vector and returns a transformed vector:
def process_vector(v):
# Example: Double the vector
return v * 2
input_vectors = jnp.arange(12).reshape(4, 3) # Batch of 4 vectors
Using the default out_axes=0
:
# Default: map input axis 0 to output axis 0
vectorized_process_default = jax.vmap(process_vector, in_axes=0, out_axes=0)
result_default = vectorized_process_default(input_vectors)
print("Input vectors (shape {}):\n{}".format(input_vectors.shape, input_vectors))
print("Result (out_axes=0, shape {}):\n{}".format(result_default.shape, result_default))
Input vectors (shape (4, 3)):
[[ 0 1 2]
[ 3 4 5]
[ 6 7 8]
[ 9 10 11]]
Result (out_axes=0, shape (4, 3)):
[[ 0 2 4]
[ 6 8 10]
[12 14 16]
[18 20 22]]
The output shape is (4, 3)
, where 4 is the batch dimension placed at axis 0.
Now, let's specify out_axes=1
:
# Map input axis 0 to output axis 1
vectorized_process_out1 = jax.vmap(process_vector, in_axes=0, out_axes=1)
result_out1 = vectorized_process_out1(input_vectors)
print("Input vectors (shape {}):\n{}".format(input_vectors.shape, input_vectors))
print("Result (out_axes=1, shape {}):\n{}".format(result_out1.shape, result_out1))
Input vectors (shape (4, 3)):
[[ 0 1 2]
[ 3 4 5]
[ 6 7 8]
[ 9 10 11]]
Result (out_axes=1, shape (3, 4)):
[[ 0 6 12 18]
[ 2 8 14 20]
[ 4 10 16 22]]
The output shape is now (3, 4)
. The original vector dimension (size 3) is now axis 0, and the mapped batch dimension (size 4) has been placed at axis 1.
out_axes
with Multiple Return Values (PyTrees)If your function returns multiple values (e.g., in a tuple or dictionary, which JAX calls PyTrees), out_axes
can also be a PyTree structure matching the output. This allows you to specify different output axes for different return values.
def process_vector_pytree(v):
# Returns a dictionary with sum and doubled vector
return {'sum': v.sum(), 'doubled': v * 2}
# Map input axis 0. Place 'sum' batch axis at 0, 'doubled' batch axis at 1.
vectorized_pytree = jax.vmap(
process_vector_pytree,
in_axes=0,
out_axes={'sum': 0, 'doubled': 1}
)
result_pytree = vectorized_pytree(input_vectors)
print("Input vectors (shape {}):\n{}".format(input_vectors.shape, input_vectors))
print("Result PyTree:")
print(" Sum (shape {}):\n{}".format(result_pytree['sum'].shape, result_pytree['sum']))
print(" Doubled (shape {}):\n{}".format(result_pytree['doubled'].shape, result_pytree['doubled']))
Input vectors (shape (4, 3)):
[[ 0 1 2]
[ 3 4 5]
[ 6 7 8]
[ 9 10 11]]
Result PyTree:
Sum (shape (4,)):
[ 3. 12. 21. 30.]
Doubled (shape (3, 4)):
[[ 0 6 12 18]
[ 2 8 14 20]
[ 4 10 16 22]]
Here, the batch of sums has shape (4,)
(batch axis 0), while the batch of doubled vectors has shape (3, 4)
(batch axis 1), exactly as specified in out_axes
.
in_axes
and out_axes
You frequently use in_axes
and out_axes
together to precisely control the vectorization process. This combination provides the flexibility needed to adapt functions expecting single inputs to complex batching scenarios without rewriting the core logic or resorting to manual dimension shuffling. By understanding how to specify which input axes to map and where the resulting batch dimension should appear in the output, you can write cleaner and often more efficient JAX code for batched computations.
© 2025 ApX Machine Learning