"""CPU-only memory profiler and tracker."""
from __future__ import annotations
import csv
import json
import logging
import os
import socket
import threading
import time
from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
import psutil
from stormlog.phases import PhaseHandle, PhaseRecorder, PhaseToken
from stormlog.session import (
SESSION_STATUS_COMPLETED,
SESSION_STATUS_INCOMPLETE,
SESSION_STATUS_RUNNING,
SessionSummary,
create_session_summary,
finalize_session_summary,
now_ns,
update_session_summary,
)
from stormlog.telemetry import (
resolve_distributed_identity,
telemetry_event_from_record,
telemetry_event_to_dict,
)
from stormlog.telemetry_sink import AppendOnlyTelemetrySink, TelemetrySinkConfig
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from stormlog.tracker import TrackingEvent
else:
try:
from stormlog.tracker import TrackingEvent
except ImportError:
@dataclass
class TrackingEvent:
"""Fallback CPU tracking event used when GPU tracker imports are unavailable."""
timestamp: float
event_type: str
memory_allocated: int
memory_reserved: int
memory_change: int
device_id: int
session_id: Optional[str] = None
context: Optional[str] = None
job_id: Optional[str] = None
rank: int = 0
local_rank: int = 0
world_size: int = 1
metadata: Optional[Dict[str, Any]] = None
active_memory: Optional[int] = None
inactive_memory: Optional[int] = None
device_used: Optional[int] = None
device_free: Optional[int] = None
device_total: Optional[int] = None
backend: str = "cpu"
[docs]
@dataclass
class CPUMemorySnapshot:
"""Point-in-time CPU memory snapshot."""
timestamp: float
rss: int
vms: int
cpu_percent: float
[docs]
def to_dict(self) -> Dict[str, Any]:
return {
"timestamp": self.timestamp,
"rss": self.rss,
"vms": self.vms,
"cpu_percent": self.cpu_percent,
}
[docs]
@dataclass
class CPUProfileResult:
"""Results from profiling a CPU function/context."""
name: str
duration: float
snapshot_before: CPUMemorySnapshot
snapshot_after: CPUMemorySnapshot
peak_rss: int
[docs]
def memory_diff(self) -> int:
return self.snapshot_after.rss - self.snapshot_before.rss
[docs]
def to_dict(self) -> Dict[str, Any]:
return {
"name": self.name,
"duration": self.duration,
"memory_diff": self.memory_diff(),
"peak_rss": self.peak_rss,
"before": self.snapshot_before.to_dict(),
"after": self.snapshot_after.to_dict(),
}
[docs]
class CPUMemoryProfiler:
"""Lightweight CPU memory profiler mirroring the GPU API."""
def __init__(self) -> None:
self.process = psutil.Process()
self.snapshots: List[CPUMemorySnapshot] = []
self.results: List[CPUProfileResult] = []
self._monitoring = False
self._monitor_thread: Optional[threading.Thread] = None
self._monitor_interval = 0.1
self._baseline_snapshot = self._take_snapshot()
def _take_snapshot(self) -> CPUMemorySnapshot:
with self.process.oneshot():
mem = self.process.memory_info()
cpu_pct = self.process.cpu_percent(interval=None)
return CPUMemorySnapshot(
timestamp=time.time(),
rss=mem.rss,
vms=mem.vms,
cpu_percent=cpu_pct,
)
[docs]
def start_monitoring(self, interval: float = 0.1) -> None:
if self._monitoring:
return
if interval <= 0:
raise ValueError("interval must be > 0")
self._monitoring = True
self._monitor_interval = interval
self._monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
self._monitor_thread.start()
def _monitor_loop(self) -> None:
while self._monitoring:
self.snapshots.append(self._take_snapshot())
time.sleep(self._monitor_interval)
[docs]
def stop_monitoring(self) -> None:
self._monitoring = False
if self._monitor_thread:
self._monitor_thread.join(timeout=1.0)
self._monitor_thread = None
[docs]
def profile_function(
self, func: Callable[..., Any], *args: Any, **kwargs: Any
) -> CPUProfileResult:
before = self._take_snapshot()
start = time.time()
_result = func(*args, **kwargs)
end = time.time()
after = self._take_snapshot()
peak_rss = max(before.rss, after.rss)
profile = CPUProfileResult(
name=getattr(func, "__name__", "cpu_function"),
duration=end - start,
snapshot_before=before,
snapshot_after=after,
peak_rss=peak_rss,
)
self.results.append(profile)
return profile
[docs]
def profile_context(self, name: str = "context") -> Any:
class _Context:
def __init__(self, outer: CPUMemoryProfiler, label: str) -> None:
self.outer = outer
self.label = label
def __enter__(self) -> Any:
self.before = self.outer._take_snapshot()
self.start = time.time()
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
after = self.outer._take_snapshot()
end = time.time()
peak_rss = max(self.before.rss, after.rss)
profile = CPUProfileResult(
name=self.label,
duration=end - self.start,
snapshot_before=self.before,
snapshot_after=after,
peak_rss=peak_rss,
)
self.outer.results.append(profile)
return _Context(self, name)
[docs]
def clear_results(self) -> None:
self.snapshots.clear()
self.results.clear()
self._baseline_snapshot = self._take_snapshot()
[docs]
def get_summary(self) -> Dict[str, Any]:
if not self.snapshots:
snapshots = [self._baseline_snapshot]
else:
snapshots = self.snapshots
rss_values = [snap.rss for snap in snapshots]
peak = max(rss_values) if rss_values else self._baseline_snapshot.rss
change = rss_values[-1] - rss_values[0] if len(rss_values) > 1 else 0
return {
"mode": "cpu",
"snapshots_collected": len(self.snapshots),
"peak_memory_usage": peak,
"memory_change_from_baseline": change,
"baseline_rss": self._baseline_snapshot.rss,
}
[docs]
class CPUMemoryTracker:
"""CPU tracker offering a superset of the GPU tracker interface."""
def __init__(
self,
sampling_interval: float = 0.5,
max_events: int = 10_000,
enable_alerts: bool = True,
job_id: Optional[str] = None,
rank: Optional[int] = None,
local_rank: Optional[int] = None,
world_size: Optional[int] = None,
telemetry_sink_config: Optional[TelemetrySinkConfig] = None,
) -> None:
if sampling_interval <= 0:
raise ValueError("sampling_interval must be > 0")
if max_events <= 0:
raise ValueError("max_events must be >= 1")
self.process = psutil.Process()
self.sampling_interval = sampling_interval
self.max_events = max_events
self.events: deque[TrackingEvent] = deque(maxlen=max_events)
self._events_lock = threading.Lock()
self.is_tracking = False
self._stop_event = threading.Event()
self._tracking_thread: Optional[threading.Thread] = None
self.enable_alerts = enable_alerts
self._telemetry_sink_config = telemetry_sink_config
self._telemetry_sink = (
AppendOnlyTelemetrySink(telemetry_sink_config)
if telemetry_sink_config is not None
else None
)
self.distributed_identity = resolve_distributed_identity(
job_id=job_id,
rank=rank,
local_rank=local_rank,
world_size=world_size,
env=os.environ,
)
self.session_source = "stormlog.cpu_profiler"
self._session_summary: Optional[SessionSummary] = None
self._history_dropped_events = 0
self._last_sink_diagnostics: Dict[str, int] = self._empty_sink_diagnostics()
self._phase_state = PhaseRecorder()
self.stats: Dict[str, Any] = {
"tracking_start_time": None,
"peak_memory": 0,
"total_events": 0,
"alert_count": 0,
}
@staticmethod
def _empty_sink_diagnostics() -> Dict[str, int]:
return {
"rollover_count": 0,
"pruned_segment_count": 0,
"pruned_bytes": 0,
"final_retained_files": 0,
"final_retained_bytes": 0,
}
def _current_rss(self) -> int:
with self.process.oneshot():
return int(self.process.memory_info().rss)
def _open_session(self) -> SessionSummary:
if self._session_summary is not None:
return self._session_summary
summary = create_session_summary(
source=self.session_source,
status=SESSION_STATUS_RUNNING,
started_at_ns=now_ns(),
host=socket.gethostname(),
pid=os.getpid(),
job_id=self.distributed_identity["job_id"],
rank=self.distributed_identity["rank"],
local_rank=self.distributed_identity["local_rank"],
world_size=self.distributed_identity["world_size"],
)
self._session_summary = summary
if self._telemetry_sink is not None and hasattr(
self._telemetry_sink, "start_session"
):
self._session_summary = self._telemetry_sink.start_session(summary)
return self._session_summary
[docs]
def get_session_summary(self) -> Optional[SessionSummary]:
return self._session_summary
def _ensure_telemetry_sink(self) -> None:
if self._telemetry_sink is None and self._telemetry_sink_config is not None:
self._telemetry_sink = AppendOnlyTelemetrySink(self._telemetry_sink_config)
[docs]
def start_tracking(self) -> None:
if self.is_tracking:
return
self._session_summary = None
self._phase_state.reset()
self._ensure_telemetry_sink()
self._stop_event.clear()
with self._events_lock:
self.events.clear()
self.stats["peak_memory"] = 0
self.stats["total_events"] = 0
self.stats["alert_count"] = 0
self.stats["tracking_start_time"] = time.time()
self._history_dropped_events = 0
self._last_sink_diagnostics = self._empty_sink_diagnostics()
self._open_session()
self._tracking_thread = threading.Thread(
target=self._tracking_loop, daemon=True
)
self._tracking_thread.start()
self.is_tracking = True
self._add_event("start", 0, "CPU memory tracking started")
[docs]
def stop_tracking(self) -> None:
if not self.is_tracking:
return
self.is_tracking = False
self._stop_event.set()
if self._tracking_thread:
self._tracking_thread.join(timeout=1.0)
self._add_event("stop", 0, "CPU memory tracking stopped")
self._close_telemetry_sink()
if self._session_summary is not None:
self._session_summary = finalize_session_summary(
self._session_summary,
ended_at_ns=now_ns(),
)
self._phase_state.reset()
def _tracking_loop(self) -> None:
last_rss = self._current_rss()
while not self._stop_event.wait(self.sampling_interval):
try:
current_rss = self._current_rss()
except Exception as exc:
logger.debug("Error sampling RSS in tracking loop: %s", exc)
continue
change = current_rss - last_rss
is_new_peak = False
with self._events_lock:
self.stats["total_events"] += 1
if current_rss > self.stats["peak_memory"]:
self.stats["peak_memory"] = current_rss
is_new_peak = True
if is_new_peak:
self._add_event(
"peak",
change,
f"New CPU peak RSS: {self._format_bytes(current_rss)}",
rss=current_rss,
)
if change > 0:
self._add_event(
"allocation",
change,
f"RSS increased by {self._format_bytes(change)}",
rss=current_rss,
)
elif change < 0:
self._add_event(
"deallocation",
change,
f"RSS decreased by {self._format_bytes(abs(change))}",
rss=current_rss,
)
self._add_event(
"sample",
0,
"Collected CPU telemetry sample.",
rss=current_rss,
)
last_rss = current_rss
self._flush_telemetry_sink()
def _add_event(
self,
event_type: str,
memory_change: int,
context: str,
metadata: Optional[Dict[str, Any]] = None,
*,
rss: int | None = None,
) -> None:
if rss is None:
rss = self._current_rss()
event = TrackingEvent(
timestamp=time.time(),
event_type=event_type,
memory_allocated=rss,
memory_reserved=rss,
memory_change=memory_change,
device_id=-1,
session_id=(
self._open_session().session_id
if self._session_summary is None
else self._session_summary.session_id
),
context=context,
job_id=self.distributed_identity["job_id"],
rank=self.distributed_identity["rank"],
local_rank=self.distributed_identity["local_rank"],
world_size=self.distributed_identity["world_size"],
metadata=dict(metadata or {}),
)
with self._events_lock:
if len(self.events) == self.max_events:
self._history_dropped_events += 1
self.events.append(event)
self._append_to_telemetry_sink(event)
[docs]
def enter_phase(
self, name: str, *, metadata: Optional[Dict[str, Any]] = None
) -> PhaseHandle:
"""Enter one structured CPU tracking phase."""
if not self.is_tracking:
raise RuntimeError("Tracking must be active before entering a phase.")
session = self._open_session()
token, boundary = self._phase_state.enter(
session_id=session.session_id,
rank=self.distributed_identity["rank"],
name=name,
attrs=metadata,
)
self._add_event(
boundary.event_type,
0,
boundary.context,
metadata=boundary.metadata,
)
return PhaseHandle(
scope_id=boundary.scope_id,
name=name,
path=boundary.path,
close_callback=lambda: self._emit_phase_exit(token),
)
[docs]
@contextmanager
def phase(self, name: str, *, metadata: Optional[Dict[str, Any]] = None) -> Any:
"""Context manager that emits structured CPU phase telemetry."""
handle = self.enter_phase(name, metadata=metadata)
try:
yield handle
finally:
handle.close()
def _emit_phase_exit(self, token: PhaseToken) -> None:
boundary = self._phase_state.exit(token)
self._add_event(
boundary.event_type,
0,
boundary.context,
metadata=boundary.metadata,
)
def _telemetry_record_from_event(self, event: TrackingEvent) -> Dict[str, Any]:
host = socket.gethostname()
pid = os.getpid()
sampling_interval_ms = int(round(self.sampling_interval * 1000))
return telemetry_event_to_dict(
telemetry_event_from_record(
{
"session_id": event.session_id
or (
self._session_summary.session_id
if self._session_summary is not None
else self._open_session().session_id
),
"timestamp": event.timestamp,
"event_type": event.event_type,
"memory_allocated": event.memory_allocated,
"memory_reserved": event.memory_reserved,
"memory_change": event.memory_change,
"device_id": event.device_id,
"context": event.context,
"job_id": event.job_id,
"rank": event.rank,
"local_rank": event.local_rank,
"world_size": event.world_size,
"metadata": dict(event.metadata or {}),
"collector": "stormlog.cpu_tracker",
"sampling_interval_ms": sampling_interval_ms,
"pid": pid,
"host": host,
},
default_collector="stormlog.cpu_tracker",
default_sampling_interval_ms=sampling_interval_ms,
default_session_id=event.session_id,
)
)
def _append_to_telemetry_sink(self, event: TrackingEvent) -> None:
if self._telemetry_sink is None:
return
try:
self._telemetry_sink.append(self._telemetry_record_from_event(event))
self._last_sink_diagnostics = self._telemetry_sink.get_diagnostics()
except Exception as exc:
self._disable_telemetry_sink("append", exc)
def _flush_telemetry_sink(self, *, force: bool = False) -> None:
if self._telemetry_sink is None:
return
try:
self._telemetry_sink.flush(force=force)
self._last_sink_diagnostics = self._telemetry_sink.get_diagnostics()
except Exception as exc:
self._disable_telemetry_sink("flush", exc)
def _close_telemetry_sink(self) -> None:
if self._telemetry_sink is None:
return
try:
self._close_sink_with_status(
self._telemetry_sink,
SESSION_STATUS_COMPLETED,
)
self._last_sink_diagnostics = self._telemetry_sink.get_diagnostics()
except Exception as exc:
self._disable_telemetry_sink("close", exc)
else:
self._telemetry_sink = None
def _disable_telemetry_sink(self, operation: str, exc: Exception) -> None:
sink = self._telemetry_sink
if sink is None:
return
self._telemetry_sink = None
logger.warning(
"Disabling CPU telemetry sink after %s failure: %s",
operation,
exc,
)
if self._session_summary is not None:
self._session_summary = update_session_summary(
self._session_summary,
status=SESSION_STATUS_INCOMPLETE,
ended_at_ns=now_ns(),
)
try:
self._close_sink_with_status(sink, SESSION_STATUS_INCOMPLETE)
if hasattr(sink, "get_diagnostics"):
self._last_sink_diagnostics = sink.get_diagnostics()
except Exception as close_exc:
logger.debug(
"CPU telemetry sink close failed after %s error: %s",
operation,
close_exc,
)
@staticmethod
def _close_sink_with_status(sink: Any, status: str) -> None:
try:
sink.close(session_status=status)
except TypeError:
sink.close()
[docs]
def get_events(
self,
event_type: Optional[str] = None,
last_n: Optional[int] = None,
since: Optional[float] = None,
) -> List[TrackingEvent]:
"""
Get tracking events with optional filtering.
Args:
event_type: Filter by event type
last_n: Get last N events
since: Get events since timestamp
Returns:
List of filtered events
"""
with self._events_lock:
events: List[TrackingEvent] = list(self.events)
# Filter by type
if event_type:
events = [e for e in events if e.event_type == event_type]
# Filter by time
if since:
events = [e for e in events if e.timestamp >= since]
# Limit results
if last_n:
events = events[-last_n:]
return events
[docs]
def get_statistics(self) -> Dict[str, Any]:
rss = self._current_rss()
with self._events_lock:
total_events = len(self.events)
peak_memory = self.stats["peak_memory"]
tracking_start_time = self.stats.get("tracking_start_time")
duration = 0.0
if isinstance(tracking_start_time, (int, float)):
duration = time.time() - float(tracking_start_time)
with self._events_lock:
retained_events = len(self.events)
return {
"mode": "cpu",
"total_events": total_events,
"peak_memory": peak_memory,
"current_memory_allocated": rss,
"tracking_duration_seconds": duration,
"history_window_limit_events": self.max_events,
"history_retained_events": retained_events,
"history_dropped_events": self._history_dropped_events,
**self._last_sink_diagnostics,
"session_id": (
self._session_summary.session_id
if self._session_summary is not None
else None
),
"session_status": (
self._session_summary.status
if self._session_summary is not None
else None
),
}
[docs]
def get_memory_timeline(self, interval: float = 1.0) -> Dict[str, List[float]]:
with self._events_lock:
events_snapshot = list(self.events)
if not events_snapshot:
return {"timestamps": [], "allocated": [], "reserved": []}
timestamps = [event.timestamp for event in events_snapshot]
allocated = [float(event.memory_allocated) for event in events_snapshot]
return {
"timestamps": timestamps,
"allocated": allocated,
"reserved": allocated,
}
[docs]
def clear_events(self) -> None:
with self._events_lock:
self.events.clear()
self.stats["peak_memory"] = 0
self.stats["total_events"] = 0
self._history_dropped_events = 0
[docs]
def export_events(self, filename: str, format: str = "csv") -> None:
with self._events_lock:
events_snapshot = list(self.events)
records = [
self._telemetry_record_from_event(event) for event in events_snapshot
]
if not records:
return
if format == "csv":
with open(filename, "w", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=records[0].keys())
writer.writeheader()
writer.writerows(records)
elif format == "json":
with open(filename, "w") as jsonfile:
json.dump(records, jsonfile, indent=2, default=str)
else:
raise ValueError(f"Unsupported format: {format}")
[docs]
def export_events_with_timestamp(self, directory: str, format: str) -> str:
filename = f"{directory}/cpu_tracker_{datetime.now().strftime('%Y%m%d_%H%M%S')}.{format}"
self.export_events(filename, format=format)
return filename
@staticmethod
def _format_bytes(value: int) -> str:
units = ["B", "KB", "MB", "GB", "TB"]
size = float(value)
unit_idx = 0
while size >= 1024 and unit_idx < len(units) - 1:
size /= 1024
unit_idx += 1
return f"{size:.2f} {units[unit_idx]}"