Source code for stormlog.jax.context_profiler

"""JAX Context Profiling.

Provides module-level convenience functions, a high-level ``JAXProfiler``
class, and a ``ProfiledFunction`` wrapper analogous to the
``ProfiledModule``/``ProfiledLayer`` classes in the PyTorch and TensorFlow
context profilers.

Because JAX follows a functional paradigm (no ``nn.Module`` hierarchy),
``ProfiledFunction`` wraps arbitrary callables rather than layer objects.
"""

from __future__ import annotations

import functools
import logging
import threading
from contextlib import contextmanager
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    Iterator,
    List,
    Optional,
    TypeVar,
    Union,
    cast,
)

if TYPE_CHECKING:
    import stormlog.jax.profiler

from .jax_env import configure_jax_logging
from .profiler import (
    JAXMemoryProfiler,
    ProfileResult,
)
from .profiler import clear_global_profiler as _clear_profiler
from .profiler import clear_profiles as _clear_profiles
from .profiler import get_global_profiler as _get_profiler
from .profiler import get_profile_summaries as _get_summaries
from .profiler import set_global_profiler as _set_profiler

jax: Any
try:
    import jax as _jax  # noqa: F401

    jax = _jax
    JAX_AVAILABLE = True
except ImportError:
    JAX_AVAILABLE = False
    jax = None

configure_jax_logging()

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Module-level global profiler state
# ---------------------------------------------------------------------------

_global_profiler: Optional[JAXMemoryProfiler] = None
_profiler_lock = threading.Lock()
F = TypeVar("F", bound=Callable[..., Any])


# ---------------------------------------------------------------------------
# Module-level convenience functions
# ---------------------------------------------------------------------------


