jax.custom_vjp allows defining your own vector-Jacobian product rule for a function. This is useful in several scenarios:Numerical Stability: The automatically derived gradient might suffer from numerical issues (like overflow or underflow) that a carefully crafted analytical gradient can avoid.Performance: You might know a more computationally efficient way to calculate the gradient than differentiating the forward pass.Non-JAX Code: If your function involves calls to external libraries or operations JAX doesn't know how to differentiate, you must provide the gradient manually.Stopping Gradients/Approximations: You might want to provide an approximate gradient or prevent gradients from flowing through certain parts of a computation.In this section, we'll implement a custom VJP for a common function, focusing on the mechanics and verification.Example: Custom Gradient for SoftplusConsider the Softplus function, defined as: $$ \text{softplus}(x) = \log(1 + e^x) $$ This function is often used as a smooth approximation of the ReLU activation function.Let's first write a standard JAX implementation:import jax import jax.numpy as jnp import numpy as np # Used for comparison/testing # Naive implementation - potentially unstable for large x def softplus_naive(x): return jnp.log(1 + jnp.exp(x)) # Example usage x_val = 10.0 print(f"softplus_naive({x_val}) = {softplus_naive(x_val)}") # Calculate gradient using JAX's autodiff grad_softplus_naive = jax.grad(softplus_naive) print(f"Gradient at {x_val} (JAX auto): {grad_softplus_naive(x_val)}")For very large positive values of $x$, $e^x$ can overflow standard floating-point representations (like float32). While JAX and XLA often handle such cases intelligently, let's assume we want to explicitly provide the gradient rule for pedagogical purposes or perhaps for a scenario where the automatic gradient does cause issues.The analytical derivative of softplus(x) is: $$ \frac{d}{dx} \text{softplus}(x) = \frac{e^x}{1 + e^x} = \frac{1}{1 + e^{-x}} = \text{sigmoid}(x) $$ The sigmoid function is numerically stable. We can use this analytical form in our custom VJP.Implementing with jax.custom_vjpTo define a custom VJP, we need three components:The original function, decorated with @jax.custom_vjp.A forward pass function (fwd) that computes the original output and saves any intermediate values (residuals) needed for the backward pass. It must return (output, residuals).A backward pass function (bwd) that takes the residuals and the upstream gradient (g, representing $d\mathcal{L}/dy$ where $y$ is the output of our function) and computes the downstream gradient with respect to the original inputs ($d\mathcal{L}/dx = (d\mathcal{L}/dy) \times (dy/dx)$). It must return a tuple of gradients, one for each positional argument of the original function.Let's implement this for softplus:import jax import jax.numpy as jnp import numpy as np # 1. Define the function and decorate it @jax.custom_vjp def softplus_custom(x): """Softplus implementation intended for custom VJP.""" # For demonstration, we use the potentially unstable forward pass here. # In a real scenario, you might use a more stable forward pass too. return jnp.log(1 + jnp.exp(x)) # 2. Define the forward pass function (_fwd) # It takes the same arguments as the original function (x). # It returns the output (y = softplus(x)) and residuals needed for backward pass. # We need the original input 'x' to compute sigmoid(x) in the backward pass. def softplus_fwd(x): y = softplus_custom(x) # Compute the primal output using the decorated function residuals = x # Save 'x' for the backward pass return y, residuals # 3. Define the backward pass function (_bwd) # It takes the residuals saved by _fwd and the upstream gradient 'g'. # It returns a tuple of gradients w.r.t the inputs of the original function. # Here, the only input is 'x', so we return a tuple with one element. def softplus_bwd(residuals, g): x = residuals # Unpack the residuals # Compute the gradient: g * sigmoid(x) # Use jax.nn.sigmoid for numerical stability grad_x = g * jax.nn.sigmoid(x) return (grad_x,) # Return gradient w.r.t. 'x' as a tuple # Associate the fwd and bwd functions with the custom_vjp function softplus_custom.defvjp(softplus_fwd, softplus_bwd) # --- Now let's test it --- # Calculate gradient using the custom VJP implementation grad_softplus_custom = jax.grad(softplus_custom) # Test values test_values = jnp.array([-10.0, 0.0, 10.0, 80.0, 100.0]) # Include large values # Compare naive autodiff gradient with custom gradient print("Input | Naive Grad (Autodiff) | Custom Grad (VJP)") print("------|-----------------------|-------------------") for x_val in test_values: # Use a try-except block for the naive version which might overflow/warn try: naive_grad = jax.grad(softplus_naive)(x_val) except OverflowError: naive_grad = float('inf') # Or handle as appropriate custom_grad = grad_softplus_custom(x_val) print(f"{x_val:<5.1f} | {naive_grad:<21.8f} | {custom_grad:<19.8f}") # JIT compilation works as expected jit_grad_softplus_custom = jax.jit(jax.grad(softplus_custom)) print(f"\nJIT-compiled custom gradient at 10.0: {jit_grad_softplus_custom(10.0):.8f}")Verification and AnalysisWhen you run the code above, you should observe that the gradients calculated by jax.grad(softplus_naive) and jax.grad(softplus_custom) are identical (within floating-point precision).Correctness: This confirms that our custom VJP rule correctly computes the analytical gradient (sigmoid).Numerical Stability (of Gradient): The backward pass explicitly uses jax.nn.sigmoid(x), which is numerically stable. While JAX's default autodiff for log(1 + exp(x)) also often yields a stable result for the gradient, defining it explicitly guarantees this behaviour. Note that the forward pass in softplus_custom is still the naive version; a production scenario might implement a stable forward pass and provide a custom VJP if needed.Mechanism: We defined softplus_fwd to compute the original function's value and cache the input x. The softplus_bwd function then used this cached x along with the incoming gradient g to compute the final gradient g * sigmoid(x). The defvjp call links these two functions to the @jax.custom_vjp decorator.Practical NotesResiduals: Carefully choose what to save in residuals. Saving unnecessary large intermediate arrays increases memory consumption. Save only what is strictly needed for the backward pass. Sometimes, recomputing values in the backward pass can be a trade-off for lower memory usage.Multiple Inputs/Outputs: If your function has multiple inputs or outputs, the fwd function receives all inputs, and the bwd function must return a tuple of gradients with the same length as the number of positional inputs to the original function. For outputs, the bwd function receives a g that matches the structure of the primal output (or is None for outputs not involved in the gradient path).jax.custom_jvp: For forward-mode differentiation, the process is analogous using @jax.custom_jvp, defjvp, and defining a function that computes (y, y_dot) given (x, x_dot).This practical exercise demonstrates the core mechanics of defining custom gradients in JAX. It's a powerful tool for handling specific numerical or performance requirements. Remember to thoroughly test your custom rules against JAX's automatic differentiation or numerical gradients where feasible.