"""Backend-aware device memory collector abstractions."""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Union
import torch
[docs]
@dataclass(frozen=True)
class DeviceMemorySample:
"""Normalized device-memory sample produced by a backend collector."""
allocated_bytes: int
reserved_bytes: int
used_bytes: int
free_bytes: Optional[int]
total_bytes: Optional[int]
active_bytes: Optional[int]
inactive_bytes: Optional[int]
device_id: int
[docs]
@dataclass(frozen=True)
class DeviceMemorySampleResult:
"""Device-memory sample plus diagnostics about partial/core collection failures."""
sample: Optional[DeviceMemorySample]
partial_fields: tuple[str, ...] = ()
errors: dict[str, str] = field(default_factory=dict)
core_error: Optional[str] = None
@property
def is_partial(self) -> bool:
return self.sample is not None and bool(self.partial_fields)
@property
def is_core_failure(self) -> bool:
return self.sample is None
[docs]
class DeviceMemoryCollector(ABC):
"""Backend-specific collector contract for device memory signals."""
[docs]
@abstractmethod
def name(self) -> str:
"""Return runtime backend name (cuda, rocm, mps)."""
[docs]
@abstractmethod
def is_available(self) -> bool:
"""Return whether this collector can sample in the current runtime."""
[docs]
@abstractmethod
def sample(self) -> DeviceMemorySample:
"""Collect a single normalized memory sample."""
[docs]
def sample_with_diagnostics(self) -> DeviceMemorySampleResult:
"""Collect a sample while preserving core-failure diagnostics."""
try:
return DeviceMemorySampleResult(sample=self.sample())
except Exception as exc:
return DeviceMemorySampleResult(
sample=None,
errors={"core_metrics": str(exc)},
core_error=str(exc),
)
[docs]
@abstractmethod
def capabilities(self) -> Dict[str, Any]:
"""Describe backend capability signals for telemetry metadata."""
def _is_mps_available() -> bool:
mps_backend = getattr(torch.backends, "mps", None)
if mps_backend is None:
return False
try:
return bool(mps_backend.is_available())
except Exception:
return False
def _is_rocm_runtime() -> bool:
hip_version = getattr(torch.version, "hip", None)
return bool(torch.cuda.is_available() and hip_version)
[docs]
def detect_torch_runtime_backend() -> str:
"""Return the active torch runtime backend in this environment."""
if torch.cuda.is_available():
return "rocm" if _is_rocm_runtime() else "cuda"
if _is_mps_available():
return "mps"
return "cpu"
def _resolve_device(device: Union[str, int, torch.device, None]) -> torch.device:
if device is None:
backend = detect_torch_runtime_backend()
if backend in {"cuda", "rocm"}:
return torch.device(f"cuda:{torch.cuda.current_device()}")
if backend == "mps":
return torch.device("mps")
raise RuntimeError("No supported GPU backend is available")
if isinstance(device, int):
backend = detect_torch_runtime_backend()
if backend not in {"cuda", "rocm"}:
raise ValueError(
"Integer device IDs are only supported for CUDA/ROCm backends"
)
return torch.device(f"cuda:{device}")
if isinstance(device, str):
return torch.device(device)
return device
[docs]
class CudaDeviceCollector(DeviceMemoryCollector):
"""Collector for NVIDIA CUDA runtime memory counters."""
telemetry_collector = "stormlog.cuda_tracker"
def __init__(self, device: Union[str, int, torch.device, None] = None) -> None:
self.device = _resolve_device(device)
if self.device.type != "cuda":
raise ValueError("CUDA collector requires a CUDA device")
[docs]
def name(self) -> str:
return "cuda"
[docs]
def is_available(self) -> bool:
return bool(torch.cuda.is_available() and not _is_rocm_runtime())
[docs]
def sample(self) -> DeviceMemorySample:
result = self.sample_with_diagnostics()
if result.sample is None:
raise RuntimeError(result.core_error or "CUDA sample collection failed")
return result.sample
[docs]
def sample_with_diagnostics(self) -> DeviceMemorySampleResult:
device_index = (
self.device.index
if self.device.index is not None
else torch.cuda.current_device()
)
try:
allocated = int(torch.cuda.memory_allocated(self.device))
reserved = int(torch.cuda.memory_reserved(self.device))
except Exception as exc:
return DeviceMemorySampleResult(
sample=None,
errors={"core_metrics": str(exc)},
core_error=str(exc),
)
used = max(allocated, reserved)
total: Optional[int] = None
free: Optional[int] = None
active: Optional[int] = None
inactive: Optional[int] = None
partial_fields: list[str] = []
errors: dict[str, str] = {}
try:
total = int(torch.cuda.get_device_properties(self.device).total_memory)
free = max(total - used, 0)
except Exception as exc:
message = str(exc)
partial_fields.extend(["device_total_bytes", "device_free_bytes"])
errors["device_total_bytes"] = message
errors["device_free_bytes"] = message
try:
stats = torch.cuda.memory_stats(self.device)
active = int(stats.get("active_bytes.all.current", 0))
inactive = int(stats.get("inactive_split_bytes.all.current", 0))
except Exception as exc:
message = str(exc)
partial_fields.extend(
["allocator_active_bytes", "allocator_inactive_bytes"]
)
errors["allocator_active_bytes"] = message
errors["allocator_inactive_bytes"] = message
return DeviceMemorySampleResult(
sample=DeviceMemorySample(
allocated_bytes=allocated,
reserved_bytes=reserved,
used_bytes=used,
free_bytes=free,
total_bytes=total,
active_bytes=active,
inactive_bytes=inactive,
device_id=device_index,
),
partial_fields=tuple(dict.fromkeys(partial_fields)),
errors=errors,
)
[docs]
def capabilities(self) -> Dict[str, Any]:
return {
"backend": self.name(),
"supports_device_total": True,
"supports_device_free": True,
"sampling_source": "torch.cuda.memory_allocated/reserved",
"telemetry_collector": self.telemetry_collector,
}
[docs]
class ROCmDeviceCollector(CudaDeviceCollector):
"""Collector for ROCm runtimes surfaced through torch.cuda APIs."""
telemetry_collector = "stormlog.rocm_tracker"
[docs]
def name(self) -> str:
return "rocm"
[docs]
def is_available(self) -> bool:
return _is_rocm_runtime()
[docs]
def capabilities(self) -> Dict[str, Any]:
capabilities = super().capabilities()
capabilities.update(
{
"backend": self.name(),
"sampling_source": "torch.cuda.memory_* (HIP runtime)",
"telemetry_collector": self.telemetry_collector,
}
)
return capabilities
[docs]
class MPSDeviceCollector(DeviceMemoryCollector):
"""Collector for Apple Metal (MPS) runtime counters."""
telemetry_collector = "stormlog.mps_tracker"
def __init__(self, device: Union[str, int, torch.device, None] = None) -> None:
resolved = _resolve_device(device)
if resolved.type != "mps":
raise ValueError("MPS collector requires an MPS device")
self.device = resolved
[docs]
def name(self) -> str:
return "mps"
[docs]
def is_available(self) -> bool:
return _is_mps_available()
[docs]
def sample(self) -> DeviceMemorySample:
result = self.sample_with_diagnostics()
if result.sample is None:
raise RuntimeError(result.core_error or "MPS sample collection failed")
return result.sample
[docs]
def sample_with_diagnostics(self) -> DeviceMemorySampleResult:
import torch.mps as torch_mps
try:
allocated = int(torch_mps.current_allocated_memory())
reserved = int(torch_mps.driver_allocated_memory())
except Exception as exc:
return DeviceMemorySampleResult(
sample=None,
errors={"core_metrics": str(exc)},
core_error=str(exc),
)
used = max(allocated, reserved)
total: Optional[int] = None
partial_fields: list[str] = []
errors: dict[str, str] = {}
if hasattr(torch_mps, "recommended_max_memory"):
try:
# MPS does not expose a strict physical-total API here; this is the
# best runtime approximation currently available from torch.
raw_total = int(torch_mps.recommended_max_memory())
total = raw_total if raw_total > 0 else None
except Exception as exc:
message = str(exc)
partial_fields.extend(["device_total_bytes", "device_free_bytes"])
errors["device_total_bytes"] = message
errors["device_free_bytes"] = message
free = max(total - used, 0) if total is not None else None
return DeviceMemorySampleResult(
sample=DeviceMemorySample(
allocated_bytes=allocated,
reserved_bytes=reserved,
used_bytes=used,
free_bytes=free,
total_bytes=total,
active_bytes=None,
inactive_bytes=None,
device_id=0,
),
partial_fields=tuple(dict.fromkeys(partial_fields)),
errors=errors,
)
[docs]
def capabilities(self) -> Dict[str, Any]:
import torch.mps as torch_mps
supports_total = hasattr(torch_mps, "recommended_max_memory")
return {
"backend": self.name(),
"supports_device_total": supports_total,
"supports_device_free": supports_total,
"sampling_source": "torch.mps.current_allocated_memory/driver_allocated_memory",
"telemetry_collector": self.telemetry_collector,
}
[docs]
def build_device_memory_collector(
device: Union[str, int, torch.device, None] = None,
) -> DeviceMemoryCollector:
"""Build a backend collector for CUDA/ROCm/MPS runtime environments."""
resolved = _resolve_device(device)
if resolved.type == "cuda":
if _is_rocm_runtime():
return ROCmDeviceCollector(resolved)
return CudaDeviceCollector(resolved)
if resolved.type == "mps":
return MPSDeviceCollector(resolved)
raise ValueError("Only CUDA/ROCm and MPS devices are supported for tracking")
__all__ = [
"DeviceMemoryCollector",
"DeviceMemorySample",
"DeviceMemorySampleResult",
"CudaDeviceCollector",
"ROCmDeviceCollector",
"MPSDeviceCollector",
"build_device_memory_collector",
"detect_torch_runtime_backend",
]