jax.jit
jit
jit
jit
grad
jax.grad
grad
of grad
)jax.value_and_grad
)vmap
jax.vmap
in_axes
, out_axes
)vmap
vmap
with jit
and grad
vmap
pmap
jax.pmap
in_axes
, out_axes
)lax.psum
, lax.pmean
, etc.)pmap
with other Transformationspmap
ped Functionsvmap
Often in numerical computation and machine learning, you need to apply the same function to multiple data points simultaneously, a process commonly referred to as batching. While you could write explicit loops or rely on functions designed for batch operations, these approaches can sometimes be verbose or require careful manual management of array dimensions.
JAX offers a transformation, jax.vmap
, designed specifically for automatic vectorization. It allows you to take a function written to operate on a single data point and efficiently apply it across an entire batch (or multiple batches) of data, often without needing to rewrite the original function's logic. vmap
effectively adds a "batch dimension" to your computations automatically.
In this chapter, you will learn:
jax.vmap
to vectorize functions operating on single or multiple arguments.in_axes
and out_axes
arguments.vmap
calls.vmap
interacts with other JAX transformations like jit
and grad
.vmap
effectively.By the end of this chapter, you'll be able to use vmap
to simplify and often accelerate your batch processing code in JAX.
4.1 The Concept of Vectorization
4.2 Introducing `jax.vmap`
4.3 Mapping over Specific Arguments (`in_axes`, `out_axes`)
4.4 Handling Multiple Batched Arguments
4.5 Nesting `vmap`
4.6 Combining `vmap` with `jit` and `grad`
4.7 Performance Considerations for `vmap`
4.8 Hands-on Practical: Vectorizing Functions
© 2025 ApX Machine Learning