While JAX provides the powerful core primitives (jit
, grad
, vmap
, pmap
) for accelerated and distributed numerical computation, building, training, and managing large, intricate machine learning models purely with these tools can become cumbersome. Handling potentially complex state, such as model parameters, optimizer statistics, and random number generator (RNG) keys, requires careful bookkeeping, especially when distributing computations across multiple devices.
This is where neural network libraries built on top of JAX come into play. They provide higher-level abstractions for defining model architectures, managing state, and organizing training loops, allowing you to focus more on the model design and training logic rather than low-level implementation details. Two prominent libraries in the JAX ecosystem are Flax and Haiku. They offer different programming styles but share the common goal of simplifying the development of complex models within the JAX framework.
Flax, developed by Google, offers a functional approach centered around linen
modules (flax.linen
). It emphasizes explicit state management, meaning model parameters and other state variables (like batch normalization statistics) are typically handled outside the module methods themselves.
Key characteristics of Flax include:
flax.linen
): Models are defined by subclassing nn.Module
. Layers and submodules are defined within a setup
method. The forward pass logic resides in a __call__
method (or other named methods).module.init
), requiring an example input and an RNG key. This separates the module structure definition from the actual parameter values.apply
: The apply
method is used to run the forward pass with specific parameters and state. This functional nature makes it easy to work with JAX transformations.jax.jit
, jax.grad
, etc.flax.training.train_state.TrainState
to bundle together the model's apply function, parameters, and optimizer state, simplifying the training loop.Here's a example of a simple Flax module:
import jax
import jax.numpy as jnp
import flax.linen as nn
class SimpleMLP(nn.Module):
features: list[int] # List defining layer sizes, e.g., [128, 64, 10]
@nn.compact # Allows defining submodules inline in __call__
def __call__(self, x):
for i, feat in enumerate(self.features):
x = nn.Dense(features=feat, name=f'dense_{i}')(x)
if i != len(self.features) - 1: # Apply ReLU to all but last layer
x = nn.relu(x)
return x
# --- Usage ---
key = jax.random.PRNGKey(0)
input_shape = (1, 28*28) # Example: Batch size 1, flattened MNIST image
dummy_input = jnp.ones(input_shape)
output_features = [128, 10] # Define layer sizes
model = SimpleMLP(features=output_features)
# Initialize parameters (requires RNG key and dummy input)
params = model.init(key, dummy_input)['params']
print(f"Parameter PyTree structure:\n{jax.tree_util.tree_map(lambda x: x.shape, params)}")
# Apply the model (forward pass)
output = model.apply({'params': params}, dummy_input)
print(f"\nOutput shape: {output.shape}")
In the context of large models, Flax's structured approach helps organize complex architectures. Its explicit state management fits well with JAX's functional paradigm and simplifies passing parameters and state across distributed devices using pmap
.
Haiku, developed by DeepMind, offers an alternative programming model that feels more similar to object-oriented frameworks like PyTorch while still being fundamentally compatible with JAX's functional nature.
Key characteristics of Haiku include:
hk.Module
): Models are defined by subclassing hk.Module
. Layers are typically instantiated within the __init__
method, similar to PyTorch.hk.transform
: This is the core function that separates the pure, functional computation (suitable for JAX transformations) from the object-oriented module definition. It converts a function that instantiates and calls Haiku modules into a pair of functions: init
(for initializing parameters) and apply
(for the forward pass).hk.transform
context.Here's an example using Haiku, similar to the Flax MLP:
import jax
import jax.numpy as jnp
import haiku as hk
class SimpleMLP(hk.Module):
def __init__(self, features: list[int], name: str | None = None):
super().__init__(name=name)
self.features = features
def __call__(self, x):
for i, feat in enumerate(self.features):
# Layers are often created inline here in Haiku
x = hk.Linear(output_size=feat, name=f'linear_{i}')(x)
if i != len(self.features) - 1:
x = jax.nn.relu(x)
return x
# --- Usage ---
# Define the forward function that uses the Haiku module
def forward_fn(x):
output_features = [128, 10]
mlp = SimpleMLP(features=output_features)
return mlp(x)
# Transform the function
model = hk.transform(forward_fn)
key = jax.random.PRNGKey(0)
input_shape = (1, 28*28)
dummy_input = jnp.ones(input_shape)
# Initialize parameters (only needs RNG key and dummy input)
params = model.init(key, dummy_input)
print(f"Parameter PyTree structure:\n{jax.tree_util.tree_map(lambda x: x.shape, params)}")
# Apply the model (forward pass, only needs params and input)
output = model.apply(params, key, dummy_input) # Note: Haiku apply often needs RNG too if modules use randomness
print(f"\nOutput shape: {output.shape}")
Haiku's hk.transform
mechanism cleverly bridges the gap between stateful object-oriented module definitions and JAX's requirement for pure functions. This can feel more familiar to users coming from other frameworks. For large models, Haiku provides clear state management and integrates well with JAX transformations and distribution primitives.
While you can build everything in raw JAX, Flax and Haiku provide significant advantages, especially as model complexity and scale increase:
pmap
.The choice between Flax and Haiku often comes down to stylistic preference. Flax is more explicitly functional, requiring manual state passing, while Haiku uses hk.transform
to provide a more object-oriented feel with implicit state management. Both are powerful tools built on the same JAX core, enabling the construction and training of sophisticated models at scale. Understanding how these libraries structure code and manage state is fundamental to applying the large-scale techniques discussed in the rest of this chapter, such as integrating pmap
for data parallelism or implementing checkpointing strategies.
Comparison of state handling in Raw JAX, Flax, and Haiku. Raw JAX requires manual management. Flax uses explicit state passing, often bundled. Haiku uses
hk.transform
to manage state implicitly within itsapply
function.
© 2025 ApX Machine Learning