jax.jit
jit
jit
jit
grad
jax.grad
grad
of grad
)jax.value_and_grad
)vmap
jax.vmap
in_axes
, out_axes
)vmap
vmap
with jit
and grad
vmap
pmap
jax.pmap
in_axes
, out_axes
)lax.psum
, lax.pmean
, etc.)pmap
with other Transformationspmap
ped FunctionsOne of JAX's significant advantages is its ability to run the same code seamlessly across different types of hardware accelerators, like Graphics Processing Units (GPUs) and Tensor Processing Units (TPUs), in addition to traditional Central Processing Units (CPUs). Understanding how JAX manages these devices is important for writing efficient code.
JAX uses an accelerator backend called XLA (Accelerated Linear Algebra) to compile and run your NumPy-like code on these devices. This means you typically write your code once using jax.numpy
and JAX transformations, and JAX handles the execution on the available hardware without requiring device-specific code from your side.
By default, JAX attempts to use the most capable hardware it detects on your system. The general preference order is TPU > GPU > CPU. If you have a TPU available and configured, JAX will use it. If not, it will look for a compatible GPU. If neither accelerator is found or configured correctly, JAX will fall back to using the CPU.
This automatic selection simplifies getting started. You can often write and test your code on a standard CPU setup and then run the exact same code on a machine equipped with a GPU or TPU for significant performance gains, especially when using transformations like jax.jit
.
You can inspect the devices JAX recognizes using the jax.devices()
function. This function returns a list of device objects available to the current JAX process.
import jax
# List all devices JAX can see
available_devices = jax.devices()
print(f"Available devices: {available_devices}")
# Get the default device JAX will use
default_device = jax.default_backend()
print(f"Default backend: {default_device}")
The output will vary depending on your hardware and JAX installation:
Available devices: [CpuDevice(id=0)]
Default backend: cpu
Available devices: [cuda(id=0)] # Or sometimes [GpuDevice(id=0)] or similar
Default backend: gpu
Available devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), ...] # TPUs often appear as multiple devices
Default backend: tpu
Knowing which devices are available is the first step towards managing computation placement.
JAX arrays, unlike standard NumPy arrays which always reside in CPU memory (RAM), exist on a specific computational device (CPU, GPU, or TPU). When you create a JAX array, JAX usually places it on the default device.
import jax
import jax.numpy as jnp
# x will typically be created on the default device (e.g., GPU if available)
x = jnp.arange(10.0)
print(f"Array x is on device: {x.device()}")
Operations involving arrays on the same device are generally efficient. However, computations involving arrays on different devices (e.g., adding a CPU array to a GPU array) may require implicit data transfers, which can introduce performance overhead. JAX handles these transfers automatically, but being mindful of data locality is beneficial for optimization.
You can explicitly control device placement using jax.device_put()
. This function takes a NumPy array or a JAX array and returns a new JAX array placed on the specified device.
import jax
import jax.numpy as jnp
import numpy as np
# Create a NumPy array (lives in host CPU memory)
numpy_array = np.array([1.0, 2.0, 3.0])
# Get a list of available devices
devices = jax.devices()
if devices:
# Place the array on the first available JAX device
jax_array_on_device0 = jax.device_put(numpy_array, devices[0])
print(f"Array placed on: {jax_array_on_device0.device()}")
# If multiple devices are available (e.g., multiple GPUs or TPU cores)
if len(devices) > 1:
# Try placing on a different device
jax_array_on_device1 = jax.device_put(numpy_array, devices[1])
print(f"Array placed on: {jax_array_on_device1.device()}")
else:
# Explicitly place on CPU if it's the only device
cpu_device = jax.devices('cpu')[0]
jax_array_on_cpu = jax.device_put(numpy_array, cpu_device)
print(f"Array explicitly placed on: {jax_array_on_cpu.device()}")
else:
print("No JAX devices found.")
# Creating a JAX array directly often places it on the default device
default_device_array = jnp.ones(5)
print(f"Default array on: {default_device_array.device()}")
While explicit placement is possible, it's often less necessary in typical workflows than understanding the implications of device placement. For instance, when using jax.jit
, the JIT compilation process optimizes the function for the specific device where the computation will run. Input arrays might be automatically moved to the target device before the compiled function executes.
It's helpful to distinguish between the "host" (usually the CPU controlling the Python process) and the "device" (the accelerator like a GPU or TPU where computations primarily occur).
Transferring data between host and device memory takes time. For optimal performance, especially in iterative algorithms like training machine learning models, aim to:
jax.device_put
).jit
, vmap
, grad
).JAX's abstractions handle much of this, but keeping the host-device distinction in mind helps in diagnosing performance bottlenecks or understanding memory usage.
Later chapters on pmap
will delve into managing computations across multiple devices simultaneously, where explicit device awareness becomes even more relevant. For now, understand that JAX provides a layer that simplifies running code on accelerators, automatically selecting devices and managing data placement, while offering tools like jax.devices()
and jax.device_put()
for inspection and control when needed.
© 2025 ApX Machine Learning