After understanding the concept of JAX primitives as the fundamental building blocks known to the JAX system, the next step in defining your own custom operation is to teach JAX about its shape and type signature. This is achieved by implementing an abstract evaluation rule.
Abstract evaluation is a critical part of JAX's tracing mechanism. Before JAX can JIT-compile, automatically differentiate, or vectorize a function containing your custom primitive, it needs to determine the properties (like shape and data type, or dtype
) of the outputs produced by that primitive, without actually running the computation on real data. It performs this analysis by operating on abstract values, primarily jax.core.ShapedArray
instances, which encapsulate shape and dtype
information but hold no actual numerical values.
Think of abstract evaluation as defining the function signature of your primitive at the type and shape level. When JAX encounters your primitive during tracing (e.g., when jax.jit
is applied to a function using it), it invokes the primitive's abstract evaluation rule. This rule takes the abstract values corresponding to the primitive's inputs and any necessary metadata (parameters specific to the primitive's behavior) and computes the abstract values (shape and dtype
) of the primitive's outputs.
This information is essential for several reasons:
vmap
need to know how shapes change to correctly implement batching rules. grad
needs to know the shapes and types to set up the backward pass correctly.abstract_eval
MethodTo define the abstract evaluation rule for a custom primitive, you typically subclass jax.core.Primitive
and implement its abstract_eval
method.
import jax
import jax.numpy as jnp
from jax.core import Primitive, ShapedArray
from jax.interpreters import xla
# Example: A hypothetical primitive that adds 1.0 to an input
custom_add_one_p = Primitive('custom_add_one')
@custom_add_one_p.def_abstract_eval
def custom_add_one_abstract_eval(*avals, **params):
# avals contains the abstract values (e.g., ShapedArray) of the inputs
# params contains any parameters passed during primitive binding
# Basic validation (optional but recommended)
if len(avals) != 1:
raise ValueError("custom_add_one expects exactly one input operand.")
input_aval = avals[0]
# Check if the input is a ShapedArray (most common case)
if not isinstance(input_aval, ShapedArray):
raise TypeError(f"Input must be a ShapedArray, got {type(input_aval)}")
# Ensure dtype is floating-point for this specific operation
if not jnp.issubdtype(input_aval.dtype, jnp.floating):
# Or you might define specific promotion rules
raise TypeError(f"Input dtype must be floating-point, got {input_aval.dtype}")
# The core logic: Determine output shape and dtype
# For 'custom_add_one', the shape and dtype are the same as the input.
output_shape = input_aval.shape
output_dtype = input_aval.dtype
# Return the abstract value of the output
return ShapedArray(output_shape, output_dtype)
# Example usage (illustrative, doesn't run without implementation rule)
# def my_func(x):
# # Binding the primitive with input x
# return custom_add_one_p.bind(x)
In this example:
custom_add_one_p
.custom_add_one_abstract_eval
with @custom_add_one_p.def_abstract_eval
. This registers the function as the abstract evaluation rule for our primitive.avals
) for each positional argument passed to the primitive's bind
method, and any keyword parameters (params
) passed to bind
.isinstance(aval, ShapedArray)
is common for array operations. We also check the dtype
.output_shape
and output_dtype
. For our simple custom_add_one
, these remain identical to the input's shape and dtype
. More complex primitives would have logic here reflecting how they transform shapes (e.g., matrix multiplication, convolution).ShapedArray
representing the abstract properties of the primitive's output. If a primitive returned multiple outputs, this function would return a tuple of ShapedArray
instances.params
)Sometimes, a primitive's behavior, and thus its output shape or type, depends on parameters that are not input arrays themselves. For example, a convolution primitive needs stride and padding information, or a reduction primitive needs the axes to reduce along. These parameters are typically passed as keyword arguments to the primitive's bind
method and are received in the params
dictionary within the abstract_eval
function.
# Hypothetical reduction primitive
# custom_reduce_sum_p = Primitive('custom_reduce_sum')
# @custom_reduce_sum_p.def_abstract_eval
# def custom_reduce_sum_abstract_eval(aval, *, axis): # axis passed as kwarg to bind
# if not isinstance(aval, ShapedArray):
# raise TypeError("Input must be array")
#
# # Calculate output shape based on input shape and axis parameter
# output_shape = tuple(d for i, d in enumerate(aval.shape) if i not in axis)
# output_dtype = aval.dtype # Assuming dtype doesn't change
#
# return ShapedArray(output_shape, output_dtype)
# Usage:
# result = custom_reduce_sum_p.bind(my_array, axis=(0,))
Here, the axis
parameter influences the computation of the output_shape
. The abstract_eval
rule uses this parameter directly from the params
dictionary (or as a keyword argument as shown above, which JAX handles nicely) to determine the correct output shape.
Abstract evaluation is the first rule you need to define for a new primitive. It lays the groundwork for JAX to understand the primitive's signature. Only after defining abstract_eval
can you proceed to implement the actual computation logic (the "lowering" rule for specific backends like CPU/GPU/TPU) and its differentiation rules (JVP/VJP), which rely on the shape and type information provided by abstract evaluation.
The abstract evaluation rule (
abstract_eval
) is invoked during JAX tracing to determine the shapes and types of a primitive's outputs, enabling the creation of ajaxpr
(JAX's intermediate representation) which is subsequently used for compilation, execution via lowering rules, and defining differentiation behavior.
By carefully implementing the abstract_eval
method, you ensure that your custom primitive integrates smoothly with JAX's tracing and transformation machinery, behaving predictably regarding shapes and data types across different contexts like JIT compilation, vectorization, and automatic differentiation. This step is fundamental to creating robust and reusable custom operations in JAX.
© 2025 ApX Machine Learning