Source code for stormlog.profiler

"""Core Stormlog for PyTorch."""

import gc
import logging
import threading
import time
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union

import psutil
import torch

logger = logging.getLogger(__name__)


[docs] @dataclass class MemorySnapshot: """Represents a memory snapshot at a specific point in time.""" timestamp: float allocated_memory: int reserved_memory: int max_memory_allocated: int max_memory_reserved: int active_memory: int inactive_memory: int cpu_memory: int device_id: int = 0 operation: Optional[str] = None stack_trace: Optional[str] = None
[docs] def to_dict(self) -> Dict[str, Any]: """Convert snapshot to dictionary.""" return { "timestamp": self.timestamp, "allocated_memory": self.allocated_memory, "reserved_memory": self.reserved_memory, "max_memory_allocated": self.max_memory_allocated, "max_memory_reserved": self.max_memory_reserved, "active_memory": self.active_memory, "inactive_memory": self.inactive_memory, "cpu_memory": self.cpu_memory, "device_id": self.device_id, "operation": self.operation, "stack_trace": self.stack_trace, }
[docs] @dataclass class ProfileResult: """Results from profiling a function or operation.""" function_name: str execution_time: float memory_before: MemorySnapshot memory_after: MemorySnapshot memory_peak: MemorySnapshot memory_allocated: int memory_freed: int tensors_created: int tensors_deleted: int call_count: int = 1
[docs] def memory_diff(self) -> int: """Calculate memory difference between before and after.""" return self.memory_after.allocated_memory - self.memory_before.allocated_memory
[docs] def peak_memory_usage(self) -> int: """Get peak memory usage during execution.""" return self.memory_peak.allocated_memory
[docs] def to_dict(self) -> Dict[str, Any]: """Convert result to dictionary.""" return { "function_name": self.function_name, "execution_time": self.execution_time, "memory_before": self.memory_before.to_dict(), "memory_after": self.memory_after.to_dict(), "memory_peak": self.memory_peak.to_dict(), "memory_allocated": self.memory_allocated, "memory_freed": self.memory_freed, "memory_diff": self.memory_diff(), "peak_memory_usage": self.peak_memory_usage(), "tensors_created": self.tensors_created, "tensors_deleted": self.tensors_deleted, "call_count": self.call_count, }
[docs] class GPUMemoryProfiler: """Comprehensive GPU memory profiler for PyTorch operations.""" def __init__( self, device: Union[str, int, torch.device, None] = None, track_tensors: bool = True, track_cpu_memory: bool = True, collect_stack_traces: bool = False, ): """ Initialize the Stormlog. Args: device: GPU device to profile (None for auto-detection) track_tensors: Whether to track tensor creation/deletion track_cpu_memory: Whether to track CPU memory usage collect_stack_traces: Whether to collect stack traces for operations """ self.device = self._setup_device(device) self.track_tensors = track_tensors self.track_cpu_memory = track_cpu_memory self.collect_stack_traces = collect_stack_traces self.results: List[ProfileResult] = [] self.snapshots: List[MemorySnapshot] = [] self.function_stats: Dict[str, List[ProfileResult]] = defaultdict(list) self._monitoring = False self._monitor_thread: Optional[threading.Thread] = None self._monitor_interval = 0.1 # 100ms self._tensor_tracker = TensorTracker() if track_tensors else None # Initialize baseline measurements self._baseline_snapshot = self._take_snapshot("baseline") def _setup_device( self, device: Union[str, int, torch.device, None] ) -> torch.device: """Setup and validate the device for profiling.""" resolved_device: torch.device if device is None: if torch.cuda.is_available(): resolved_device = torch.device(f"cuda:{torch.cuda.current_device()}") else: raise RuntimeError("CUDA is not available, cannot profile GPU memory") elif isinstance(device, int): resolved_device = torch.device(f"cuda:{device}") elif isinstance(device, str): resolved_device = torch.device(device) else: resolved_device = device if resolved_device.type != "cuda": raise ValueError("Only CUDA devices are supported for GPU memory profiling") # Ensure device is available device_index = ( resolved_device.index if resolved_device.index is not None else torch.cuda.current_device() ) if device_index >= torch.cuda.device_count(): raise ValueError(f"Device {resolved_device} is not available") return torch.device(f"cuda:{device_index}") def _take_snapshot(self, operation: Optional[str] = None) -> MemorySnapshot: """Take a memory snapshot at the current moment.""" torch.cuda.synchronize(self.device) snapshot = MemorySnapshot( timestamp=time.time(), allocated_memory=torch.cuda.memory_allocated(self.device), reserved_memory=torch.cuda.memory_reserved(self.device), max_memory_allocated=torch.cuda.max_memory_allocated(self.device), max_memory_reserved=torch.cuda.max_memory_reserved(self.device), active_memory=torch.cuda.memory_stats(self.device).get( "active_bytes.all.current", 0 ), inactive_memory=torch.cuda.memory_stats(self.device).get( "inactive_split_bytes.all.current", 0 ), cpu_memory=psutil.virtual_memory().used if self.track_cpu_memory else 0, device_id=self.device.index, operation=operation, ) if self.collect_stack_traces and operation: import traceback snapshot.stack_trace = "".join(traceback.format_stack()[-5:]) return snapshot
[docs] def profile_function( self, func: Callable[..., Any], *args: Any, **kwargs: Any ) -> ProfileResult: """ Profile a single function call. Args: func: Function to profile *args: Arguments to pass to function **kwargs: Keyword arguments to pass to function Returns: ProfileResult with profiling information """ function_name = getattr(func, "__name__", str(func)) # Reset peak memory stats torch.cuda.reset_peak_memory_stats(self.device) # Take before snapshot memory_before = self._take_snapshot(f"before_{function_name}") # Track tensors if enabled if self._tensor_tracker: tensors_before = self._tensor_tracker.count_tensors() # Execute function start_time = time.time() try: _result = func(*args, **kwargs) # Ensure all operations complete torch.cuda.synchronize(self.device) except Exception as exc: # Still capture memory state even if function fails logger.debug("Profiled function raised, capturing error snapshot: %s", exc) memory_after = self._take_snapshot(f"after_{function_name}_error") memory_peak = self._take_snapshot(f"peak_{function_name}_error") profile_result = ProfileResult( function_name=function_name, execution_time=time.time() - start_time, memory_before=memory_before, memory_after=memory_after, memory_peak=memory_peak, memory_allocated=0, memory_freed=0, tensors_created=0, tensors_deleted=0, ) self.results.append(profile_result) self.function_stats[function_name].append(profile_result) raise end_time = time.time() # Take after snapshot memory_after = self._take_snapshot(f"after_{function_name}") # Get peak memory usage memory_stats = torch.cuda.memory_stats(self.device) peak_allocated = memory_stats.get( "allocated_bytes.all.peak", memory_after.allocated_memory ) memory_peak = MemorySnapshot( timestamp=end_time, allocated_memory=peak_allocated, reserved_memory=memory_stats.get( "reserved_bytes.all.peak", memory_after.reserved_memory ), max_memory_allocated=torch.cuda.max_memory_allocated(self.device), max_memory_reserved=torch.cuda.max_memory_reserved(self.device), active_memory=memory_stats.get( "active_bytes.all.peak", memory_after.active_memory ), inactive_memory=memory_after.inactive_memory, cpu_memory=memory_after.cpu_memory, device_id=self.device.index, operation=f"peak_{function_name}", ) # Track tensor changes tensors_created = 0 tensors_deleted = 0 if self._tensor_tracker: tensors_after = self._tensor_tracker.count_tensors() tensors_created = max(0, tensors_after - tensors_before) tensors_deleted = max(0, tensors_before - tensors_after) # Create profile result profile_result = ProfileResult( function_name=function_name, execution_time=end_time - start_time, memory_before=memory_before, memory_after=memory_after, memory_peak=memory_peak, memory_allocated=max( 0, memory_after.allocated_memory - memory_before.allocated_memory ), memory_freed=max( 0, memory_before.allocated_memory - memory_after.allocated_memory ), tensors_created=tensors_created, tensors_deleted=tensors_deleted, ) # Store results self.results.append(profile_result) self.function_stats[function_name].append(profile_result) return profile_result
[docs] @contextmanager def profile_context(self, name: str = "context") -> Any: """ Context manager for profiling a block of code. Args: name: Name for the profiled context Yields: ProfileResult after the context exits """ torch.cuda.reset_peak_memory_stats(self.device) memory_before = self._take_snapshot(f"before_{name}") if self._tensor_tracker: tensors_before = self._tensor_tracker.count_tensors() start_time = time.time() try: yield torch.cuda.synchronize(self.device) finally: end_time = time.time() memory_after = self._take_snapshot(f"after_{name}") # Get peak memory memory_stats = torch.cuda.memory_stats(self.device) peak_allocated = memory_stats.get( "allocated_bytes.all.peak", memory_after.allocated_memory ) memory_peak = MemorySnapshot( timestamp=end_time, allocated_memory=peak_allocated, reserved_memory=memory_stats.get( "reserved_bytes.all.peak", memory_after.reserved_memory ), max_memory_allocated=torch.cuda.max_memory_allocated(self.device), max_memory_reserved=torch.cuda.max_memory_reserved(self.device), active_memory=memory_stats.get( "active_bytes.all.peak", memory_after.active_memory ), inactive_memory=memory_after.inactive_memory, cpu_memory=memory_after.cpu_memory, device_id=self.device.index, operation=f"peak_{name}", ) # Track tensors tensors_created = 0 tensors_deleted = 0 if self._tensor_tracker: tensors_after = self._tensor_tracker.count_tensors() tensors_created = max(0, tensors_after - tensors_before) tensors_deleted = max(0, tensors_before - tensors_after) profile_result = ProfileResult( function_name=name, execution_time=end_time - start_time, memory_before=memory_before, memory_after=memory_after, memory_peak=memory_peak, memory_allocated=max( 0, memory_after.allocated_memory - memory_before.allocated_memory ), memory_freed=max( 0, memory_before.allocated_memory - memory_after.allocated_memory ), tensors_created=tensors_created, tensors_deleted=tensors_deleted, ) self.results.append(profile_result) self.function_stats[name].append(profile_result)
[docs] def start_monitoring(self, interval: float = 0.1) -> None: """ Start continuous memory monitoring. Args: interval: Monitoring interval in seconds """ if self._monitoring: return self._monitoring = True self._monitor_interval = interval self._monitor_thread = threading.Thread(target=self._monitor_memory) self._monitor_thread.daemon = True self._monitor_thread.start()
[docs] def stop_monitoring(self) -> None: """Stop continuous memory monitoring.""" self._monitoring = False if self._monitor_thread: self._monitor_thread.join()
def _monitor_memory(self) -> None: """Background thread for continuous memory monitoring.""" while self._monitoring: snapshot = self._take_snapshot("monitor") self.snapshots.append(snapshot) time.sleep(self._monitor_interval)
[docs] def get_summary(self) -> Dict[str, Any]: """Get a summary of all profiling results.""" if not self.results: return {"message": "No profiling results available"} total_functions = len(self.function_stats) total_calls = len(self.results) # Aggregate statistics total_time = sum(r.execution_time for r in self.results) total_memory_allocated = sum(r.memory_allocated for r in self.results) total_memory_freed = sum(r.memory_freed for r in self.results) peak_memory = max(r.peak_memory_usage() for r in self.results) # Function statistics function_summaries = {} for func_name, results in self.function_stats.items(): function_summaries[func_name] = { "call_count": len(results), "total_time": sum(r.execution_time for r in results), "avg_time": sum(r.execution_time for r in results) / len(results), "total_memory_allocated": sum(r.memory_allocated for r in results), "avg_memory_allocated": sum(r.memory_allocated for r in results) / len(results), "peak_memory": max(r.peak_memory_usage() for r in results), } # Current memory state current_snapshot = self._take_snapshot("current") return { "device": str(self.device), "total_functions_profiled": total_functions, "total_function_calls": total_calls, "total_execution_time": total_time, "total_memory_allocated": total_memory_allocated, "total_memory_freed": total_memory_freed, "net_memory_change": total_memory_allocated - total_memory_freed, "peak_memory_usage": peak_memory, "current_memory_usage": current_snapshot.allocated_memory, "baseline_memory_usage": self._baseline_snapshot.allocated_memory, "memory_change_from_baseline": current_snapshot.allocated_memory - self._baseline_snapshot.allocated_memory, "function_summaries": function_summaries, "monitoring_active": self._monitoring, "snapshots_collected": len(self.snapshots), }
[docs] def clear_results(self) -> None: """Clear all profiling results and reset state.""" self.results.clear() self.snapshots.clear() self.function_stats.clear() torch.cuda.reset_peak_memory_stats(self.device) self._baseline_snapshot = self._take_snapshot("new_baseline")
def __enter__(self) -> "GPUMemoryProfiler": """Support for context manager usage.""" return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: """Cleanup when exiting context manager.""" self.stop_monitoring()
[docs] class TensorTracker: """Tracks tensor creation and deletion for memory profiling.""" def __init__(self) -> None: self._tensor_count = 0 self._setup_hooks() def _setup_hooks(self) -> None: """Setup hooks to track tensor lifecycle.""" # Note: This is a simplified version. Full implementation would require # more sophisticated tensor tracking using PyTorch's autograd hooks pass
[docs] def count_tensors(self) -> int: """Count current number of tracked tensors.""" # Simplified implementation - count all tensors in CUDA memory gc.collect() tensor_count = 0 for obj in gc.get_objects(): if isinstance(obj, torch.Tensor) and obj.is_cuda: tensor_count += 1 return tensor_count