Source code for stormlog.jax.pprof_parser

import gzip
from pathlib import Path
from typing import Any, Dict, List

try:
    from . import profile_pb2
except ImportError:
    raise ImportError(
        "Could not import profile_pb2. Please run: \n"
        "curl -sO https://raw.githubusercontent.com/google/pprof/master/proto/profile.proto && "
        "python -m grpc_tools.protoc -I. --python_out=. profile.proto"
    ) from None


[docs] def parse_jax_memory_profile(file_path: str) -> Dict[str, Any]: """Parse a JAX .prof (gzipped pprof protobuf) using the official protobuf schema.""" path = Path(file_path) try: with gzip.open(path, "rb") as f: data = f.read() except FileNotFoundError as exc: raise FileNotFoundError(f"JAX memory profile not found: {path}") from exc except PermissionError as exc: raise PermissionError(f"JAX memory profile is not readable: {path}") from exc profile = profile_pb2.Profile() # type: ignore profile.ParseFromString(data) string_table = profile.string_table # Map function_id -> function name functions: Dict[int, str] = {} for func in profile.function: functions[func.id] = string_table[func.name] # Map location_id -> [function_names] locations: Dict[int, List[str]] = {} for loc in profile.location: names = [] for line in loc.line: func_name = functions.get(line.function_id, "<unknown>") names.append(func_name) locations[loc.id] = names # Flatten samples samples = [] for sample in profile.sample: # pprof puts innermost call first, so reverse to get root->leaf stack stack = [] for loc_id in sample.location_id: loc_names = locations.get(loc_id, ["<unknown>"]) # The line entries in a location are innermost-first too stack.extend(loc_names) stack.reverse() samples.append({"stack": stack, "values": list(sample.value)}) return {"samples": samples}