jax.jit
jit
jit
jit
grad
jax.grad
grad
of grad
)jax.value_and_grad
)vmap
jax.vmap
in_axes
, out_axes
)vmap
vmap
with jit
and grad
vmap
pmap
jax.pmap
in_axes
, out_axes
)lax.psum
, lax.pmean
, etc.)pmap
with other Transformationspmap
ped Functionsgrad
Many computational tasks, especially in machine learning model training, rely on calculating the gradient of functions. Automatic differentiation provides an efficient way to compute these derivatives. This chapter introduces jax.grad
, the core JAX transformation for obtaining gradient functions from Python code operating on numerical inputs.
You will learn how to:
jax.grad
to compute the gradient ∇f(x) of a scalar-valued function f.grad
.grad
.jax.value_and_grad
.By the end of this chapter, you will be able to use jax.grad
effectively to differentiate your numerical functions within the JAX framework.
3.1 Understanding Gradients
3.2 Introducing `jax.grad`
3.3 How Autodiff Works: Reverse Mode
3.4 Differentiating with Respect to Arguments
3.5 Higher-Order Derivatives (`grad` of `grad`)
3.6 Value and Gradient (`jax.value_and_grad`)
3.7 Differentiation and Control Flow
3.8 Limitations and Considerations
3.9 Hands-on Practical: Computing Gradients
© 2025 ApX Machine Learning