← Back to main docs

JAX Testing Guide

This guide covers the current JAX workflow in Stormlog: profiling JAX code directly, tracking JAX memory usage from the CLI, and exporting artifacts for later review.

Before you start

Validate the environment:

jaxmemprof info

If you are bringing up an accelerator runtime (CUDA or TPU), start with a basic JAX array operation before attempting complex tracking. These checks work on CPU-backed JAX installs as well.

Daily workflow: ML engineer

Use JAXMemoryProfiler when you want snapshots and aggregate results around a real JAX workload.

import jax.numpy as jnp
from stormlog.jax import JAXMemoryProfiler

profiler = JAXMemoryProfiler()

with profiler.profile_context("training"):
    x = jnp.ones((1000, 1000))
    y = jnp.dot(x, x)
    # JAX operations are asynchronous. Block until ready.
    y.block_until_ready()

results = profiler.get_results()
print(f"Peak memory: {results.peak_memory_mb:.2f} MB")
print(f"Snapshots captured: {len(results.snapshots)}")

Daily workflow: investigate sustained growth

The JAX CLI is the simplest way to capture longer-running telemetry:

jaxmemprof monitor --interval 0.5 --duration 30 --output jax_monitor.json
jaxmemprof track --interval 0.5 --output jax_track.json
jaxmemprof analyze --input jax_monitor.json --detect-leaks --optimize --report jax_report.txt
jaxmemprof diagnose --duration 0 --output ./jax_diag

For CPU-backed JAX or when the accelerator backend is unavailable, jaxmemprof will automatically fallback to CPU mode. You can explicitly force it with --device cpu.

Common issues

jaxmemprof runs on CPU when I expected GPU/TPU

Run:

jaxmemprof info

If the CLI outputs that JAX is running on CPU, you’ll need to install the specific jax variants for your hardware (e.g., jax[cuda12], jax[tpu]).

Plot export fails

Install the visualization extra:

pip install "stormlog[viz]"