Implementing a custom recurrent cell is a practical scenario often encountered in sequence modeling, making it a good application for control flow primitives like lax.scan. While basic RNNs are straightforward, many advanced architectures use more complex gating mechanisms. A simplified Gated Recurrent Unit (GRU) cell can be implemented using lax.scan to manage sequential processing and hidden state updates. This demonstrates how to structure non-trivial computations within the scan body.Understanding the Simplified GRU CellA GRU is a type of recurrent neural network cell designed to capture dependencies over different time scales by using gating mechanisms. These gates control the flow of information, deciding what to keep from the past state and what to incorporate from the new input.For our example, we'll implement a slightly simplified version. Let $x_t$ be the input vector at timestep $t$, and $h_{t-1}$ be the hidden state from the previous timestep. The hidden state $h_t$ at timestep $t$ is computed as follows:Update Gate ($z_t$): Determines how much of the previous hidden state to keep. $$ z_t = \sigma(W_z x_t + U_z h_{t-1} + b_z) $$Reset Gate ($r_t$): Determines how much of the previous hidden state to forget when computing the candidate state. $$ r_t = \sigma(W_r x_t + U_r h_{t-1} + b_r) $$Candidate Hidden State ($\tilde{h}_t$): Computes a new hidden state proposal based on the current input and the reset previous hidden state. $$ \tilde{h}t = \tanh(W_h x_t + U_h (r_t \odot h{t-1}) + b_h) $$Final Hidden State ($h_t$): Linearly interpolates between the previous hidden state $h_{t-1}$ and the candidate hidden state $\tilde{h}t$, using the update gate $z_t$. $$ h_t = (1 - z_t) \odot h{t-1} + z_t \odot \tilde{h}_t $$Here, $\sigma$ represents the sigmoid activation function, $\tanh$ is the hyperbolic tangent activation function, and $\odot$ denotes element-wise multiplication. $W_z, U_z, b_z, W_r, U_r, b_r, W_h, U_h, b_h$ are the learnable parameters (weight matrices and bias vectors) of the GRU cell.Implementing the GRU with lax.scanWe can implement the processing of an entire sequence through the GRU cell using lax.scan. The core idea is:The carry in lax.scan will hold the hidden state $h_t$ at each timestep.The xs argument will be the input sequence $(x_1, x_2, ..., x_T)$.The function passed to lax.scan (let's call it gru_step) will implement the four equations above, taking the previous hidden state h_prev (from the carry) and the current input x_t (from xs) to compute the new hidden state h_t.gru_step will return (h_t, h_t). The first h_t becomes the carry for the next step, and the second h_t is accumulated as the output sequence.Let's write the code. We'll start with imports and defining the parameters. In a real scenario, these parameters would be part of a larger model structure (like Flax or Haiku), but here we'll define them directly for clarity.import jax import jax.numpy as jnp import jax.lax as lax from jax import random # Define activation functions sigmoid = jax.nn.sigmoid tanh = jnp.tanh def initialize_gru_params(key, input_dim, hidden_dim): """Initializes parameters for the simplified GRU cell.""" keys = random.split(key, 6) # Need keys for Wz, Uz, bz, Wr, Ur, br, Wh, Uh, bh (3 pairs of W,U + 3 biases) # Update gate parameters Wz = random.normal(keys[0], (hidden_dim, input_dim)) * 0.01 Uz = random.normal(keys[1], (hidden_dim, hidden_dim)) * 0.01 bz = jnp.zeros((hidden_dim,)) # Reset gate parameters Wr = random.normal(keys[2], (hidden_dim, input_dim)) * 0.01 Ur = random.normal(keys[3], (hidden_dim, hidden_dim)) * 0.01 br = jnp.zeros((hidden_dim,)) # Candidate hidden state parameters Wh = random.normal(keys[4], (hidden_dim, input_dim)) * 0.01 Uh = random.normal(keys[5], (hidden_dim, hidden_dim)) * 0.01 bh = jnp.zeros((hidden_dim,)) params = { 'Wz': Wz, 'Uz': Uz, 'bz': bz, 'Wr': Wr, 'Ur': Ur, 'br': br, 'Wh': Wh, 'Uh': Uh, 'bh': bh } return params def gru_step(params, h_prev, x_t): """Performs one step of the simplified GRU computation.""" # Update gate z_t = sigmoid(jnp.dot(params['Wz'], x_t) + jnp.dot(params['Uz'], h_prev) + params['bz']) # Reset gate r_t = sigmoid(jnp.dot(params['Wr'], x_t) + jnp.dot(params['Ur'], h_prev) + params['br']) # Candidate hidden state h_tilde_t = tanh(jnp.dot(params['Wh'], x_t) + jnp.dot(params['Uh'], (r_t * h_prev)) + params['bh']) # Final hidden state h_t = (1.0 - z_t) * h_prev + z_t * h_tilde_t # Return new hidden state as carry and output return h_t, h_t def gru_sequence(params, initial_h, inputs): """Applies the GRU cell over a sequence of inputs using lax.scan.""" # Define the scan function, closing over the parameters scan_fn = lambda carry, x: gru_step(params, carry, x) # Apply lax.scan final_h, outputs_h = lax.scan(scan_fn, initial_h, inputs) # final_h contains the last hidden state # outputs_h contains the sequence of hidden states [h_1, h_2, ..., h_T] return final_h, outputs_h # Example Usage key = random.PRNGKey(0) seq_len = 10 input_dim = 5 hidden_dim = 8 # Initialize parameters gru_params = initialize_gru_params(key, input_dim, hidden_dim) # Create dummy input sequence (sequence_length, input_features) key, subkey = random.split(key) input_sequence = random.normal(subkey, (seq_len, input_dim)) # Initialize hidden state initial_hidden_state = jnp.zeros((hidden_dim,)) # Run the GRU over the sequence final_state, hidden_states_sequence = gru_sequence(gru_params, initial_hidden_state, input_sequence) print("Input sequence shape:", input_sequence.shape) print("Initial hidden state shape:", initial_hidden_state.shape) print("Final hidden state shape:", final_state.shape) print("Output sequence of hidden states shape:", hidden_states_sequence.shape)In this code:initialize_gru_params sets up the necessary weight matrices and bias vectors with appropriate shapes, using small random values for initialization.gru_step implements the core logic for a single timestep. It takes the parameters, the previous hidden state h_prev, and the current input x_t, returning the new hidden state h_t twice (once as the new carry, once as the output for this step).gru_sequence orchestrates the process. It defines scan_fn which is just gru_step with the params argument fixed (closed over). It then calls lax.scan with this function, the initial hidden state, and the input sequence.The example usage demonstrates how to create parameters, generate sample input, and call the gru_sequence function. The output shapes confirm that the final state has the hidden dimension, and the output sequence has dimensions (sequence_length, hidden_dimension).Integrating with JAX TransformationsOne of the advantages of using lax.scan is that the resulting gru_sequence function is fully compatible with other JAX transformations like jit, grad, and vmap.For instance, to compile the GRU computation for faster execution, simply wrap the call with jax.jit:# Compile the GRU function for efficiency jit_gru_sequence = jax.jit(gru_sequence) # Run the compiled version (first run includes compilation time) key, subkey = random.split(key) input_sequence_2 = random.normal(subkey, (seq_len, input_dim)) final_state_jit, hidden_states_sequence_jit = jit_gru_sequence(gru_params, initial_hidden_state, input_sequence_2) print("\nRunning JIT-compiled version:") print("Final hidden state shape (JIT):", final_state_jit.shape) print("Output sequence shape (JIT):", hidden_states_sequence_jit.shape) If you wanted to process a batch of sequences simultaneously, you could use jax.vmap. Assuming your inputs have a batch dimension, like (batch_size, seq_len, input_dim), you would map over the batch dimension for both the initial hidden state (batch_size, hidden_dim) and the inputs:# Example VMAP usage (requires batched inputs/states) # Assume: # batch_size = 32 # batched_inputs = random.normal(key, (batch_size, seq_len, input_dim)) # batched_initial_h = jnp.zeros((batch_size, hidden_dim,)) # Map over the batch dimension (axis 0 for params=None, initial_h, inputs) # Note: params are shared across the batch, so we use None in in_axes # batched_gru = jax.vmap(gru_sequence, in_axes=(None, 0, 0)) # final_states_batch, hidden_sequences_batch = batched_gru(gru_params, batched_initial_h, batched_inputs) # print("Batched final state shape:", final_states_batch.shape) # (batch_size, hidden_dim) # print("Batched output sequences shape:", hidden_sequences_batch.shape) # (batch_size, seq_len, hidden_dim) Similarly, you could compute gradients with respect to the parameters (gru_params) or the inputs (input_sequence) using jax.grad, enabling the training of the GRU cell within a larger model.This example illustrates how lax.scan provides a powerful mechanism for implementing complex, stateful sequential computations in a way that integrates cleanly with JAX's compilation and automatic differentiation capabilities. By defining the logic for a single step and letting lax.scan handle the iteration, you can build sophisticated recurrent models efficiently.