Source code for stormlog._wandb.core

"""Core helpers shared by Stormlog's optional W&B exporters."""

from __future__ import annotations

import json
import re
import tempfile
from dataclasses import dataclass
from importlib import import_module
from pathlib import Path
from typing import Any, Mapping

from ..session import SessionSummary, session_summary_from_dict

WANDB_INSTALL_GUIDANCE = (
    "Weights & Biases integration requires optional dependencies. "
    "Install with `pip install 'stormlog[wandb]'`."
)


[docs] @dataclass(frozen=True) class WandbExportConfig: """Runtime configuration for optional W&B exports.""" enabled: bool = False project: str | None = None entity: str | None = None mode: str | None = None run_id: str | None = None run_name: str | None = None group: str | None = None job_type: str | None = None log_tables: bool = True log_artifacts: bool = False log_attribution: bool = False def __post_init__(self) -> None: if self.mode not in {None, "online", "offline"}: raise ValueError("wandb mode must be 'online', 'offline', or omitted")
[docs] def wandb_config_from_namespace(args: Any) -> WandbExportConfig: """Build a W&B export config from CLI args or a similar namespace.""" return WandbExportConfig( enabled=bool(getattr(args, "wandb", False)), project=_normalized_optional_string(getattr(args, "wandb_project", None)), entity=_normalized_optional_string(getattr(args, "wandb_entity", None)), mode=_normalized_optional_string(getattr(args, "wandb_mode", None)), run_id=_normalized_optional_string(getattr(args, "wandb_run_id", None)), run_name=_normalized_optional_string(getattr(args, "wandb_name", None)), group=_normalized_optional_string(getattr(args, "wandb_group", None)), job_type=_normalized_optional_string(getattr(args, "wandb_job_type", None)), log_artifacts=bool(getattr(args, "wandb_log_artifacts", False)), log_attribution=bool(getattr(args, "wandb_log_attribution", False)), )
[docs] def add_wandb_arguments(parser: Any) -> None: """Attach shared optional W&B flags to a CLI parser.""" parser.add_argument( "--wandb", action="store_true", help="Log Stormlog summaries to Weights & Biases", ) parser.add_argument( "--wandb-project", type=str, default=None, help="W&B project name (default: stormlog)", ) parser.add_argument( "--wandb-entity", type=str, default=None, help="W&B entity or team name", ) parser.add_argument( "--wandb-mode", choices=["online", "offline"], default=None, help="W&B logging mode (default: online)", ) parser.add_argument( "--wandb-run-id", type=str, default=None, help="Existing W&B run id to resume or attach to", ) parser.add_argument( "--wandb-name", type=str, default=None, help="Explicit W&B run name", ) parser.add_argument( "--wandb-group", type=str, default=None, help="W&B group override (default: Stormlog job id)", ) parser.add_argument( "--wandb-job-type", type=str, default=None, help="W&B job type override (default: Stormlog command name)", ) parser.add_argument( "--wandb-log-artifacts", action="store_true", help="Upload Stormlog output bundles as W&B artifacts", ) parser.add_argument( "--wandb-log-attribution", action="store_true", help="Log attribution HTML and top offenders to W&B when available", )
[docs] def ensure_wandb_available(config: WandbExportConfig) -> None: """Fail fast when the W&B feature is enabled without dependencies installed.""" if config.enabled: import_wandb()
def import_wandb() -> Any: try: return import_module("wandb") except ModuleNotFoundError as exc: if exc.name == "wandb": raise ImportError(WANDB_INSTALL_GUIDANCE) from exc raise def resolve_run( config: WandbExportConfig, *, command_name: str, session_summary: SessionSummary | None, ) -> tuple[Any, Any, bool]: wandb = import_wandb() active_run = getattr(wandb, "run", None) if active_run is not None: return wandb, active_run, False init_kwargs: dict[str, Any] = { "project": config.project or "stormlog", "entity": config.entity, "mode": config.mode, "name": config.run_name or _default_run_name(command_name, session_summary), "group": config.group or default_group(session_summary), "job_type": config.job_type or command_name, } if config.run_id is not None: init_kwargs["id"] = config.run_id init_kwargs["resume"] = "allow" run = wandb.init( **{key: value for key, value in init_kwargs.items() if value is not None} ) return wandb, run, True def update_summary(run: Any, payload: Mapping[str, Any]) -> None: if not payload: return summary = getattr(run, "summary", None) if summary is None: return if hasattr(summary, "update"): summary.update(payload) return for key, value in payload.items(): summary[key] = value def session_summary_from_manifest( manifest: Mapping[str, Any] | None, ) -> SessionSummary | None: if not isinstance(manifest, Mapping): return None session_payload = manifest.get("session") if not isinstance(session_payload, Mapping): return None try: return session_summary_from_dict(session_payload) except ValueError: return None def session_summary_fields(summary: SessionSummary | None) -> dict[str, Any]: if summary is None: return {} fields: dict[str, Any] = { "stormlog_session_id": summary.session_id, "stormlog_session_source": summary.source, "stormlog_session_status": summary.status, "stormlog_rank": summary.rank, "stormlog_local_rank": summary.local_rank, "stormlog_world_size": summary.world_size, } if summary.job_id is not None: fields["stormlog_job_id"] = summary.job_id return fields def default_group(summary: SessionSummary | None) -> str | None: if summary is None or summary.job_id is None: return None return summary.job_id def session_slug(summary: SessionSummary | None) -> str: raw = summary.session_id if summary is not None else "session" slug = re.sub(r"[^a-zA-Z0-9._-]+", "-", raw) return slug or "session" def coerce_existing_file(value: str | Path | None) -> Path | None: if value is None: return None path = Path(value) if path.exists() and path.is_file(): return path return None def coerce_existing_dir(value: str | Path | None) -> Path | None: if value is None: return None path = Path(value) if path.exists() and path.is_dir(): return path return None def read_json_if_exists(path: Path) -> dict[str, Any] | None: if not path.exists(): return None try: payload = json.loads(path.read_text(encoding="utf-8")) except (OSError, json.JSONDecodeError): return None if isinstance(payload, dict): return payload return None def log_file_artifact( wandb: Any, run: Any, *, artifact_name: str, artifact_type: str, path: Path, ) -> None: artifact = wandb.Artifact(name=artifact_name, type=artifact_type) artifact.add_file(local_path=str(path), name=path.name) run.log_artifact(artifact) def log_directory_artifact( wandb: Any, run: Any, *, artifact_name: str, artifact_type: str, path: Path, ) -> None: artifact = wandb.Artifact(name=artifact_name, type=artifact_type) artifact.add_dir(local_path=str(path), name=path.name) run.log_artifact(artifact) def materialize_html_file( *, html_text: str, file_name: str, output_root: Path | None, ) -> Path: if output_root is not None: output_root.mkdir(parents=True, exist_ok=True) target_path = output_root / file_name target_path.write_text(html_text, encoding="utf-8") return target_path temp_dir = Path(tempfile.mkdtemp(prefix="stormlog-wandb-")) target_path = temp_dir / file_name target_path.write_text(html_text, encoding="utf-8") target_path.chmod(0o600) return target_path def _default_run_name( command_name: str, session_summary: SessionSummary | None, ) -> str: if session_summary is None: return command_name return f"{command_name}-{session_summary.session_id[:8]}" def _normalized_optional_string(value: Any) -> str | None: if not isinstance(value, str): return None stripped = value.strip() return stripped or None