"""OOM flight recorder helpers for bounded event capture and dump artifacts."""
from __future__ import annotations
import json
import logging
import os
import re
import shutil
import threading
from collections import deque
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Optional
from .session import SessionSummary, session_summary_to_dict
from .utils import get_system_info
logger = logging.getLogger(__name__)
_OOM_MESSAGE_PATTERNS = (
"out of memory",
"cuda out of memory",
"hip out of memory",
"resource exhausted",
"failed to allocate",
"allocation failed",
"memoryerror",
"std::bad_alloc",
)
[docs]
@dataclass(frozen=True)
class OOMFlightRecorderConfig:
"""Runtime configuration for OOM flight recorder dumps."""
enabled: bool = False
dump_dir: str = "oom_dumps"
buffer_size: int = 10_000
max_dumps: int = 5
max_total_mb: int = 256
[docs]
@dataclass(frozen=True)
class OOMExceptionClassification:
"""Normalized classification result for an exception."""
is_oom: bool
reason: Optional[str]
[docs]
def classify_oom_exception(exc: BaseException) -> OOMExceptionClassification:
"""Classify whether an exception corresponds to an OOM condition."""
try:
import torch
torch_oom_type = getattr(getattr(torch, "cuda", None), "OutOfMemoryError", None)
if torch_oom_type is not None and isinstance(exc, torch_oom_type):
return OOMExceptionClassification(True, "torch.cuda.OutOfMemoryError")
except Exception:
# Keep classification resilient even if torch isn't importable in this runtime.
pass
exc_type = type(exc)
exc_name = exc_type.__name__.lower()
exc_module = exc_type.__module__.lower()
if exc_name.endswith("resourceexhaustederror"):
return OOMExceptionClassification(True, "tensorflow.ResourceExhaustedError")
if exc_name == "outofmemoryerror" and "torch" in exc_module:
return OOMExceptionClassification(True, "torch.cuda.OutOfMemoryError")
message = str(exc).strip().lower()
for pattern in _OOM_MESSAGE_PATTERNS:
if pattern in message:
return OOMExceptionClassification(True, f"message_pattern:{pattern}")
return OOMExceptionClassification(False, None)
[docs]
class OOMFlightRecorder:
"""Bounded recorder that writes dump bundles on OOM."""
def __init__(self, config: OOMFlightRecorderConfig) -> None:
self.config = config
bounded_size = max(1, int(config.buffer_size))
self._events: deque[dict[str, Any]] = deque(maxlen=bounded_size)
self._events_lock = threading.Lock()
self._dump_sequence = 0
self._sequence_lock = threading.Lock()
[docs]
def record_event(self, event: dict[str, Any]) -> None:
"""Append one event payload to the in-memory ring buffer."""
with self._events_lock:
self._events.append(dict(event))
[docs]
def snapshot_events(self) -> list[dict[str, Any]]:
"""Return buffered events in chronological order."""
with self._events_lock:
return [dict(event) for event in self._events]
[docs]
def clear(self) -> None:
"""Discard buffered events for the next session/run."""
with self._events_lock:
self._events.clear()
[docs]
def dump(
self,
*,
reason: str,
exception: BaseException,
context: Optional[str],
backend: str,
metadata: Optional[dict[str, Any]] = None,
session_summary: SessionSummary | None = None,
) -> Optional[str]:
"""Write an OOM diagnostic bundle and enforce retention constraints."""
if not self.config.enabled:
return None
root = Path(self.config.dump_dir)
root.mkdir(parents=True, exist_ok=True)
bundle_dir = self._next_bundle_dir(root=root, backend=backend)
bundle_dir.mkdir(parents=True, exist_ok=False)
events_payload = self.snapshot_events()
metadata_payload = {
"reason": reason,
"exception_type": type(exception).__name__,
"exception_module": type(exception).__module__,
"exception_message": str(exception),
"context": context,
"backend": backend,
"captured_event_count": len(events_payload),
"session_id": (
session_summary.session_id if session_summary is not None else None
),
"session_status": (
session_summary.status if session_summary is not None else None
),
"session": (
session_summary_to_dict(session_summary)
if session_summary is not None
else None
),
"custom_metadata": dict(metadata or {}),
}
environment_payload = {
"pid": os.getpid(),
"cwd": str(Path.cwd()),
"system": get_system_info(),
}
self._write_json(bundle_dir / "events.json", events_payload)
self._write_json(bundle_dir / "metadata.json", metadata_payload)
self._write_json(bundle_dir / "environment.json", environment_payload)
manifest_payload = {
"schema_version": 2,
"bundle_name": bundle_dir.name,
"created_at_utc": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"),
"reason": reason,
"backend": backend,
"event_count": len(events_payload),
"session_id": (
session_summary.session_id if session_summary is not None else None
),
"session_status": (
session_summary.status if session_summary is not None else None
),
"session": (
session_summary_to_dict(session_summary)
if session_summary is not None
else None
),
"files": [
"manifest.json",
"events.json",
"metadata.json",
"environment.json",
],
}
self._write_json(bundle_dir / "manifest.json", manifest_payload)
self._prune_retention(root)
return str(bundle_dir)
def _next_bundle_dir(self, *, root: Path, backend: str) -> Path:
timestamp_utc = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
safe_backend = re.sub(r"[^a-zA-Z0-9_-]+", "_", backend) or "unknown"
while True:
with self._sequence_lock:
self._dump_sequence += 1
candidate = root / (
f"oom_dump_{timestamp_utc}_{os.getpid()}_{safe_backend}_{self._dump_sequence}"
)
if not candidate.exists():
return candidate
def _prune_retention(self, root: Path) -> None:
# Retention consistently uses oldest->newest ordering.
bundles = self._list_bundles_oldest_first(root)
if self.config.max_dumps > 0 and len(bundles) > self.config.max_dumps:
for stale in bundles[: -self.config.max_dumps]:
shutil.rmtree(stale, ignore_errors=True)
max_total_bytes = self.config.max_total_mb * 1024 * 1024
if max_total_bytes <= 0:
return
bundles = self._list_bundles_oldest_first(root)
total_bytes = sum(self._bundle_size_bytes(path) for path in bundles)
while bundles and total_bytes > max_total_bytes:
oldest = bundles.pop(0)
bundle_size = self._bundle_size_bytes(oldest)
shutil.rmtree(oldest, ignore_errors=True)
total_bytes = max(total_bytes - bundle_size, 0)
@staticmethod
def _write_json(path: Path, payload: Any) -> None:
with path.open("w", encoding="utf-8") as handle:
json.dump(payload, handle, indent=2, default=str)
@staticmethod
def _list_bundles_oldest_first(root: Path) -> list[Path]:
return sorted(
[
path
for path in root.iterdir()
if path.is_dir() and path.name.startswith("oom_dump_")
],
key=lambda path: path.stat().st_mtime,
)
@staticmethod
def _bundle_size_bytes(bundle_dir: Path) -> int:
total = 0
for file_path in bundle_dir.rglob("*"):
if file_path.is_file():
total += file_path.stat().st_size
return total
__all__ = [
"OOMFlightRecorder",
"OOMFlightRecorderConfig",
"OOMExceptionClassification",
"classify_oom_exception",
]