Source code for stormlog.telemetry

"""Canonical telemetry event schema and legacy conversion helpers."""

from __future__ import annotations

import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Iterable, Literal, Mapping, Optional

from .session import (
    SESSION_STATUS_INCOMPLETE,
    SessionSummary,
    infer_session_summary_from_events,
    select_default_loaded_session,
    sort_session_summaries,
    stable_legacy_session_id,
)
from .telemetry_model import (
    ProjectedTelemetryRecord,
    project_telemetry_mapping,
    unique_projected_correlations,
    unique_projected_resources,
)
from .telemetry_sink import (
    read_telemetry_sink_manifest,
    resolve_telemetry_sink_segment_paths,
)

SCHEMA_VERSION_V2: Literal[2] = 2
SCHEMA_VERSION_V3: Literal[3] = 3
SCHEMA_VERSION_LATEST: Literal[3] = SCHEMA_VERSION_V3
UNKNOWN_PID = -1
UNKNOWN_HOST = "unknown"

REQUIRED_V3_FIELDS = (
    "schema_version",
    "session_id",
    "timestamp_ns",
    "event_type",
    "collector",
    "sampling_interval_ms",
    "pid",
    "host",
    "device_id",
    "allocator_allocated_bytes",
    "allocator_reserved_bytes",
    "allocator_active_bytes",
    "allocator_inactive_bytes",
    "allocator_change_bytes",
    "device_used_bytes",
    "device_free_bytes",
    "device_total_bytes",
    "context",
    "metadata",
)
OPTIONAL_V3_FIELDS = (
    "job_id",
    "rank",
    "local_rank",
    "world_size",
)
REQUIRED_V2_FIELDS = tuple(
    field_name for field_name in REQUIRED_V3_FIELDS if field_name != "session_id"
)
OPTIONAL_V2_FIELDS = OPTIONAL_V3_FIELDS
KNOWN_V2_FIELD_SET = frozenset(REQUIRED_V2_FIELDS + OPTIONAL_V2_FIELDS)
KNOWN_V3_FIELD_SET = frozenset(REQUIRED_V3_FIELDS + OPTIONAL_V3_FIELDS)
_DISTRIBUTED_METADATA_KEYS = frozenset(OPTIONAL_V3_FIELDS)
_SESSION_METADATA_KEYS = frozenset({"session_id"})
_RANK_ENV_GROUPS = (
    ("RANK", "LOCAL_RANK", "WORLD_SIZE"),
    (
        "OMPI_COMM_WORLD_RANK",
        "OMPI_COMM_WORLD_LOCAL_RANK",
        "OMPI_COMM_WORLD_SIZE",
    ),
    ("SLURM_PROCID", "SLURM_LOCALID", "SLURM_NTASKS"),
)
_JOB_ID_ENV_KEYS = ("TORCHELASTIC_RUN_ID", "SLURM_JOB_ID")


