Source code for stormlog.jax.profiler

"""JAX Memory Profiler.

Provides snapshot-based memory profiling, function/context profiling,
and a global profiler singleton for JAX workloads.
"""

from __future__ import annotations

import functools
import logging
import threading
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import (
    Any,
    Callable,
    Dict,
    Iterator,
    List,
    Optional,
    TypeVar,
    cast,
)

from .jax_env import configure_jax_logging

configure_jax_logging()

import jax  # noqa: E402

JAX_AVAILABLE = True

try:
    import psutil

    PSUTIL_AVAILABLE = True
except ImportError:
    PSUTIL_AVAILABLE = False
    psutil = None

logger = logging.getLogger(__name__)

F = TypeVar("F", bound=Callable[..., Any])

# ---------------------------------------------------------------------------
# Data classes
# ---------------------------------------------------------------------------


[docs] @dataclass class MemorySnapshot: """Point-in-time JAX memory snapshot.""" timestamp: float name: str device_memory_bytes: int cpu_memory_bytes: int device_id: int device_memory_reserved_bytes: int = 0 memory_stats: Dict[str, Any] = field(default_factory=dict) operation_name: Optional[str] = None def __post_init__(self) -> None: if self.device_memory_bytes < 0: self.device_memory_bytes = 0 if self.device_memory_reserved_bytes < 0: self.device_memory_reserved_bytes = 0 if self.cpu_memory_bytes < 0: self.cpu_memory_bytes = 0 @property def device_memory_mb(self) -> float: return self.device_memory_bytes / (1024 * 1024) @property def device_memory_reserved_mb(self) -> float: return self.device_memory_reserved_bytes / (1024 * 1024) @property def cpu_memory_mb(self) -> float: return self.cpu_memory_bytes / (1024 * 1024)
[docs] @dataclass class ProfileResult: """Aggregated profiling results for a JAX session.""" start_time: float end_time: float peak_memory_bytes: int average_memory_bytes: int min_memory_bytes: int snapshots: List[MemorySnapshot] = field(default_factory=list) function_profiles: Dict[str, Dict[str, Any]] = field(default_factory=dict) @property def duration(self) -> float: return self.end_time - self.start_time @property def peak_memory_mb(self) -> float: return self.peak_memory_bytes / (1024 * 1024) @property def average_memory_mb(self) -> float: return self.average_memory_bytes / (1024 * 1024) @property def min_memory_mb(self) -> float: return self.min_memory_bytes / (1024 * 1024) @property def memory_growth_rate(self) -> float: """Memory growth rate in bytes/second.""" if self.duration <= 0: return 0.0 return (self.peak_memory_bytes - self.min_memory_bytes) / self.duration
# --------------------------------------------------------------------------- # Core profiler # ---------------------------------------------------------------------------
[docs] class JAXMemoryProfiler: """JAX memory profiler with snapshot capture and function profiling. Provides: * ``capture_snapshot()`` – point-in-time device memory reading * ``profile_function()`` – decorator-based before/after profiling * ``profile_context()`` – with-block profiling * ``start_continuous_profiling()`` / ``stop_continuous_profiling()`` * ``get_results()`` – aggregate into :class:`ProfileResult` Example:: profiler = JAXMemoryProfiler() with profiler: s = profiler.capture_snapshot("after_init") result = profiler.get_results() """ def __init__(self, device_index: int = 0) -> None: self._device_index = device_index self._device: Any = None self._lock = threading.Lock() self._snapshots: List[MemorySnapshot] = [] self._function_profiles: Dict[str, Dict[str, Any]] = {} self._continuous_thread: Optional[threading.Thread] = None self._continuous_stop = threading.Event() self._start_time: Optional[float] = None self._end_time: Optional[float] = None if JAX_AVAILABLE: try: devices = jax.local_devices() if device_index < len(devices): self._device = devices[device_index] except Exception as exc: logger.debug("Could not resolve JAX device %d: %s", device_index, exc) # Cache a scalar sentinel for sync barriers self._sync_sentinel: Any = None if self._device is not None: try: self._sync_sentinel = jax.numpy.zeros((), device=self._device) except Exception: pass # -- Snapshot capture --------------------------------------------------
[docs] def capture_snapshot( self, name: str = "snapshot", *, operation_name: Optional[str] = None, ) -> MemorySnapshot: """Capture a point-in-time memory snapshot. Args: name: Human-readable label for this snapshot. operation_name: Optional operation being profiled. Returns: A :class:`MemorySnapshot`. """ device_bytes = 0 reserved_bytes = 0 memory_stats: Dict[str, Any] = {} if self._device is not None: try: # Flush XLA async dispatch before reading memory stats. # JAX dispatches operations asynchronously to XLA — without # this synchronisation barrier, memory_stats() may return # stale values that exclude memory from operations still # "in flight". The trade-off is a small allocation # (1-element array) and a forced sync; at high snapshot # frequencies this can slightly perturb workload timing. if self._sync_sentinel is not None: self._sync_sentinel.block_until_ready() else: jax.numpy.zeros((), device=self._device).block_until_ready() raw = self._device.memory_stats() if raw is not None: memory_stats = dict(raw) device_bytes = int(memory_stats.get("bytes_in_use", 0)) reserved_bytes = int( memory_stats.get("bytes_reserved", device_bytes) ) except Exception as exc: logger.debug("Snapshot memory_stats failed: %s", exc) cpu_bytes = 0 if PSUTIL_AVAILABLE and psutil is not None: try: cpu_bytes = psutil.Process().memory_info().rss except Exception as exc: logger.debug("CPU memory read failed: %s", exc) snap = MemorySnapshot( timestamp=time.time(), name=name, device_memory_bytes=device_bytes, device_memory_reserved_bytes=reserved_bytes, cpu_memory_bytes=cpu_bytes, device_id=self._device_index, memory_stats=memory_stats, operation_name=operation_name, ) with self._lock: self._snapshots.append(snap) return snap
# -- Function profiling ------------------------------------------------
[docs] def profile_function( self, func: Optional[F] = None, *, name: Optional[str] = None, ) -> Any: """Decorator that profiles a function's memory impact. Usage:: @profiler.profile_function def train_step(): ... # or with options: @profiler.profile_function(name="custom_name") def train_step(): ... """ def decorator(f: F) -> F: profiled_name = name or f.__name__ @functools.wraps(f) def wrapper(*args: Any, **kwargs: Any) -> Any: before = self.capture_snapshot( f"{profiled_name}_before", operation_name=profiled_name, ) start = time.monotonic() try: result = f(*args, **kwargs) finally: elapsed = time.monotonic() - start after = self.capture_snapshot( f"{profiled_name}_after", operation_name=profiled_name, ) delta = after.device_memory_bytes - before.device_memory_bytes peak = max( before.device_memory_bytes, after.device_memory_bytes, ) with self._lock: entry = self._function_profiles.setdefault( profiled_name, { "calls": 0, "total_duration": 0.0, "peak_memory_bytes": 0, "total_memory_delta": 0, "last_memory_delta": 0, }, ) entry["calls"] += 1 entry["total_duration"] += elapsed entry["peak_memory_bytes"] = max( entry["peak_memory_bytes"], peak ) entry["total_memory_delta"] += delta entry["last_memory_delta"] = delta return result return cast(F, wrapper) if func is None: return decorator return decorator(func)
# -- Context profiling -------------------------------------------------
[docs] @contextmanager def profile_context( self, name: str = "context", ) -> Iterator[MemorySnapshot]: """Context manager that captures before/after snapshots. Usage:: with profiler.profile_context("matmul") as snap_before: result = jax.numpy.dot(a, b) """ before = self.capture_snapshot(f"{name}_before", operation_name=name) start = time.monotonic() try: yield before finally: elapsed = time.monotonic() - start after = self.capture_snapshot(f"{name}_after", operation_name=name) delta = after.device_memory_bytes - before.device_memory_bytes peak = max( before.device_memory_bytes, after.device_memory_bytes, ) with self._lock: entry = self._function_profiles.setdefault( name, { "calls": 0, "total_duration": 0.0, "peak_memory_bytes": 0, "total_memory_delta": 0, "last_memory_delta": 0, }, ) entry["calls"] += 1 entry["total_duration"] += elapsed entry["peak_memory_bytes"] = max(entry["peak_memory_bytes"], peak) entry["total_memory_delta"] += delta entry["last_memory_delta"] = delta
# -- Continuous profiling ----------------------------------------------
[docs] def start_continuous_profiling(self, interval: float = 1.0) -> None: """Start background snapshot capture at *interval* seconds.""" if self._continuous_thread is not None: logger.warning("Continuous profiling already running") return self._continuous_stop.clear() self._continuous_thread = threading.Thread( target=self._continuous_loop, args=(interval,), daemon=True, ) self._continuous_thread.start()
[docs] def stop_continuous_profiling(self) -> None: """Stop the background snapshot loop.""" self._continuous_stop.set() if self._continuous_thread is not None: self._continuous_thread.join(timeout=5.0) self._continuous_thread = None
def _continuous_loop(self, interval: float) -> None: counter = 0 while not self._continuous_stop.is_set(): self.capture_snapshot(f"continuous_{counter}") counter += 1 self._continuous_stop.wait(interval) # -- Results -----------------------------------------------------------
[docs] def get_results(self) -> ProfileResult: """Aggregate captured snapshots into a :class:`ProfileResult`.""" with self._lock: snapshots = list(self._snapshots) profiles = {k: dict(v) for k, v in self._function_profiles.items()} snapshots.sort(key=lambda snapshot: snapshot.timestamp) if not snapshots: now = time.time() return ProfileResult( start_time=self._start_time or now, end_time=self._end_time or now, peak_memory_bytes=0, average_memory_bytes=0, min_memory_bytes=0, snapshots=[], function_profiles=profiles, ) memories = [s.device_memory_bytes for s in snapshots] return ProfileResult( start_time=snapshots[0].timestamp, end_time=snapshots[-1].timestamp, peak_memory_bytes=max(memories), average_memory_bytes=int(sum(memories) / len(memories)), min_memory_bytes=min(memories), snapshots=snapshots, function_profiles=profiles, )
[docs] def reset(self) -> None: """Clear all captured data.""" with self._lock: self._snapshots.clear() self._function_profiles.clear() self._start_time = None self._end_time = None
# -- Context manager --------------------------------------------------- def __enter__(self) -> "JAXMemoryProfiler": self._start_time = time.time() self.capture_snapshot("session_start") return self def __exit__(self, *exc: Any) -> None: self.capture_snapshot("session_end") self._end_time = time.time() self.stop_continuous_profiling()
# --------------------------------------------------------------------------- # Global profiler singleton # --------------------------------------------------------------------------- _global_profiler: Optional[JAXMemoryProfiler] = None _profiler_lock = threading.Lock()
[docs] def get_global_profiler() -> JAXMemoryProfiler: """Get or create the global :class:`JAXMemoryProfiler` instance.""" global _global_profiler with _profiler_lock: if _global_profiler is None: _global_profiler = JAXMemoryProfiler() return _global_profiler
[docs] def set_global_profiler(profiler: JAXMemoryProfiler) -> None: """Replace the global profiler instance.""" global _global_profiler with _profiler_lock: if _global_profiler is not None: _global_profiler.stop_continuous_profiling() _global_profiler = profiler
[docs] def clear_global_profiler() -> None: """Reset and discard the global profiler.""" global _global_profiler with _profiler_lock: if _global_profiler is not None: _global_profiler.stop_continuous_profiling() _global_profiler.reset() _global_profiler = None
[docs] def clear_profiles() -> None: """Reset the global profiler without discarding it.""" with _profiler_lock: if _global_profiler is not None: _global_profiler.reset()
[docs] def get_profile_summaries( limit: Optional[int] = None, ) -> List[Dict[str, Any]]: """Return aggregated profile summaries from the global profiler. Args: limit: Maximum number of summaries to return. Returns: A list of dicts, one per profiled function/context block. """ with _profiler_lock: prof = _global_profiler if prof is None: return [] result = prof.get_results() summaries: List[Dict[str, Any]] = [] for name, entry in result.function_profiles.items(): summaries.append({"name": name, **entry}) summaries.sort(key=lambda s: s.get("peak_memory_bytes", 0), reverse=True) if limit is not None: summaries = summaries[:limit] return summaries