stormlog.jax.tracker

Real-time JAX Memory Tracking.

This module provides real-time monitoring of JAX device memory usage, integrating with Stormlog’s shared telemetry, session, and phase tracking infrastructure.

Classes

JAXMemoryTracker

alias of MemoryTracker

MemoryTracker([sampling_interval, ...])

Real-time JAX device memory tracker.

MemoryWatchdog([max_memory_mb, ...])

Automatic memory management and cleanup for JAX workloads.

TrackingResult(start_time, end_time, ...[, ...])

Results from real-time JAX memory tracking.

class stormlog.jax.tracker.TrackingResult(start_time, end_time, samples_collected, peak_memory_bytes, min_memory_bytes, average_memory_bytes, alert_count, session_summary=None, telemetry_events=<factory>, memory_usage=<factory>, timestamps=<factory>, device_memory_profile_path=None, history_window_limit=0, history_retained_samples=0, history_dropped_samples=0, history_retained_events=0, history_dropped_events=0, history_retained_alerts=0, history_dropped_alerts=0)[source]

Bases: object

Results from real-time JAX memory tracking.

Parameters:
  • start_time (float)

  • end_time (float)

  • samples_collected (int)

  • peak_memory_bytes (int)

  • min_memory_bytes (int)

  • average_memory_bytes (int)

  • alert_count (int)

  • session_summary (SessionSummary | None)

  • telemetry_events (List[Dict[str, Any]])

  • memory_usage (List[int])

  • timestamps (List[float])

  • device_memory_profile_path (str | None)

  • history_window_limit (int)

  • history_retained_samples (int)

  • history_dropped_samples (int)

  • history_retained_events (int)

  • history_dropped_events (int)

  • history_retained_alerts (int)

  • history_dropped_alerts (int)

start_time: float
end_time: float
samples_collected: int
peak_memory_bytes: int
min_memory_bytes: int
average_memory_bytes: int
alert_count: int
session_summary: SessionSummary | None = None
telemetry_events: List[Dict[str, Any]]
memory_usage: List[int]
timestamps: List[float]
device_memory_profile_path: str | None = None
history_window_limit: int = 0
history_retained_samples: int = 0
history_dropped_samples: int = 0
history_retained_events: int = 0
history_dropped_events: int = 0
history_retained_alerts: int = 0
history_dropped_alerts: int = 0
property peak_memory_mb: float

Peak memory usage in MB.

property average_memory_mb: float

Average memory usage in MB.

property duration: float

Total tracking duration in seconds.

class stormlog.jax.tracker.MemoryTracker(sampling_interval=1.0, alert_threshold_mb=None, device_index=0, enable_logging=True, max_history=10000, job_id=None, rank=None, local_rank=None, world_size=None, telemetry_sink_config=None, save_device_profile_on_stop=False, enable_oom_flight_recorder=False, oom_dump_dir='oom_dumps', oom_buffer_size=None, oom_max_dumps=5, oom_max_total_mb=256)[source]

Bases: object

Real-time JAX device memory tracker.

Parameters:
  • sampling_interval (float)

  • alert_threshold_mb (Optional[float])

  • device_index (int)

  • enable_logging (bool)

  • max_history (int)

  • job_id (Optional[str])

  • rank (Optional[int])

  • local_rank (Optional[int])

  • world_size (Optional[int])

  • telemetry_sink_config (Optional[TelemetrySinkConfig])

  • save_device_profile_on_stop (bool)

  • enable_oom_flight_recorder (bool)

  • oom_dump_dir (str)

  • oom_buffer_size (Optional[int])

  • oom_max_dumps (int)

  • oom_max_total_mb (int)

get_session_summary()[source]
Return type:

SessionSummary | None

property oom_buffer_size: int

Resolved OOM ring-buffer size.

add_alert_callback(callback)[source]
Parameters:

callback (Callable[[Dict[str, Any]], None])

Return type:

None

remove_alert_callback(callback)[source]

Remove a previously registered alert callback.

Parameters:

callback (Callable[[Dict[str, Any]], None])

Return type:

None

set_alert_threshold(threshold_mb)[source]
Parameters:

threshold_mb (float)

Return type:

None

check_alerts()[source]
Return type:

bool

start_tracking()[source]
Return type:

None

stop_tracking()[source]
Return type:

TrackingResult

get_current_memory()[source]

Get current memory usage in MB.

Return type:

float

get_statistics()[source]
Return type:

dict[str, Any]

get_tracking_results()[source]

Get current tracking results without stopping.

Return type:

TrackingResult

enter_phase(name, *, metadata=None)[source]
Parameters:
  • name (str)

  • metadata (Dict[str, Any] | None)

Return type:

PhaseHandle

phase(name, *, metadata=None)[source]
Parameters:
  • name (str)

  • metadata (Dict[str, Any] | None)

Return type:

Iterator[PhaseHandle]

property last_oom_dump_path: str | None

Path to the most recent OOM dump bundle, or None.

handle_exception(exc, context=None, metadata=None)[source]

Capture OOM diagnostics for recognized OOM exceptions.

Parameters:
  • exc (BaseException)

  • context (str | None)

  • metadata (Dict[str, Any] | None)

Return type:

str | None

capture_oom(context='runtime', metadata=None)[source]

Capture an OOM diagnostic bundle if the wrapped block raises an OOM.

Parameters:
  • context (str)

  • metadata (Dict[str, Any] | None)

Return type:

Iterator[None]

trigger_oom_dump(exception, context=None, metadata=None)[source]

Manually trigger an OOM diagnostic dump bundle.

Parameters:
  • exception (BaseException)

  • context (str | None)

  • metadata (Dict[str, Any] | None)

Return type:

str | None

save_device_memory_profile(output_path)[source]

Save a JAX device memory profile to the given path.

Note

This method depends on jax.profiler.save_device_memory_profile which is only available on GPU/TPU backends with JAX >= 0.4.1. On CPU-only installs or older JAX versions the call is a no-op and returns False. The availability is checked at runtime via hasattr guards so no import error is raised.

Parameters:

output_path (str)

Return type:

bool

save_device_memory_profile_to_dir(output_dir=None)[source]

Save a JAX device memory profile to a directory with an auto-generated filename.

Parameters:

output_dir (str | None)

Return type:

str | None

class stormlog.jax.tracker.MemoryWatchdog(max_memory_mb=8000, cleanup_threshold_mb=6000, check_interval=5.0, device_index=0)[source]

Bases: object

Automatic memory management and cleanup for JAX workloads.

Parameters:
  • max_memory_mb (float)

  • cleanup_threshold_mb (float)

  • check_interval (float)

  • device_index (int)

add_cleanup_callback(callback)[source]

Add cleanup callback function.

Parameters:

callback (Callable[[], None])

Return type:

None

start()[source]

Start memory watchdog.

Return type:

None

stop()[source]

Stop memory watchdog.

Return type:

None

force_cleanup(aggressive=False)[source]

Force immediate memory cleanup.

Parameters:

aggressive (bool) – When True, also delete all live JAX arrays reachable via jax.live_arrays() (if available) before running garbage collection. Use with caution — this can invalidate arrays still referenced by user code.

Return type:

None

stormlog.jax.tracker.JAXMemoryTracker

alias of MemoryTracker