stormlog.jax.analyzer

JAX Memory Analysis.

Advanced analysis tools for JAX memory profiling data. Provides feature parity with the PyTorch and TensorFlow analyzer modules while adapting recommendations and heuristics to JAX-specific idioms (XLA compilation, jax.jit, device memory pools, etc.).

Classes

MemoryAnalyzer([sensitivity, ...])

Advanced analyzer for JAX memory profiling data.

class stormlog.jax.analyzer.MemoryAnalyzer(sensitivity=0.05, collective_sensitivity='medium', collective_threshold_overrides=None)[source]

Bases: object

Advanced analyzer for JAX memory profiling data.

Mirrors the public interface of the PyTorch and TensorFlow MemoryAnalyzer classes, adapting heuristics and recommendations for JAX’s XLA runtime and device memory model.

Parameters:
  • sensitivity (float)

  • collective_sensitivity (str)

  • collective_threshold_overrides (Optional[Mapping[str, Any]])

detect_memory_leaks(results)[source]

Detect potential memory leaks in JAX telemetry.

Uses linear regression over the memory-usage series to detect sustained upward drift.

Parameters:

results (Any) – Object with a memory_usage attribute (numeric sequence of at least 10 samples).

Returns:

List of leak-detection dicts with type, severity, description, and slope keys.

Return type:

List[Dict[str, Any]]

detect_patterns(results)[source]

Detect allocation patterns in JAX telemetry.

Performs simplified autocorrelation analysis to identify periodic memory usage behaviour (e.g. per-step allocations in a training loop).

Parameters:

results (Any) – Object with a memory_usage attribute (numeric sequence of at least 10 samples).

Returns:

List of detected pattern dicts.

Return type:

List[Dict[str, Any]]

analyze_fragmentation(profile_result)[source]

Analyse memory fragmentation patterns.

Computes fragmentation as 1 (used / reserved) across profiling snapshots.

Parameters:

profile_result (Any) – Profiling result with a snapshots attribute where each snapshot exposes device_memory_mb and device_memory_reserved_mb.

Returns:

Dictionary with fragmentation_score, trend, max_fragmentation, and min_fragmentation.

Return type:

Dict[str, float]

analyze_efficiency(profile_result)[source]

Analyse memory usage efficiency.

Returns a score on a 0.0–1.0 scale (1.0 = excellent). The score starts at 1.0 and is reduced by penalties for high peak memory, high growth rate, fragmentation, and detected leaks.

Parameters:

profile_result (Any) – Profiling result with peak_memory_mb and optionally memory_growth_rate, snapshots, and memory_usage.

Returns:

Efficiency score in [0.0, 1.0].

Return type:

float

correlate_with_performance(profile_result)[source]

Correlate memory usage with performance metrics.

Analyses per-function efficiency based on memory consumption and execution duration.

Parameters:

profile_result (Any) – Profiling result with function_profiles mapping function names to dicts containing calls, total_memory_delta, and total_duration.

Returns:

Dictionary with memory_duration_correlation, function_efficiency, and recommendations.

Return type:

Dict[str, Any]

score_optimization(profile_result, events=None)[source]

Generate an overall optimisation score with recommendations.

Combines memory efficiency, fragmentation, and per-function performance scores into a single summary.

Parameters:
  • profile_result (Any) – JAX profiling result object.

  • events (List | None) – Optional telemetry event series for gap analysis. When provided, the result includes gap_analysis and collective_attribution sections.

Returns:

Dictionary with overall_score, categories, top_recommendations, and priority_actions.

Return type:

Dict[str, Any]

analyze_memory_gaps(events, *, phase_resolver=None)[source]

Classify allocator-vs-device hidden memory gaps over time.

Parameters:
  • events (List) – Chronologically ordered telemetry samples.

  • phase_resolver (Any | None) – Optional PhaseReplayIndex for phase attribution.

Returns:

Prioritised list of gap findings (severity desc, confidence desc). Returns an empty list when the gap_analysis sub-package is not available.

Return type:

List

analyze_collective_attribution(events, *, phase_resolver=None)[source]

Attribute hidden-memory spikes to collective communication phases.

Parameters:
  • events (List) – Chronologically ordered telemetry samples.

  • phase_resolver (Any | None) – Optional PhaseReplayIndex for phase attribution.

Returns:

List of CollectiveAttributionResult objects. Returns an empty list when the collective_attribution sub-package is not available.

Return type:

List