Chapter 1: Introduction to JAX
Core Design Philosophy: Function Transformations
Device Management: CPU, GPU, TPU
Practice: Basic Array Operations
Chapter 2: Accelerating Functions with JIT Compilation
The Need for Speed: Why Compile?
How JIT Works: Tracing and Compilation
Python Control Flow and jit
Common Challenges with jit
Hands-on Practical: Applying jit
Chapter 3: Automatic Differentiation with grad
How Autodiff Works: Reverse Mode
Differentiating with Respect to Arguments
Higher-Order Derivatives (grad
of grad
)
Value and Gradient (jax.value_and_grad
)
Differentiation and Control Flow
Limitations and Considerations
Hands-on Practical: Computing Gradients
Chapter 4: Automatic Vectorization with vmap
The Concept of Vectorization
Mapping over Specific Arguments (in_axes
, out_axes
)
Handling Multiple Batched Arguments
Combining vmap
with jit
and grad
Performance Considerations for vmap
Hands-on Practical: Vectorizing Functions
Chapter 5: Parallelization Across Devices with pmap
Introduction to Data Parallelism (SPMD)
Mapping Data to Devices (in_axes
, out_axes
)
Device Meshes and Axis Names
Collective Operations (lax.psum
, lax.pmean
, etc.)
Combining pmap
with other Transformations
Debugging pmap
ped Functions
Hands-on Practical: Parallel Computation
Chapter 6: Managing State in JAX
Functional Purity and Side Effects
The Challenge of State in Functional Code
Pattern: Explicit State Passing
Using PyTrees for Structured State
Example: Stateful Counter
Example: Simple Optimizer State
Combining State Management with Transformations
Practice: Implementing Stateful Functions