[docs] def get_global_profiler() -> JAXMemoryProfiler: """Get or create the global :class:`~stormlog.jax.profiler.JAXMemoryProfiler` instance. Delegates to the singleton managed in :mod:`stormlog.jax.profiler`. Returns: The global :class:`~stormlog.jax.profiler.JAXMemoryProfiler`. """ return _get_profiler()
[docs] def set_global_profiler(profiler: JAXMemoryProfiler) -> None: """Replace the global :class:`~stormlog.jax.profiler.JAXMemoryProfiler` instance. Args: profiler: New profiler instance to install as the global singleton. """ _set_profiler(profiler)
[docs] def clear_global_profiler() -> None: """Reset and discard the global profiler. After this call, :func:`get_global_profiler` will create a fresh instance on the next invocation. """ _clear_profiler()
[docs] def clear_profiles() -> None: """Reset profiling data without discarding the global profiler.""" _clear_profiles()
[docs] def get_profile_summaries(limit: Optional[int] = None) -> List[Dict[str, Any]]: """Return aggregated profiling summaries from the global profiler. Args: limit: Maximum number of summaries to return. Returns: A list of dicts, one per profiled function/context block, sorted by peak memory (descending). """ return _get_summaries(limit)
# --------------------------------------------------------------------------- # Decorator & context-manager helpers # ---------------------------------------------------------------------------
[docs] def profile_function( func: Optional[F] = None, *, name: Optional[str] = None, profiler: Optional[JAXMemoryProfiler] = None, ) -> Union[Callable[[F], F], F]: """Decorator to profile a function's JAX device-memory usage. Can be used bare (``@profile_function``) or with keyword arguments (``@profile_function(name="custom")``). Args: func: Function to profile (when used as ``@profile_function``). name: Custom name for the profiled function. Defaults to ``func.__name__``. profiler: Explicit :class:`~stormlog.jax.profiler.JAXMemoryProfiler` to use. Falls back to the global profiler when *None*. Returns: Decorated callable (or decorator factory when called with keyword arguments). """ def decorator(f: F) -> F: profiled_name = name or getattr(f, "__name__", "unknown_function") @functools.wraps(f) def wrapper(*args: Any, **kwargs: Any) -> Any: prof = profiler or get_global_profiler() with prof.profile_context(profiled_name or "profiled_func"): return f(*args, **kwargs) # Attach profiling metadata for introspection. wrapper._is_profiled = True # type: ignore[attr-defined] wrapper._profile_name = profiled_name # type: ignore[attr-defined] return cast(F, wrapper) if func is None: # Called as ``@profile_function(name=...)`` return decorator # Called as ``@profile_function`` return decorator(func)
[docs] @contextmanager def profile_context( name: str = "context", profiler: Optional[JAXMemoryProfiler] = None, ) -> Iterator[JAXMemoryProfiler]: """Context manager for profiling a block of code. Args: name: Label for the profiled block. profiler: Explicit profiler. Falls back to the global profiler. Yields: The :class:`~stormlog.jax.profiler.JAXMemoryProfiler` being used. Example:: with profile_context("matmul") as prof: result = jax.numpy.dot(a, b) """ prof = profiler or get_global_profiler() with prof.profile_context(name): yield prof
# --------------------------------------------------------------------------- # ProfiledFunction – wraps arbitrary callables # ---------------------------------------------------------------------------
[docs] class ProfiledFunction: """Wrapper that automatically profiles every call to a function. JAX does not use an ``nn.Module`` class hierarchy, so instead of wrapping a layer/module this class wraps any callable (pure functions, closures, ``jax.jit``-compiled functions, etc.). Args: func: The callable to profile. profiler: Explicit :class:`~stormlog.jax.profiler.JAXMemoryProfiler`. Falls back to the global profiler when *None*. name: Label used in profiling output. Defaults to the callable's ``__name__`` or ``__class__.__name__``. Example:: profiled_forward = ProfiledFunction(forward_fn, name="forward") output = profiled_forward(params, batch) """ def __init__( self, func: Callable[..., Any], profiler: Optional["stormlog.jax.profiler.JAXMemoryProfiler"] = None, name: Optional[str] = None, ) -> None: self.func = func self.profiler = profiler or get_global_profiler() self.name = name or getattr(func, "__name__", func.__class__.__name__) # Preserve introspection attributes from the wrapped callable. functools.update_wrapper(self, func) def __call__(self, *args: Any, **kwargs: Any) -> Any: """Invoke the wrapped callable under memory profiling.""" with self.profiler.profile_context(self.name or "profiled_func"): return self.func(*args, **kwargs) def __repr__(self) -> str: return f"ProfiledFunction({self.name!r})"
# --------------------------------------------------------------------------- # JAXProfiler – high-level profiling interface # ---------------------------------------------------------------------------
[docs] class JAXProfiler: """High-level JAX profiling interface. Provides convenience methods for profiling training loops and inference passes, analogous to :class:`stormlog.context_profiler.MemoryProfiler` (PyTorch) and :class:`stormlog.tensorflow.context_profiler.TensorFlowProfiler`. Args: device_index: Index of the JAX device to monitor (default ``0``). Example:: jp = JAXProfiler() jp.profile_training(train_step, dataset, epochs=3) result = jp.get_results() """ def __init__(self, device_index: int = 0) -> None: self.profiler = JAXMemoryProfiler(device_index=device_index) set_global_profiler(self.profiler) # -- Training profiling ------------------------------------------------
[docs] def profile_training( self, train_step_fn: Callable[..., Any], dataset: Any, epochs: int = 1, steps_per_epoch: Optional[int] = None, ) -> None: """Profile a JAX training loop. The caller supplies a *train_step_fn* that is invoked once per batch. ``train_step_fn`` should accept a single batch as its first positional argument (additional arguments can be closed over or passed through the function itself). Args: train_step_fn: A callable ``(batch) -> Any`` that executes a single training step. dataset: An iterable of batches (must be re-iterable for multi-epoch training; generators are exhausted after epoch 0). Each epoch iterates over the full dataset (or up to steps_per_epoch batches). epochs: Number of epochs to profile. steps_per_epoch: Optional cap on the number of steps per epoch. """ if not JAX_AVAILABLE: raise ImportError("JAX is required for JAXProfiler.profile_training") # Convert single-use iterators (like generators) to a list once # so subsequent epochs can reuse the same data. if iter(dataset) is dataset: dataset = list(dataset) with self.profiler.profile_context("training"): for epoch in range(epochs): with self.profiler.profile_context(f"epoch_{epoch}"): step_count = 0 for batch in dataset: if ( steps_per_epoch is not None and step_count >= steps_per_epoch ): break with self.profiler.profile_context(f"step_{step_count}"): train_step_fn(batch) step_count += 1
# -- Inference profiling -----------------------------------------------
[docs] def profile_inference( self, inference_fn: Callable[..., Any], data: Any, batch_size: int = 32, ) -> None: """Profile a JAX inference pass. If *data* is an iterable of batches it is consumed directly; otherwise it is treated as a single array-like and sliced into batches of *batch_size*. Args: inference_fn: A callable ``(batch) -> Any`` that runs inference on a single batch. data: Input data – either an iterable of batches or a single array-like with a leading batch dimension. batch_size: Batch size used when *data* must be sliced. """ if not JAX_AVAILABLE: raise ImportError("JAX is required for JAXProfiler.profile_inference") with self.profiler.profile_context("inference"): # Try iterating first (e.g. tf.data.Dataset, list of batches). if hasattr(data, "__iter__") and not hasattr(data, "shape"): for i, batch in enumerate(data): with self.profiler.profile_context(f"inference_batch_{i}"): inference_fn(batch) return # Fall back to manual batching over an array-like. import jax.numpy as jnp if not hasattr(data, "shape"): data = jnp.asarray(data) num_samples = data.shape[0] num_batches = (num_samples + batch_size - 1) // batch_size for i in range(num_batches): start_idx = i * batch_size end_idx = min((i + 1) * batch_size, num_samples) batch = data[start_idx:end_idx] with self.profiler.profile_context(f"inference_batch_{i}"): inference_fn(batch)
# -- Results / lifecycle -----------------------------------------------
[docs] def get_results(self) -> ProfileResult: """Return the aggregated :class:`ProfileResult`.""" return self.profiler.get_results()
[docs] def reset(self) -> None: """Clear all captured profiling data.""" self.profiler.reset()