Source code for stormlog.collective_attribution

"""Heuristics for attributing hidden-memory spikes to collective communication."""

from __future__ import annotations

from dataclasses import dataclass, field, replace
from typing import Any, Iterable, Mapping, Sequence

import numpy as np

try:
    from .phases import PhaseAttribution, PhaseReplayIndex, merge_phase_attributions
except ImportError:  # pragma: no cover - phase package may land in another slice
    PhaseAttribution = Any  # type: ignore[assignment,misc]
    PhaseReplayIndex = Any  # type: ignore[assignment,misc]

    def merge_phase_attributions(
        first: PhaseAttribution | None,
        second: PhaseAttribution | None,
    ) -> PhaseAttribution | None:
        return first or second


from .telemetry import TelemetryEventV2

_COLLECTIVE_TOKENS = (
    "nccl",
    "collective",
    "communication",
    "all_reduce",
    "allreduce",
    "all_gather",
    "allgather",
    "reduce_scatter",
    "reducescatter",
    "broadcast",
    "barrier",
)


[docs] @dataclass class CollectiveAttributionConfig: """Runtime knobs for collective-memory attribution heuristics.""" preset: str = "medium" min_samples_per_rank: int = 6 min_gap_bytes: int = 128 * 1024 * 1024 min_gap_ratio: float = 0.04 robust_zscore_threshold: float = 2.5 marker_window_ns: int = 150_000_000 interval_padding_ns: int = 120_000_000 synchrony_window_ns: int = 120_000_000 min_synchrony_ratio: float = 0.5 min_confidence: float = 0.5
[docs] @dataclass class CollectiveAttributionEvidence: """Evidence fields backing one collective-attribution output.""" marker_hits: int synchronized_ranks: int expected_world_size: int synchrony_ratio: float peak_gap_bytes: int peak_gap_ratio: float | None robust_zscore: float
[docs] @dataclass class CollectiveAttributionResult: """Communication-attributed hidden-memory interval.""" rank: int interval_start_ns: int interval_end_ns: int classification: str confidence: float reason_codes: list[str] = field(default_factory=list) evidence: CollectiveAttributionEvidence | None = None phase_attribution: PhaseAttribution | None = None
@dataclass(frozen=True) class _RankSpike: key: tuple[int, int, int] rank: int session_id: str | None timestamp_ns: int peak_gap_bytes: int peak_gap_ratio: float | None robust_zscore: float marker_times: tuple[int, ...] _PRESET_CONFIGS: dict[str, CollectiveAttributionConfig] = { "low": CollectiveAttributionConfig( preset="low", min_samples_per_rank=8, min_gap_bytes=256 * 1024 * 1024, min_gap_ratio=0.06, robust_zscore_threshold=3.0, marker_window_ns=120_000_000, interval_padding_ns=100_000_000, synchrony_window_ns=100_000_000, min_synchrony_ratio=0.6, min_confidence=0.7, ), "medium": CollectiveAttributionConfig(), "high": CollectiveAttributionConfig( preset="high", min_samples_per_rank=5, min_gap_bytes=64 * 1024 * 1024, min_gap_ratio=0.025, robust_zscore_threshold=1.8, marker_window_ns=180_000_000, interval_padding_ns=160_000_000, synchrony_window_ns=180_000_000, min_synchrony_ratio=0.34, min_confidence=0.35, ), } _OVERRIDABLE_FIELDS = frozenset( { "min_samples_per_rank", "min_gap_bytes", "min_gap_ratio", "robust_zscore_threshold", "marker_window_ns", "interval_padding_ns", "synchrony_window_ns", "min_synchrony_ratio", "min_confidence", } )
[docs] def resolve_collective_attribution_config( preset: str = "medium", overrides: Mapping[str, Any] | None = None, ) -> CollectiveAttributionConfig: """Resolve a preset config with optional per-threshold overrides.""" normalized_preset = (preset or "medium").strip().lower() if normalized_preset not in _PRESET_CONFIGS: known = ", ".join(sorted(_PRESET_CONFIGS)) raise ValueError( f"Unknown collective attribution preset: {preset!r} (known: {known})" ) config = replace(_PRESET_CONFIGS[normalized_preset]) config.preset = normalized_preset if overrides: unknown = sorted(key for key in overrides if key not in _OVERRIDABLE_FIELDS) if unknown: raise ValueError( "Unknown collective attribution override fields: " + ", ".join(unknown) ) typed_overrides = {key: overrides[key] for key in overrides} config = replace(config, **typed_overrides) _validate_collective_config(config) return config
[docs] def attribute_collective_memory( events: Sequence[TelemetryEventV2], *, config: CollectiveAttributionConfig | None = None, preset: str = "medium", overrides: Mapping[str, Any] | None = None, phase_resolver: PhaseReplayIndex | None = None, ) -> list[CollectiveAttributionResult]: """Attribute hidden-memory spikes to communication phases using hybrid signals.""" if not events: return [] resolved = config or resolve_collective_attribution_config(preset, overrides) ordered_events = sorted(events, key=lambda item: item.timestamp_ns) marker_timestamps_by_rank = _collect_marker_timestamps_by_rank(ordered_events) grouped_samples = _group_sample_events_by_rank(ordered_events) if not grouped_samples: return [] spikes_by_rank: dict[int, list[_RankSpike]] = {} for rank, rank_samples in grouped_samples.items(): spikes = _detect_rank_spikes( rank=rank, rank_events=rank_samples, marker_timestamps=marker_timestamps_by_rank.get(rank, ()), config=resolved, ) if spikes: spikes_by_rank[rank] = spikes if not spikes_by_rank: return [] expected_world_size = _expected_world_size(ordered_events) trace_start_ns = max(0, ordered_events[0].timestamp_ns) synchrony_by_spike = _build_synchrony_lookup( spikes_by_rank, resolved.synchrony_window_ns ) results: list[CollectiveAttributionResult] = [] for rank, spikes in spikes_by_rank.items(): for spike in spikes: synchronized_ranks = synchrony_by_spike.get(spike.key, {rank}) scored = _score_spike( spike=spike, synchronized_ranks=synchronized_ranks, expected_world_size=expected_world_size, trace_start_ns=trace_start_ns, config=resolved, phase_resolver=phase_resolver, ) if scored is not None and scored.confidence >= resolved.min_confidence: results.append(scored) return _merge_rank_intervals(results)
def _validate_collective_config(config: CollectiveAttributionConfig) -> None: if config.min_samples_per_rank < 3: raise ValueError("min_samples_per_rank must be >= 3") if config.min_gap_bytes < 0: raise ValueError("min_gap_bytes must be >= 0") if config.min_gap_ratio < 0: raise ValueError("min_gap_ratio must be >= 0") if config.robust_zscore_threshold <= 0: raise ValueError("robust_zscore_threshold must be > 0") if config.marker_window_ns < 0: raise ValueError("marker_window_ns must be >= 0") if config.interval_padding_ns < 0: raise ValueError("interval_padding_ns must be >= 0") if config.synchrony_window_ns < 0: raise ValueError("synchrony_window_ns must be >= 0") if not 0 <= config.min_synchrony_ratio <= 1: raise ValueError("min_synchrony_ratio must be in [0, 1]") if not 0 <= config.min_confidence <= 1: raise ValueError("min_confidence must be in [0, 1]") def _group_sample_events_by_rank( events: Iterable[TelemetryEventV2], ) -> dict[int, list[TelemetryEventV2]]: grouped: dict[int, list[TelemetryEventV2]] = {} for event in events: if str(event.event_type).strip().lower() != "sample": continue grouped.setdefault(event.rank, []).append(event) return grouped def _collect_marker_timestamps_by_rank( events: Iterable[TelemetryEventV2], ) -> dict[int, tuple[int, ...]]: grouped: dict[int, list[int]] = {} for event in events: if _event_has_collective_marker(event): grouped.setdefault(event.rank, []).append(event.timestamp_ns) return {rank: tuple(sorted(values)) for rank, values in grouped.items()} def _detect_rank_spikes( *, rank: int, rank_events: Sequence[TelemetryEventV2], marker_timestamps: Sequence[int], config: CollectiveAttributionConfig, ) -> list[_RankSpike]: if len(rank_events) < config.min_samples_per_rank: return [] positive_gaps = np.asarray( [ max(0, event.device_used_bytes - event.allocator_reserved_bytes) for event in rank_events ], dtype=float, ) if positive_gaps.size < config.min_samples_per_rank: return [] spikes: list[_RankSpike] = [] for sample_index, event in enumerate(rank_events): gap_bytes = max(0, event.device_used_bytes - event.allocator_reserved_bytes) gap_ratio = _compute_gap_ratio(event, gap_bytes) if not _is_significant_gap( gap_bytes=gap_bytes, gap_ratio=gap_ratio, config=config, ): continue robust_zscore = _robust_zscore(positive_gaps, float(gap_bytes)) if robust_zscore < config.robust_zscore_threshold: continue nearby_markers = tuple( ts for ts in marker_timestamps if abs(ts - event.timestamp_ns) <= config.marker_window_ns ) spikes.append( _RankSpike( key=(rank, event.timestamp_ns, sample_index), rank=rank, session_id=getattr(event, "session_id", None), timestamp_ns=event.timestamp_ns, peak_gap_bytes=int(gap_bytes), peak_gap_ratio=gap_ratio, robust_zscore=round(float(robust_zscore), 4), marker_times=nearby_markers, ) ) return spikes def _compute_gap_ratio(event: TelemetryEventV2, gap_bytes: int) -> float | None: if event.device_total_bytes is None or event.device_total_bytes <= 0: return None return abs(gap_bytes) / event.device_total_bytes def _is_significant_gap( *, gap_bytes: int, gap_ratio: float | None, config: CollectiveAttributionConfig, ) -> bool: ratio_significant = gap_ratio is not None and gap_ratio >= config.min_gap_ratio bytes_significant = gap_bytes >= config.min_gap_bytes return ratio_significant or bytes_significant def _robust_zscore(values: np.ndarray, value: float) -> float: if values.size < 3: return 0.0 median = float(np.median(values)) mad = float(np.median(np.abs(values - median))) if mad > 0: return max(0.0, 0.6745 * (value - median) / mad) std = float(np.std(values, ddof=1)) if values.size > 1 else 0.0 if std > 0: mean = float(np.mean(values)) return max(0.0, (value - mean) / std) return 0.0 def _expected_world_size(events: Sequence[TelemetryEventV2]) -> int: world_sizes = [event.world_size for event in events if event.world_size > 0] if world_sizes: return max(world_sizes) ranks = {event.rank for event in events} return max(len(ranks), 1) def _build_synchrony_lookup( spikes_by_rank: Mapping[int, Sequence[_RankSpike]], synchrony_window_ns: int, ) -> dict[tuple[int, int, int], set[int]]: all_spikes = [spike for spikes in spikes_by_rank.values() for spike in spikes] if not all_spikes: return {} bucket_size = max(1, synchrony_window_ns) buckets: dict[int, list[_RankSpike]] = {} for spike in all_spikes: bucket = spike.timestamp_ns // bucket_size buckets.setdefault(bucket, []).append(spike) lookup: dict[tuple[int, int, int], set[int]] = {} for spike in all_spikes: bucket = spike.timestamp_ns // bucket_size synchronized = { other.rank for neighbor_bucket in (bucket - 1, bucket, bucket + 1) for other in buckets.get(neighbor_bucket, []) if abs(other.timestamp_ns - spike.timestamp_ns) <= synchrony_window_ns } if not synchronized: synchronized = {spike.rank} lookup[spike.key] = synchronized return lookup def _score_spike( *, spike: _RankSpike, synchronized_ranks: set[int], expected_world_size: int, trace_start_ns: int, config: CollectiveAttributionConfig, phase_resolver: PhaseReplayIndex | None, ) -> CollectiveAttributionResult | None: marker_overlap = bool(spike.marker_times) synchronized_count = max(len(synchronized_ranks), 1) if expected_world_size <= 1: synchrony_ratio = 0.0 else: synchrony_ratio = min(1.0, (synchronized_count - 1) / (expected_world_size - 1)) zscore_strength = min(1.0, spike.robust_zscore / config.robust_zscore_threshold) bytes_strength = ( min(1.0, spike.peak_gap_bytes / config.min_gap_bytes) if config.min_gap_bytes > 0 else 1.0 ) ratio_strength = ( min(1.0, (spike.peak_gap_ratio or 0.0) / config.min_gap_ratio) if config.min_gap_ratio > 0 else 1.0 ) divergence_strength = max(bytes_strength, ratio_strength) confidence = 0.0 confidence += 0.34 if marker_overlap else 0.0 confidence += 0.34 * synchrony_ratio confidence += 0.16 * zscore_strength confidence += 0.16 * divergence_strength reason_codes: list[str] = ["gap_spike_statistical_outlier"] if marker_overlap: reason_codes.extend(["marker_collective_token", "marker_spike_overlap"]) else: confidence -= 0.15 reason_codes.append("weak_marker_signal") if synchrony_ratio >= config.min_synchrony_ratio: reason_codes.append("cross_rank_synchrony") if divergence_strength >= 0.5: reason_codes.append("allocator_device_divergence") if expected_world_size <= 1 or synchronized_count <= 1: confidence -= 0.15 reason_codes.append("single_rank_only") if spike.peak_gap_ratio is None: confidence -= 0.05 confidence = max(0.0, min(1.0, confidence)) if confidence <= 0: return None marker_start = min(spike.marker_times) if spike.marker_times else spike.timestamp_ns marker_end = max(spike.marker_times) if spike.marker_times else spike.timestamp_ns unclamped_start = min(marker_start, spike.timestamp_ns) - config.interval_padding_ns interval_start = max(0, trace_start_ns, unclamped_start) interval_end = max(marker_end, spike.timestamp_ns) + config.interval_padding_ns evidence = CollectiveAttributionEvidence( marker_hits=len(spike.marker_times), synchronized_ranks=synchronized_count, expected_world_size=expected_world_size, synchrony_ratio=round(float(synchrony_ratio), 4), peak_gap_bytes=spike.peak_gap_bytes, peak_gap_ratio=( round(float(spike.peak_gap_ratio), 6) if spike.peak_gap_ratio is not None else None ), robust_zscore=round(float(spike.robust_zscore), 4), ) if confidence >= 0.8: classification = "collective_confident" elif confidence >= 0.6: classification = "collective_likely" else: classification = "collective_suspect" phase_attribution = None if ( phase_resolver is not None and spike.session_id is not None and hasattr(phase_resolver, "resolve") ): phase_attribution = phase_resolver.resolve( timestamp_ns=spike.timestamp_ns, session_id=spike.session_id, rank=spike.rank, ) return CollectiveAttributionResult( rank=spike.rank, interval_start_ns=interval_start, interval_end_ns=interval_end, classification=classification, confidence=round(float(confidence), 3), reason_codes=sorted(set(reason_codes)), evidence=evidence, phase_attribution=phase_attribution, ) def _merge_rank_intervals( results: Sequence[CollectiveAttributionResult], ) -> list[CollectiveAttributionResult]: if not results: return [] merged: list[CollectiveAttributionResult] = [] for result in sorted(results, key=lambda item: (item.rank, item.interval_start_ns)): if not merged: merged.append(result) continue prev = merged[-1] same_rank = prev.rank == result.rank overlaps = result.interval_start_ns <= prev.interval_end_ns if not same_rank or not overlaps: merged.append(result) continue prev_evidence = prev.evidence curr_evidence = result.evidence if prev_evidence is None: merged_evidence = curr_evidence elif curr_evidence is None: merged_evidence = prev_evidence else: merged_evidence = CollectiveAttributionEvidence( marker_hits=max(prev_evidence.marker_hits, curr_evidence.marker_hits), synchronized_ranks=max( prev_evidence.synchronized_ranks, curr_evidence.synchronized_ranks ), expected_world_size=max( prev_evidence.expected_world_size, curr_evidence.expected_world_size ), synchrony_ratio=max( prev_evidence.synchrony_ratio, curr_evidence.synchrony_ratio ), peak_gap_bytes=max( prev_evidence.peak_gap_bytes, curr_evidence.peak_gap_bytes ), peak_gap_ratio=_max_optional_ratio( prev_evidence.peak_gap_ratio, curr_evidence.peak_gap_ratio, ), robust_zscore=max( prev_evidence.robust_zscore, curr_evidence.robust_zscore ), ) merged[-1] = CollectiveAttributionResult( rank=prev.rank, interval_start_ns=min(prev.interval_start_ns, result.interval_start_ns), interval_end_ns=max(prev.interval_end_ns, result.interval_end_ns), classification=_merge_classification( prev.classification, result.classification ), confidence=max(prev.confidence, result.confidence), reason_codes=sorted(set(prev.reason_codes + result.reason_codes)), evidence=merged_evidence, phase_attribution=merge_phase_attributions( prev.phase_attribution, result.phase_attribution, ), ) return merged def _max_optional_ratio(first: float | None, second: float | None) -> float | None: if first is None: return second if second is None: return first return max(first, second) def _merge_classification(first: str, second: str) -> str: order = { "collective_suspect": 0, "collective_likely": 1, "collective_confident": 2, } return first if order.get(first, -1) >= order.get(second, -1) else second def _event_has_collective_marker(event: TelemetryEventV2) -> bool: event_type = str(getattr(event, "event_type", "")) context = getattr(event, "context", None) metadata = getattr(event, "metadata", {}) text_fragments: list[str] = [event_type] if isinstance(context, str) and context: text_fragments.append(context) text_fragments.extend(_iter_string_values(metadata)) for fragment in text_fragments: if _contains_collective_token(fragment): return True return False def _iter_string_values(value: Any) -> Iterable[str]: if isinstance(value, str): yield value return if isinstance(value, Mapping): for _, nested in value.items(): yield from _iter_string_values(nested) return if isinstance(value, (list, tuple, set)): for item in value: yield from _iter_string_values(item) def _contains_collective_token(text: str) -> bool: lowered = text.strip().lower() if not lowered: return False normalized = lowered.replace("-", "_") collapsed = normalized.replace("_", "") return any( token in normalized or token.replace("_", "") in collapsed for token in _COLLECTIVE_TOKENS ) __all__ = [ "CollectiveAttributionConfig", "CollectiveAttributionEvidence", "CollectiveAttributionResult", "attribute_collective_memory", "resolve_collective_attribution_config", ]