[docs] @dataclass class TelemetryEventV2: """Legacy v2 telemetry event payload retained for backward-compatible writes/tests.""" schema_version: Literal[2] timestamp_ns: int event_type: str collector: str sampling_interval_ms: int pid: int host: str device_id: int allocator_allocated_bytes: int allocator_reserved_bytes: int allocator_active_bytes: Optional[int] allocator_inactive_bytes: Optional[int] allocator_change_bytes: int device_used_bytes: int device_free_bytes: Optional[int] device_total_bytes: Optional[int] context: Optional[str] job_id: Optional[str] = None rank: int = 0 local_rank: int = 0 world_size: int = 1 metadata: dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: validate_telemetry_record(telemetry_event_to_dict(self))
[docs] @dataclass class TelemetryEventV3: """Canonical telemetry event payload used by tracker exports and loaders.""" schema_version: Literal[3] session_id: str timestamp_ns: int event_type: str collector: str sampling_interval_ms: int pid: int host: str device_id: int allocator_allocated_bytes: int allocator_reserved_bytes: int allocator_active_bytes: Optional[int] allocator_inactive_bytes: Optional[int] allocator_change_bytes: int device_used_bytes: int device_free_bytes: Optional[int] device_total_bytes: Optional[int] context: Optional[str] job_id: Optional[str] = None rank: int = 0 local_rank: int = 0 world_size: int = 1 metadata: dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: validate_telemetry_record(telemetry_event_to_dict(self))
TelemetryEvent = TelemetryEventV3
[docs] @dataclass class LoadedTelemetrySession: """Grouped telemetry records and lifecycle metadata for one session.""" summary: SessionSummary events: list[TelemetryEvent] sources_loaded: list[str] = field(default_factory=list) warnings: list[str] = field(default_factory=list)
[docs] def telemetry_records(self) -> list[ProjectedTelemetryRecord]: """Return backend-neutral projected records for this loaded session.""" return project_telemetry_events(self.events)
[docs] def resources(self) -> list[dict[str, Any]]: """Return unique observed resources for this loaded session.""" return unique_projected_resources(self.telemetry_records())
[docs] def correlations(self) -> list[dict[str, Any]]: """Return unique correlation contexts for this loaded session.""" return unique_projected_correlations(self.telemetry_records())
def _is_int(value: Any) -> bool: return isinstance(value, int) and not isinstance(value, bool) def _coerce_int(value: Any, field_name: str) -> int: if _is_int(value): return int(value) raise ValueError(f"{field_name} must be an integer") def _coerce_optional_int(value: Any, field_name: str) -> Optional[int]: if value is None: return None return _coerce_int(value, field_name) def _coerce_string( value: Any, field_name: str, *, allow_none: bool = False ) -> Optional[str]: if value is None: if allow_none: return None raise ValueError(f"{field_name} must be a non-empty string") if not isinstance(value, str): raise ValueError(f"{field_name} must be a string") if not value.strip() and not allow_none: raise ValueError(f"{field_name} must be a non-empty string") return value def _coerce_required_string(value: Any, field_name: str) -> str: coerced = _coerce_string(value, field_name) if coerced is None: raise ValueError(f"{field_name} must be a non-empty string") return coerced def _coerce_optional_non_empty_string(value: Any, field_name: str) -> Optional[str]: if value is None: return None return _coerce_required_string(value, field_name) def _coerce_metadata_dict(value: Any) -> dict[str, Any]: if not isinstance(value, dict): raise ValueError("metadata must be an object") return dict(value) def _extract_metadata(record: Mapping[str, Any]) -> dict[str, Any]: metadata: dict[str, Any] = {} raw_metadata = record.get("metadata") if raw_metadata is None: pass elif isinstance(raw_metadata, Mapping): metadata.update(dict(raw_metadata)) else: raise ValueError("metadata must be an object when provided") for key, value in record.items(): if isinstance(key, str) and key.startswith("metadata_"): metadata[key.removeprefix("metadata_")] = value return metadata def _first_env_value(env: Mapping[str, str], keys: tuple[str, ...]) -> Optional[str]: for key in keys: value = env.get(key) if value is None: continue stripped = value.strip() if stripped: return stripped return None def _coerce_non_negative_int(value: Any, field_name: str) -> int: coerced = _coerce_int(value, field_name) if coerced < 0: raise ValueError(f"{field_name} must be >= 0") return coerced def _coerce_positive_int(value: Any, field_name: str) -> int: coerced = _coerce_int(value, field_name) if coerced <= 0: raise ValueError(f"{field_name} must be >= 1") return coerced def _coerce_env_int(value: str, field_name: str) -> int: try: return int(value.strip()) except ValueError as exc: raise ValueError(f"{field_name} must be an integer") from exc def _infer_distributed_identity_from_env( env: Optional[Mapping[str, str]] = None, ) -> dict[str, Any]: if env is None: return {"job_id": None, "rank": None, "local_rank": None, "world_size": None} raw_job_id = _first_env_value(env, _JOB_ID_ENV_KEYS) for rank_key, local_rank_key, world_size_key in _RANK_ENV_GROUPS: keys_present = any( key in env for key in (rank_key, local_rank_key, world_size_key) ) if not keys_present: continue raw_rank = env.get(rank_key) raw_world_size = env.get(world_size_key) if raw_rank is None or raw_world_size is None: continue local_rank_value = env.get(local_rank_key) if local_rank_value is None or not local_rank_value.strip(): local_rank_value = raw_rank rank_value = _coerce_env_int(raw_rank, "rank") local_rank_int = _coerce_env_int(local_rank_value, "local_rank") world_size_value = _coerce_env_int(raw_world_size, "world_size") return { "job_id": raw_job_id, "rank": _coerce_non_negative_int(rank_value, "rank"), "local_rank": _coerce_non_negative_int(local_rank_int, "local_rank"), "world_size": _coerce_positive_int(world_size_value, "world_size"), } return {"job_id": raw_job_id, "rank": None, "local_rank": None, "world_size": None}
[docs] def resolve_distributed_identity( *, job_id: Any = None, rank: Any = None, local_rank: Any = None, world_size: Any = None, metadata: Optional[Mapping[str, Any]] = None, env: Optional[Mapping[str, str]] = None, ) -> dict[str, Any]: """Normalize distributed identity fields from explicit, metadata, or env inputs.""" metadata_values = dict(metadata or {}) raw_job_id = job_id if job_id is not None else metadata_values.get("job_id") raw_rank = rank if rank is not None else metadata_values.get("rank") raw_local_rank = ( local_rank if local_rank is not None else metadata_values.get("local_rank") ) raw_world_size = ( world_size if world_size is not None else metadata_values.get("world_size") ) needs_rank_env = ( raw_rank is None or raw_local_rank is None or raw_world_size is None ) if needs_rank_env: inferred = _infer_distributed_identity_from_env(env) if raw_rank is None: raw_rank = inferred["rank"] if raw_local_rank is None: raw_local_rank = inferred["local_rank"] if raw_world_size is None: raw_world_size = inferred["world_size"] if raw_job_id is None: raw_job_id = inferred["job_id"] elif raw_job_id is None and env is not None: raw_job_id = _first_env_value(env, _JOB_ID_ENV_KEYS) if raw_world_size is None: raw_world_size = 1 if raw_rank is None: raw_rank = 0 if raw_rank is not None and raw_local_rank is None: raw_local_rank = raw_rank normalized_job_id = _coerce_optional_non_empty_string(raw_job_id, "job_id") normalized_rank = _coerce_non_negative_int(raw_rank, "rank") normalized_local_rank = _coerce_non_negative_int(raw_local_rank, "local_rank") normalized_world_size = _coerce_positive_int(raw_world_size, "world_size") if normalized_rank >= normalized_world_size: raise ValueError("rank must be < world_size") if normalized_local_rank >= normalized_world_size: raise ValueError("local_rank must be < world_size") if normalized_world_size == 1 and normalized_rank != 0: raise ValueError("rank must be 0 when world_size is 1") if normalized_world_size == 1 and normalized_local_rank != 0: raise ValueError("local_rank must be 0 when world_size is 1") return { "job_id": normalized_job_id, "rank": normalized_rank, "local_rank": normalized_local_rank, "world_size": normalized_world_size, }
def _strip_distributed_identity_metadata(metadata: Mapping[str, Any]) -> dict[str, Any]: return { key: value for key, value in metadata.items() if key not in _DISTRIBUTED_METADATA_KEYS } def _strip_session_metadata(metadata: Mapping[str, Any]) -> dict[str, Any]: return { key: value for key, value in metadata.items() if key not in _SESSION_METADATA_KEYS } def _resolve_session_id( record: Mapping[str, Any], *, metadata: Mapping[str, Any] | None = None, default_session_id: str | None = None, ) -> str: raw_session_id = record.get("session_id") if raw_session_id is None and metadata is not None: raw_session_id = metadata.get("session_id") if isinstance(raw_session_id, str) and raw_session_id.strip(): return raw_session_id if default_session_id is not None: return default_session_id timestamp_value = record.get("timestamp_ns", record.get("timestamp", "unknown")) host_value = record.get("host", UNKNOWN_HOST) pid_value = record.get("pid", UNKNOWN_PID) return stable_legacy_session_id(timestamp_value, host_value, pid_value) def _legacy_timestamp_ns(record: Mapping[str, Any]) -> int: if "timestamp_ns" in record: return _coerce_int(record["timestamp_ns"], "timestamp_ns") timestamp = record.get("timestamp") if isinstance(timestamp, (int, float)) and not isinstance(timestamp, bool): return int(float(timestamp) * 1_000_000_000) raise ValueError("Legacy record is missing a valid timestamp") def _legacy_device_id(record: Mapping[str, Any]) -> int: if "device_id" in record: return _coerce_int(record["device_id"], "device_id") device = record.get("device") if isinstance(device, str): lowered = device.lower() if "cpu" in lowered: return -1 if ":" in device: tail = device.rsplit(":", 1)[-1] if tail.isdigit(): return int(tail) if lowered.startswith("/gpu"): return 0 if "memory_mb" in record: return 0 return -1 def _legacy_allocator_allocated_bytes(record: Mapping[str, Any]) -> int: if "allocator_allocated_bytes" in record: return _coerce_int( record["allocator_allocated_bytes"], "allocator_allocated_bytes" ) if "memory_allocated" in record: return _coerce_int(record["memory_allocated"], "memory_allocated") memory_mb = record.get("memory_mb") if isinstance(memory_mb, (int, float)) and not isinstance(memory_mb, bool): return int(float(memory_mb) * (1024**2)) if "device_used_bytes" in record: return _coerce_int(record["device_used_bytes"], "device_used_bytes") return 0 def _legacy_allocator_reserved_bytes(record: Mapping[str, Any], allocated: int) -> int: if "allocator_reserved_bytes" in record: return _coerce_int( record["allocator_reserved_bytes"], "allocator_reserved_bytes" ) if "memory_reserved" in record: return _coerce_int(record["memory_reserved"], "memory_reserved") return allocated def _legacy_allocator_change_bytes(record: Mapping[str, Any]) -> int: if "allocator_change_bytes" in record: return _coerce_int(record["allocator_change_bytes"], "allocator_change_bytes") if "memory_change" in record: return _coerce_int(record["memory_change"], "memory_change") return 0 def _legacy_optional_counter(record: Mapping[str, Any], key: str) -> Optional[int]: value = record.get(key) if value is None: return None return _coerce_int(value, key) def _legacy_total_memory_bytes( record: Mapping[str, Any], metadata: Mapping[str, Any] ) -> Optional[int]: if "device_total_bytes" in record: return _coerce_optional_int( record.get("device_total_bytes"), "device_total_bytes" ) for key in ("total_memory", "device_total", "total_bytes"): if key in record: value = record[key] if value is None: return None return _coerce_int(value, key) for key in ("total_memory", "device_total", "total_bytes"): if key in metadata: value = metadata[key] if value is None: return None return _coerce_int(value, key) return None def _legacy_device_used_bytes(record: Mapping[str, Any], allocated: int) -> int: if "device_used_bytes" in record: return _coerce_int(record["device_used_bytes"], "device_used_bytes") return allocated def _legacy_device_free_bytes( record: Mapping[str, Any], used: int, total: Optional[int], ) -> Optional[int]: if "device_free_bytes" in record: return _coerce_optional_int( record.get("device_free_bytes"), "device_free_bytes" ) if total is None: return None free = total - used return max(free, 0) def _legacy_pid(record: Mapping[str, Any], metadata: Mapping[str, Any]) -> int: if "pid" in record: return _coerce_int(record["pid"], "pid") if "pid" in metadata: return _coerce_int(metadata["pid"], "pid") return UNKNOWN_PID def _legacy_host(record: Mapping[str, Any], metadata: Mapping[str, Any]) -> str: if "host" in record: return _coerce_string(record["host"], "host") or UNKNOWN_HOST if "host" in metadata: return _coerce_string(metadata["host"], "host") or UNKNOWN_HOST return UNKNOWN_HOST def _legacy_collector( record: Mapping[str, Any], default_collector: str, device_id: int, metadata: Mapping[str, Any], ) -> str: collector = record.get("collector") if isinstance(collector, str) and collector.strip(): return collector backend_value = record.get("backend", metadata.get("backend")) if isinstance(backend_value, str): backend = backend_value.strip().lower() if backend == "mps": return "stormlog.mps_tracker" if backend == "rocm": return "stormlog.rocm_tracker" if backend == "cuda": return "stormlog.cuda_tracker" if backend == "cpu": return "stormlog.cpu_tracker" if "memory_mb" in record: return "stormlog.tensorflow.memory_tracker" if "memory_allocated" in record: return "stormlog.cpu_tracker" if device_id == -1 else "stormlog.cuda_tracker" return default_collector
[docs] def telemetry_event_to_dict( event: TelemetryEvent | TelemetryEventV2, ) -> dict[str, Any]: """Serialize a telemetry event to a plain dictionary.""" if isinstance(event, TelemetryEventV2): return { "schema_version": event.schema_version, "timestamp_ns": event.timestamp_ns, "event_type": event.event_type, "collector": event.collector, "sampling_interval_ms": event.sampling_interval_ms, "pid": event.pid, "host": event.host, "job_id": event.job_id, "rank": event.rank, "local_rank": event.local_rank, "world_size": event.world_size, "device_id": event.device_id, "allocator_allocated_bytes": event.allocator_allocated_bytes, "allocator_reserved_bytes": event.allocator_reserved_bytes, "allocator_active_bytes": event.allocator_active_bytes, "allocator_inactive_bytes": event.allocator_inactive_bytes, "allocator_change_bytes": event.allocator_change_bytes, "device_used_bytes": event.device_used_bytes, "device_free_bytes": event.device_free_bytes, "device_total_bytes": event.device_total_bytes, "context": event.context, "metadata": dict(event.metadata), } return { "schema_version": event.schema_version, "session_id": event.session_id, "timestamp_ns": event.timestamp_ns, "event_type": event.event_type, "collector": event.collector, "sampling_interval_ms": event.sampling_interval_ms, "pid": event.pid, "host": event.host, "job_id": event.job_id, "rank": event.rank, "local_rank": event.local_rank, "world_size": event.world_size, "device_id": event.device_id, "allocator_allocated_bytes": event.allocator_allocated_bytes, "allocator_reserved_bytes": event.allocator_reserved_bytes, "allocator_active_bytes": event.allocator_active_bytes, "allocator_inactive_bytes": event.allocator_inactive_bytes, "allocator_change_bytes": event.allocator_change_bytes, "device_used_bytes": event.device_used_bytes, "device_free_bytes": event.device_free_bytes, "device_total_bytes": event.device_total_bytes, "context": event.context, "metadata": dict(event.metadata), }
[docs] def validate_telemetry_record(record: Mapping[str, Any]) -> None: """Validate a v2 or v3 telemetry record. Raises: ValueError: if the record is invalid or partial. """ schema_version = _coerce_int(record.get("schema_version"), "schema_version") required_fields: tuple[str, ...] known_fields: frozenset[str] if schema_version == SCHEMA_VERSION_V3: required_fields = REQUIRED_V3_FIELDS known_fields = KNOWN_V3_FIELD_SET require_session_id = True elif schema_version == SCHEMA_VERSION_V2: required_fields = REQUIRED_V2_FIELDS known_fields = KNOWN_V2_FIELD_SET require_session_id = False else: raise ValueError(f"Unsupported schema_version: {schema_version}") missing = [name for name in required_fields if name not in record] if missing: raise ValueError(f"Missing required telemetry fields: {', '.join(missing)}") unknown = sorted(str(name) for name in record if name not in known_fields) if unknown: raise ValueError(f"Unknown telemetry fields: {', '.join(unknown)}") if require_session_id: _coerce_required_string(record["session_id"], "session_id") timestamp_ns = _coerce_int(record["timestamp_ns"], "timestamp_ns") if timestamp_ns < 0: raise ValueError("timestamp_ns must be >= 0") _coerce_required_string(record["event_type"], "event_type") _coerce_required_string(record["collector"], "collector") sampling_interval_ms = _coerce_int( record["sampling_interval_ms"], "sampling_interval_ms" ) if sampling_interval_ms < 0: raise ValueError("sampling_interval_ms must be >= 0") pid = _coerce_int(record["pid"], "pid") if pid < -1: raise ValueError("pid must be >= -1") _coerce_required_string(record["host"], "host") if "job_id" in record: _coerce_optional_non_empty_string(record["job_id"], "job_id") if "rank" in record: _coerce_non_negative_int(record["rank"], "rank") if "local_rank" in record: _coerce_non_negative_int(record["local_rank"], "local_rank") if "world_size" in record: _coerce_positive_int(record["world_size"], "world_size") _coerce_int(record["device_id"], "device_id") allocator_allocated_bytes = _coerce_int( record["allocator_allocated_bytes"], "allocator_allocated_bytes" ) allocator_reserved_bytes = _coerce_int( record["allocator_reserved_bytes"], "allocator_reserved_bytes" ) allocator_active_bytes = _coerce_optional_int( record["allocator_active_bytes"], "allocator_active_bytes" ) allocator_inactive_bytes = _coerce_optional_int( record["allocator_inactive_bytes"], "allocator_inactive_bytes" ) _coerce_int(record["allocator_change_bytes"], "allocator_change_bytes") if allocator_allocated_bytes < 0: raise ValueError("allocator_allocated_bytes must be >= 0") if allocator_reserved_bytes < 0: raise ValueError("allocator_reserved_bytes must be >= 0") if allocator_active_bytes is not None and allocator_active_bytes < 0: raise ValueError("allocator_active_bytes must be >= 0 when provided") if allocator_inactive_bytes is not None and allocator_inactive_bytes < 0: raise ValueError("allocator_inactive_bytes must be >= 0 when provided") device_used_bytes = _coerce_int(record["device_used_bytes"], "device_used_bytes") device_free_bytes = _coerce_optional_int( record["device_free_bytes"], "device_free_bytes" ) device_total_bytes = _coerce_optional_int( record["device_total_bytes"], "device_total_bytes" ) if device_used_bytes < 0: raise ValueError("device_used_bytes must be >= 0") if device_free_bytes is not None and device_free_bytes < 0: raise ValueError("device_free_bytes must be >= 0 when provided") if device_total_bytes is not None and device_total_bytes < 0: raise ValueError("device_total_bytes must be >= 0 when provided") if device_total_bytes is not None and device_used_bytes > device_total_bytes: raise ValueError("device_used_bytes cannot exceed device_total_bytes") if ( device_total_bytes is not None and device_free_bytes is not None and device_free_bytes > device_total_bytes ): raise ValueError("device_free_bytes cannot exceed device_total_bytes") _coerce_string(record["context"], "context", allow_none=True) _coerce_metadata_dict(record["metadata"]) resolve_distributed_identity( job_id=record.get("job_id"), rank=record.get("rank"), local_rank=record.get("local_rank"), world_size=record.get("world_size"), )
[docs] def telemetry_event_from_record( record: Mapping[str, Any], permissive_legacy: bool = True, default_collector: str = "legacy.unknown", default_sampling_interval_ms: int = 0, default_session_id: str | None = None, ) -> TelemetryEvent: """Create a canonical telemetry event from v3, v2, or legacy records.""" if not isinstance(record, Mapping): raise ValueError("record must be a mapping") if "schema_version" in record: schema_version = _coerce_int(record["schema_version"], "schema_version") if schema_version not in {SCHEMA_VERSION_V2, SCHEMA_VERSION_V3}: raise ValueError(f"Unsupported schema_version: {schema_version}") raw_metadata = record.get("metadata", {}) metadata = _coerce_metadata_dict(raw_metadata) distributed_identity = resolve_distributed_identity( job_id=record.get("job_id"), rank=record.get("rank"), local_rank=record.get("local_rank"), world_size=record.get("world_size"), ) session_id = _resolve_session_id( record, metadata=metadata, default_session_id=default_session_id, ) upgraded_record = dict(record) upgraded_record["schema_version"] = SCHEMA_VERSION_V3 upgraded_record["session_id"] = session_id validate_telemetry_record(upgraded_record) metadata = _coerce_metadata_dict(upgraded_record["metadata"]) return TelemetryEvent( schema_version=SCHEMA_VERSION_V3, session_id=session_id, timestamp_ns=_coerce_int(record["timestamp_ns"], "timestamp_ns"), event_type=_coerce_required_string(record["event_type"], "event_type"), collector=_coerce_required_string(record["collector"], "collector"), sampling_interval_ms=_coerce_int( record["sampling_interval_ms"], "sampling_interval_ms" ), pid=_coerce_int(record["pid"], "pid"), host=_coerce_required_string(record["host"], "host"), device_id=_coerce_int(record["device_id"], "device_id"), allocator_allocated_bytes=_coerce_int( record["allocator_allocated_bytes"], "allocator_allocated_bytes" ), allocator_reserved_bytes=_coerce_int( record["allocator_reserved_bytes"], "allocator_reserved_bytes" ), allocator_active_bytes=_coerce_optional_int( record["allocator_active_bytes"], "allocator_active_bytes" ), allocator_inactive_bytes=_coerce_optional_int( record["allocator_inactive_bytes"], "allocator_inactive_bytes" ), allocator_change_bytes=_coerce_int( record["allocator_change_bytes"], "allocator_change_bytes" ), device_used_bytes=_coerce_int( record["device_used_bytes"], "device_used_bytes" ), device_free_bytes=_coerce_optional_int( record["device_free_bytes"], "device_free_bytes" ), device_total_bytes=_coerce_optional_int( record["device_total_bytes"], "device_total_bytes" ), context=_coerce_string(record["context"], "context", allow_none=True), job_id=distributed_identity["job_id"], rank=distributed_identity["rank"], local_rank=distributed_identity["local_rank"], world_size=distributed_identity["world_size"], metadata=metadata, ) if not permissive_legacy: raise ValueError("Legacy record conversion is disabled") metadata = _extract_metadata(record) timestamp_ns = _legacy_timestamp_ns(record) device_id = _legacy_device_id(record) allocator_allocated_bytes = _legacy_allocator_allocated_bytes(record) allocator_reserved_bytes = _legacy_allocator_reserved_bytes( record, allocator_allocated_bytes ) allocator_change_bytes = _legacy_allocator_change_bytes(record) allocator_active_bytes = _legacy_optional_counter(record, "allocator_active_bytes") allocator_inactive_bytes = _legacy_optional_counter( record, "allocator_inactive_bytes" ) device_used_bytes = _legacy_device_used_bytes(record, allocator_allocated_bytes) device_total_bytes = _legacy_total_memory_bytes(record, metadata) device_free_bytes = _legacy_device_free_bytes( record, device_used_bytes, device_total_bytes ) event_type_value = record.get("event_type", record.get("type", "sample")) event_type = _coerce_string(event_type_value, "event_type") or "sample" sampling_interval_value = record.get( "sampling_interval_ms", default_sampling_interval_ms ) sampling_interval_ms = _coerce_int(sampling_interval_value, "sampling_interval_ms") pid = _legacy_pid(record, metadata) host = _legacy_host(record, metadata) collector = _legacy_collector(record, default_collector, device_id, metadata) distributed_identity = resolve_distributed_identity( job_id=record.get("job_id"), rank=record.get("rank"), local_rank=record.get("local_rank"), world_size=record.get("world_size"), metadata=metadata, ) metadata = _strip_distributed_identity_metadata(metadata) session_id = _resolve_session_id( record, metadata=metadata, default_session_id=default_session_id, ) metadata = _strip_session_metadata(metadata) context_value = record.get("context", record.get("message")) context = _coerce_string(context_value, "context", allow_none=True) event = TelemetryEvent( schema_version=SCHEMA_VERSION_V3, session_id=session_id, timestamp_ns=timestamp_ns, event_type=event_type, collector=collector, sampling_interval_ms=sampling_interval_ms, pid=pid, host=host, device_id=device_id, allocator_allocated_bytes=allocator_allocated_bytes, allocator_reserved_bytes=allocator_reserved_bytes, allocator_active_bytes=allocator_active_bytes, allocator_inactive_bytes=allocator_inactive_bytes, allocator_change_bytes=allocator_change_bytes, device_used_bytes=device_used_bytes, device_free_bytes=device_free_bytes, device_total_bytes=device_total_bytes, context=context, job_id=distributed_identity["job_id"], rank=distributed_identity["rank"], local_rank=distributed_identity["local_rank"], world_size=distributed_identity["world_size"], metadata=metadata, ) return event
def _looks_like_event_record(payload: Mapping[str, Any]) -> bool: candidate_keys = { "schema_version", "event_type", "type", "memory_allocated", "memory_mb", "timestamp", "timestamp_ns", } return any(key in payload for key in candidate_keys) def _group_session_events( events: list[TelemetryEvent], ) -> dict[str, list[TelemetryEvent]]: grouped: dict[str, list[TelemetryEvent]] = {} for event in events: grouped.setdefault(event.session_id, []).append(event) for session_events in grouped.values(): session_events.sort(key=lambda event: event.timestamp_ns) return grouped def _assemble_loaded_sessions( *, grouped_events: dict[str, list[TelemetryEvent]], manifest_summaries: list[SessionSummary] | None = None, sources_by_session: Mapping[str, set[str]] | None = None, warnings_by_session: Mapping[str, list[str]] | None = None, default_source: str, default_source_path: str, ) -> list[LoadedTelemetrySession]: summary_by_id = { summary.session_id: summary for summary in (manifest_summaries or []) } session_ids = set(grouped_events) | set(summary_by_id) loaded_sessions: list[LoadedTelemetrySession] = [] for session_id in session_ids: session_events = list(grouped_events.get(session_id, [])) summary = summary_by_id.get(session_id) if summary is None: summary = infer_session_summary_from_events( session_id=session_id, events=session_events, source=default_source, fallback_status=SESSION_STATUS_INCOMPLETE, ) loaded_sessions.append( LoadedTelemetrySession( summary=summary, events=session_events, sources_loaded=sorted( (sources_by_session or {}).get(session_id, {default_source_path}) ), warnings=list((warnings_by_session or {}).get(session_id, [])), ) ) ordered_summaries = sort_session_summaries( loaded.summary for loaded in loaded_sessions ) order = { summary.session_id: index for index, summary in enumerate(ordered_summaries) } return sorted( loaded_sessions, key=lambda loaded: ( order.get(loaded.summary.session_id, 999), loaded.summary.session_id, ), ) def _load_jsonl_events( path: Path, *, permissive_legacy: bool, default_session_id: str | None = None, ) -> list[TelemetryEvent]: lines = path.read_text(encoding="utf-8").splitlines(keepends=True) output: list[TelemetryEvent] = [] for index, line in enumerate(lines, start=1): if not line.strip(): continue try: payload = json.loads(line) except json.JSONDecodeError as exc: if index == len(lines) and not line.endswith("\n"): break raise ValueError( f"Malformed telemetry JSONL record in {path} at line {index}" ) from exc if not isinstance(payload, Mapping): raise ValueError( f"Telemetry record in {path} at line {index} must be an object" ) output.append( telemetry_event_from_record( payload, permissive_legacy=permissive_legacy, default_session_id=default_session_id, ) ) return output def _load_json_records( payload_path: Path, *, events_key: Optional[str], ) -> list[Mapping[str, Any]]: with payload_path.open("r", encoding="utf-8") as handle: payload = json.load(handle) records: Any if isinstance(payload, list): records = payload elif isinstance(payload, Mapping): if events_key is not None: records = payload.get(events_key) if not isinstance(records, list): raise ValueError( f"Top-level key '{events_key}' must contain a list of events" ) elif isinstance(payload.get("events"), list): records = payload["events"] elif _looks_like_event_record(payload): records = [payload] else: raise ValueError("JSON payload does not contain telemetry events") else: raise ValueError("Telemetry payload must be a JSON object or array") normalized: list[Mapping[str, Any]] = [] for index, record in enumerate(records): if not isinstance(record, Mapping): raise ValueError(f"Event at index {index} must be an object") normalized.append(record) return normalized
[docs] def load_telemetry_sessions( path: str | Path, permissive_legacy: bool = True, events_key: Optional[str] = None, ) -> list[LoadedTelemetrySession]: """Load grouped telemetry sessions from JSON, JSONL, or sink directories.""" payload_path = Path(path) default_source_path = str(payload_path.resolve()) default_source = f"artifact:{payload_path.name or payload_path.resolve()}" manifest = read_telemetry_sink_manifest(payload_path) segment_paths = resolve_telemetry_sink_segment_paths(payload_path) if segment_paths: grouped_events: dict[str, list[TelemetryEvent]] = {} sources_by_session: dict[str, set[str]] = {} segment_session_ids = { segment.filename: segment.session_id for segment in (manifest.segments if manifest is not None else []) } fallback_session_id = stable_legacy_session_id(default_source_path, "sink") for segment_path in segment_paths: hint_session_id = ( segment_session_ids.get(segment_path.name) or fallback_session_id ) segment_events = _load_jsonl_events( segment_path, permissive_legacy=permissive_legacy, default_session_id=hint_session_id, ) session_groups = _group_session_events(segment_events) for session_id, events in session_groups.items(): grouped_events.setdefault(session_id, []).extend(events) sources_by_session.setdefault(session_id, set()).add(str(segment_path)) if not segment_events and hint_session_id: sources_by_session.setdefault(hint_session_id, set()).add( str(segment_path) ) for events in grouped_events.values(): events.sort(key=lambda event: event.timestamp_ns) return _assemble_loaded_sessions( grouped_events=grouped_events, manifest_summaries=manifest.sessions if manifest is not None else None, sources_by_session=sources_by_session, default_source=default_source, default_source_path=default_source_path, ) records = _load_json_records(payload_path, events_key=events_key) default_session_id = stable_legacy_session_id( default_source_path, events_key or "json" ) loaded_events: list[TelemetryEvent] = [ telemetry_event_from_record( record, permissive_legacy=permissive_legacy, default_session_id=default_session_id, ) for record in records ] grouped_events = _group_session_events(loaded_events) sources_by_session = { session_id: {default_source_path} for session_id in grouped_events } return _assemble_loaded_sessions( grouped_events=grouped_events, manifest_summaries=None, sources_by_session=sources_by_session, default_source=default_source, default_source_path=default_source_path, )
[docs] def load_telemetry_events( path: str | Path, permissive_legacy: bool = True, events_key: Optional[str] = None, session_id: str | None = None, ) -> list[TelemetryEvent]: """Load telemetry events from JSON and return the selected session.""" sessions = load_telemetry_sessions( path, permissive_legacy=permissive_legacy, events_key=events_key, ) if not sessions: return [] if session_id is not None: for loaded in sessions: if loaded.summary.session_id == session_id: return list(loaded.events) raise ValueError(f"Requested session_id not found: {session_id}") selected = select_default_loaded_session(sessions) return list(selected.events) if selected is not None else []
[docs] def project_telemetry_event( event: TelemetryEvent | Mapping[str, Any], ) -> ProjectedTelemetryRecord: """Project telemetry objects or compatible mappings into the shared model.""" if isinstance(event, TelemetryEventV3): normalized = event else: normalized = telemetry_event_from_record(event) return project_telemetry_mapping(telemetry_event_to_dict(normalized))
[docs] def project_telemetry_events( events: Iterable[TelemetryEvent | Mapping[str, Any]], ) -> list[ProjectedTelemetryRecord]: """Project existing telemetry events into backend-neutral records.""" return [project_telemetry_event(event) for event in events]
__all__ = [ "SCHEMA_VERSION_V2", "SCHEMA_VERSION_V3", "SCHEMA_VERSION_LATEST", "ProjectedTelemetryRecord", "LoadedTelemetrySession", "TelemetryEvent", "TelemetryEventV2", "TelemetryEventV3", "project_telemetry_event", "project_telemetry_events", "load_telemetry_sessions", "telemetry_event_from_record", "telemetry_event_to_dict", "validate_telemetry_record", "load_telemetry_events", "resolve_distributed_identity", ]