Source code for stormlog.distributed_analysis

"""Distributed telemetry analysis helpers."""

from __future__ import annotations

from collections import Counter, defaultdict
from dataclasses import asdict, dataclass, field
from statistics import median
from typing import Any, Sequence

try:
    from .phases import PhaseAttribution, PhaseReplayIndex, phase_attribution_to_payload
except ImportError:  # pragma: no cover - phase package may land in another slice
    PhaseAttribution = Any  # type: ignore[assignment,misc]
    PhaseReplayIndex = Any  # type: ignore[assignment,misc]

    def phase_attribution_to_payload(
        attribution: PhaseAttribution | None,
    ) -> dict[str, Any] | None:
        return None


from .telemetry import TelemetryEventV2

_SPIKE_MIN_BYTES = 64 * 1024**2
_SKEW_NOTE_MULTIPLIER = 5


[docs] @dataclass class RankTimelinePoint: """A single telemetry sample in a rank-aligned timeline.""" rank: int timestamp_ns: int aligned_timestamp_ns: int device_used_bytes: int allocator_reserved_bytes: int allocator_allocated_bytes: int allocator_change_bytes: int
[docs] @dataclass class CrossRankMergeResult: """Merged distributed timeline state.""" job_id: str | None world_size: int participating_ranks: list[int] missing_ranks: list[int] rank_sample_counts: dict[int, int] alignment_offsets_ns: dict[int, int] merged_points: list[RankTimelinePoint] notes: list[str] = field(default_factory=list)
[docs] @dataclass class FirstCauseSuspect: """A ranked first-cause candidate.""" rank: int first_spike_timestamp_ns: int aligned_first_spike_timestamp_ns: int peak_delta_bytes: int spike_window_samples: int lead_over_cluster_onset_ns: int confidence: str evidence: dict[str, int | str] phase_attribution: PhaseAttribution | None = None
[docs] @dataclass class FirstCauseAnalysisResult: """The distributed first-cause detection result.""" cluster_onset_timestamp_ns: int | None suspects: list[FirstCauseSuspect] notes: list[str] = field(default_factory=list)
@dataclass class _RankSpikeCandidate: rank: int session_id: str | None first_spike_timestamp_ns: int aligned_first_spike_timestamp_ns: int peak_delta_bytes: int spike_window_samples: int def _is_sample_event(event: TelemetryEventV2) -> bool: return event.event_type.casefold() == "sample" def _group_events_by_rank( events: Sequence[TelemetryEventV2], ) -> dict[int, list[TelemetryEventV2]]: grouped: dict[int, list[TelemetryEventV2]] = defaultdict(list) for event in events: grouped[event.rank].append(event) for rank_events in grouped.values(): rank_events.sort(key=lambda event: (event.timestamp_ns, event.pid, event.host)) return dict(grouped) def _select_job_id(events: Sequence[TelemetryEventV2]) -> str | None: job_ids = [event.job_id for event in events if event.job_id] if not job_ids: return None counts = Counter(job_ids) return counts.most_common(1)[0][0] def _select_cross_rank_analysis_events( events: Sequence[TelemetryEventV2], ) -> tuple[list[TelemetryEventV2], str | None, list[str]]: notes: list[str] = [] sample_events = [event for event in events if _is_sample_event(event)] if len(sample_events) != len(events): notes.append("Ignored non-sample events during cross-rank analysis.") if not sample_events: notes.append( "No sample telemetry events were available for distributed analysis." ) return [], None, notes job_id = _select_job_id(sample_events) observed_job_ids = {event.job_id for event in sample_events if event.job_id} if len(observed_job_ids) > 1 and job_id is not None: notes.append( "Multiple job_id values were observed; filtering to the most common value." ) sample_events = [event for event in sample_events if event.job_id == job_id] return sample_events, job_id, notes def _determine_world_size( events: Sequence[TelemetryEventV2], participating_ranks: Sequence[int] ) -> int: if not events: return 0 claimed_sizes = [event.world_size for event in events if event.world_size > 1] declared = Counter(claimed_sizes).most_common(1)[0][0] if claimed_sizes else 1 observed = (max(participating_ranks) + 1) if participating_ranks else declared if len(participating_ranks) > 1: return max(declared, observed) return max(declared, observed, 1) def _median_sampling_interval_ns(grouped: dict[int, list[TelemetryEventV2]]) -> int: intervals: list[int] = [] for rank_events in grouped.values(): for index in range(1, len(rank_events)): delta = ( rank_events[index].timestamp_ns - rank_events[index - 1].timestamp_ns ) if delta > 0: intervals.append(delta) if intervals: return int(median(intervals)) fallback_intervals = [ event.sampling_interval_ms * 1_000_000 for rank_events in grouped.values() for event in rank_events if event.sampling_interval_ms > 0 ] if fallback_intervals: return int(median(fallback_intervals)) return 0
[docs] def merge_cross_rank_timelines( events: Sequence[TelemetryEventV2], ) -> CrossRankMergeResult: """Merge rank streams into a single aligned timeline.""" if not events: return CrossRankMergeResult( job_id=None, world_size=0, participating_ranks=[], missing_ranks=[], rank_sample_counts={}, alignment_offsets_ns={}, merged_points=[], notes=["No telemetry events were provided."], ) analysis_events, job_id, notes = _select_cross_rank_analysis_events(events) if not analysis_events: return CrossRankMergeResult( job_id=job_id, world_size=0, participating_ranks=[], missing_ranks=[], rank_sample_counts={}, alignment_offsets_ns={}, merged_points=[], notes=notes, ) grouped = _group_events_by_rank(analysis_events) participating_ranks = sorted(grouped) world_size = _determine_world_size(analysis_events, participating_ranks) expected_ranks = set(range(world_size)) if world_size > 0 else set() missing_ranks = sorted(expected_ranks.difference(participating_ranks)) rank_sample_counts = { rank: len(rank_events) for rank, rank_events in grouped.items() } if missing_ranks: notes.append( "Missing rank data for ranks: " + ", ".join(str(rank) for rank in missing_ranks) + "." ) anchor_rank = 0 if 0 in grouped else participating_ranks[0] anchor_timestamp = grouped[anchor_rank][0].timestamp_ns alignment_offsets_ns: dict[int, int] = {} merged_points: list[RankTimelinePoint] = [] median_interval_ns = _median_sampling_interval_ns(grouped) for rank in participating_ranks: offset_ns = grouped[rank][0].timestamp_ns - anchor_timestamp alignment_offsets_ns[rank] = offset_ns if ( median_interval_ns > 0 and abs(offset_ns) > _SKEW_NOTE_MULTIPLIER * median_interval_ns ): notes.append( "Rank " f"{rank} starts {offset_ns} ns from the anchor; " "first-sample alignment may be approximate." ) for event in grouped[rank]: merged_points.append( RankTimelinePoint( rank=rank, timestamp_ns=event.timestamp_ns, aligned_timestamp_ns=event.timestamp_ns - offset_ns, device_used_bytes=event.device_used_bytes, allocator_reserved_bytes=event.allocator_reserved_bytes, allocator_allocated_bytes=event.allocator_allocated_bytes, allocator_change_bytes=event.allocator_change_bytes, ) ) merged_points.sort( key=lambda point: (point.aligned_timestamp_ns, point.rank, point.timestamp_ns) ) return CrossRankMergeResult( job_id=job_id, world_size=world_size, participating_ranks=participating_ranks, missing_ranks=missing_ranks, rank_sample_counts=rank_sample_counts, alignment_offsets_ns=alignment_offsets_ns, merged_points=merged_points, notes=notes, )
def _find_rank_spike_candidate( rank_events: Sequence[TelemetryEventV2], offset_ns: int ) -> _RankSpikeCandidate | None: if len(rank_events) < 2: return None spike_threshold = max( _SPIKE_MIN_BYTES, int(max(event.device_used_bytes for event in rank_events) * 0.10), ) window_start_index: int | None = None cumulative_delta = 0 spike_window_samples = 0 for index in range(1, len(rank_events)): delta = ( rank_events[index].device_used_bytes - rank_events[index - 1].device_used_bytes ) if delta <= 0: window_start_index = None cumulative_delta = 0 spike_window_samples = 0 continue if window_start_index is None: window_start_index = index cumulative_delta = 0 spike_window_samples = 0 cumulative_delta += delta spike_window_samples += 1 if cumulative_delta < spike_threshold: continue spike_event = rank_events[index] return _RankSpikeCandidate( rank=spike_event.rank, session_id=getattr(spike_event, "session_id", None), first_spike_timestamp_ns=spike_event.timestamp_ns, aligned_first_spike_timestamp_ns=spike_event.timestamp_ns - offset_ns, peak_delta_bytes=cumulative_delta, spike_window_samples=spike_window_samples, ) return None def _detect_first_cause_spikes( grouped: dict[int, list[TelemetryEventV2]], merge_result: CrossRankMergeResult, phase_resolver: PhaseReplayIndex | None = None, ) -> FirstCauseAnalysisResult: if not grouped: return FirstCauseAnalysisResult( cluster_onset_timestamp_ns=None, suspects=[], notes=["No telemetry events were available for distributed analysis."], ) if len(grouped) == 1: return FirstCauseAnalysisResult( cluster_onset_timestamp_ns=None, suspects=[], notes=[ "At least two ranks are required for cross-rank first-cause analysis." ], ) candidates: list[_RankSpikeCandidate] = [] insufficient_sample_ranks = [ rank for rank, rank_events in grouped.items() if len(rank_events) < 2 ] for rank, rank_events in grouped.items(): candidate = _find_rank_spike_candidate( rank_events, merge_result.alignment_offsets_ns.get(rank, 0), ) if candidate is not None: candidates.append(candidate) notes: list[str] = [] if insufficient_sample_ranks: notes.append( "Some ranks have fewer than two samples and cannot contribute to spike detection: " + ", ".join(str(rank) for rank in sorted(insufficient_sample_ranks)) + "." ) if not candidates: notes.append("No qualifying cross-rank spikes were detected.") return FirstCauseAnalysisResult( cluster_onset_timestamp_ns=None, suspects=[], notes=notes, ) candidates.sort( key=lambda candidate: ( candidate.aligned_first_spike_timestamp_ns, -candidate.peak_delta_bytes, candidate.rank, ) ) cluster_onset_timestamp_ns: int | None = None if len(candidates) >= 2: cluster_onset_timestamp_ns = candidates[1].aligned_first_spike_timestamp_ns else: notes.append( "Only one rank produced a qualifying spike; confidence is limited." ) suspect_cutoff = ( cluster_onset_timestamp_ns if cluster_onset_timestamp_ns is not None else candidates[0].aligned_first_spike_timestamp_ns ) ranked_suspects = [ candidate for candidate in candidates if candidate.aligned_first_spike_timestamp_ns <= suspect_cutoff ] median_interval_ns = _median_sampling_interval_ns(grouped) earliest_aligned_timestamp = ranked_suspects[0].aligned_first_spike_timestamp_ns earliest_count = sum( candidate.aligned_first_spike_timestamp_ns == earliest_aligned_timestamp for candidate in ranked_suspects ) sparse_evidence = bool(merge_result.missing_ranks or insufficient_sample_ranks) support_count = len(ranked_suspects) suspects: list[FirstCauseSuspect] = [] for candidate in ranked_suspects: lead_over_cluster_onset_ns = ( 0 if cluster_onset_timestamp_ns is None else cluster_onset_timestamp_ns - candidate.aligned_first_spike_timestamp_ns ) confidence = "low" if len(candidates) >= 2: if ( candidate.aligned_first_spike_timestamp_ns == earliest_aligned_timestamp and earliest_count == 1 and not sparse_evidence and ( median_interval_ns <= 0 or lead_over_cluster_onset_ns >= median_interval_ns ) ): confidence = "high" elif ( candidate.aligned_first_spike_timestamp_ns == earliest_aligned_timestamp ): confidence = "medium" if not sparse_evidence else "low" suspects.append( FirstCauseSuspect( rank=candidate.rank, first_spike_timestamp_ns=candidate.first_spike_timestamp_ns, aligned_first_spike_timestamp_ns=candidate.aligned_first_spike_timestamp_ns, peak_delta_bytes=candidate.peak_delta_bytes, spike_window_samples=candidate.spike_window_samples, lead_over_cluster_onset_ns=lead_over_cluster_onset_ns, confidence=confidence, evidence={ "device_used_delta_bytes": candidate.peak_delta_bytes, "supporting_ranks_at_or_before_onset": support_count, }, phase_attribution=( phase_resolver.resolve( timestamp_ns=candidate.first_spike_timestamp_ns, session_id=candidate.session_id, rank=candidate.rank, ) if phase_resolver is not None and candidate.session_id is not None and hasattr(phase_resolver, "resolve") else None ), ) ) return FirstCauseAnalysisResult( cluster_onset_timestamp_ns=cluster_onset_timestamp_ns, suspects=suspects, notes=notes, )
[docs] def analyze_cross_rank_events( events: Sequence[TelemetryEventV2], *, phase_resolver: PhaseReplayIndex | None = None, ) -> tuple[CrossRankMergeResult, FirstCauseAnalysisResult]: """Analyze distributed telemetry for merged timelines and first-cause spikes.""" merge_result = merge_cross_rank_timelines(events) analysis_events, _, selection_notes = _select_cross_rank_analysis_events(events) if not analysis_events: return merge_result, FirstCauseAnalysisResult( cluster_onset_timestamp_ns=None, suspects=[], notes=selection_notes or list(merge_result.notes), ) grouped = _group_events_by_rank(analysis_events) first_cause_result = _detect_first_cause_spikes( grouped, merge_result, phase_resolver=phase_resolver, ) return merge_result, first_cause_result
[docs] def summarize_cross_rank_analysis( events: Sequence[TelemetryEventV2], *, phase_resolver: PhaseReplayIndex | None = None, ) -> dict[str, Any]: """Return a JSON-serializable cross-rank analysis summary.""" merge_result, first_cause_result = analyze_cross_rank_events( events, phase_resolver=phase_resolver, ) notes = list(dict.fromkeys([*merge_result.notes, *first_cause_result.notes])) return { "job_id": merge_result.job_id, "world_size": merge_result.world_size, "participating_ranks": merge_result.participating_ranks, "missing_ranks": merge_result.missing_ranks, "rank_sample_counts": { str(rank): count for rank, count in merge_result.rank_sample_counts.items() }, "alignment_offsets_ns": { str(rank): offset for rank, offset in merge_result.alignment_offsets_ns.items() }, "cluster_onset_timestamp_ns": first_cause_result.cluster_onset_timestamp_ns, "first_cause_suspects": [ _serialize_first_cause_suspect(suspect) for suspect in first_cause_result.suspects ], "notes": notes, }
def _serialize_first_cause_suspect(suspect: FirstCauseSuspect) -> dict[str, Any]: payload = asdict(suspect) payload["phase_attribution"] = phase_attribution_to_payload( suspect.phase_attribution ) return payload __all__ = [ "CrossRankMergeResult", "FirstCauseAnalysisResult", "FirstCauseSuspect", "RankTimelinePoint", "analyze_cross_rank_events", "merge_cross_rank_timelines", "summarize_cross_rank_analysis", ]