Source code for stormlog.tensorflow.profiler

"""
Core TensorFlow Stormlog

Main profiling engine for capturing and analyzing GPU memory usage during
TensorFlow model training and inference.
"""

import logging
import threading
import time
import weakref
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import wraps
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    Iterator,
    List,
    Optional,
    ParamSpec,
    TypeVar,
)

from .tf_env import configure_tensorflow_logging

configure_tensorflow_logging()

try:
    import tensorflow as tf

    TF_AVAILABLE = True
except ImportError:
    TF_AVAILABLE = False
    tf = None

if TYPE_CHECKING:
    import tensorflow as tf

P = ParamSpec("P")
R = TypeVar("R")


[docs] @dataclass class MemorySnapshot: """Represents a point-in-time memory snapshot.""" timestamp: float name: str gpu_memory_mb: float cpu_memory_mb: float gpu_memory_reserved_mb: float gpu_utilization: float num_tensors: int tensor_sizes: Dict[str, int] = field(default_factory=dict) operation_name: Optional[str] = None graph_node: Optional[str] = None def __post_init__(self) -> None: """Validate snapshot data.""" if self.gpu_memory_mb < 0: self.gpu_memory_mb = 0.0 if self.cpu_memory_mb < 0: self.cpu_memory_mb = 0.0
[docs] @dataclass class ProfileResult: """Comprehensive profiling results.""" start_time: float end_time: float peak_memory_mb: float average_memory_mb: float min_memory_mb: float total_allocations: int total_deallocations: int snapshots: List[MemorySnapshot] = field(default_factory=list) function_profiles: Dict[str, Dict[str, Any]] = field(default_factory=dict) tensor_lifecycle: Dict[str, Any] = field(default_factory=dict) memory_fragmentation: float = 0.0 efficiency_score: float = 0.0 @property def duration(self) -> float: """Total profiling duration in seconds.""" return self.end_time - self.start_time @property def memory_growth_rate(self) -> float: """Memory growth rate in MB/second.""" if self.duration <= 0: return 0.0 return (self.peak_memory_mb - self.min_memory_mb) / self.duration
[docs] class TensorTracker: """Tracks TensorFlow tensor lifecycle and memory usage.""" def __init__(self) -> None: self.tensors: weakref.WeakSet[Any] = weakref.WeakSet() self.tensor_history: List[Dict[str, Any]] = [] self.creation_times: Dict[int, float] = {} self.tensor_sizes: Dict[int, int] = {} self._lock = threading.Lock()
[docs] def track_tensor( self, tensor: "tf.Tensor", operation_name: str = "unknown" ) -> None: """Track a new tensor.""" if not TF_AVAILABLE or tensor is None: return with self._lock: tensor_id = id(tensor) self.tensors.add(tensor) self.creation_times[tensor_id] = time.time() # Calculate tensor size try: size_bytes = tensor.numpy().nbytes if hasattr(tensor, "numpy") else 0 self.tensor_sizes[tensor_id] = size_bytes self.tensor_history.append( { "tensor_id": tensor_id, "operation": operation_name, "timestamp": time.time(), "action": "created", "size_bytes": size_bytes, "shape": ( tensor.shape.as_list() if hasattr(tensor, "shape") else [] ), } ) except Exception as e: logging.warning(f"Could not track tensor: {e}")
[docs] def get_active_tensors(self) -> Dict[str, Any]: """Get information about currently active tensors.""" with self._lock: active_count = len(self.tensors) total_size = sum(self.tensor_sizes.get(id(t), 0) for t in self.tensors) return { "count": active_count, "total_size_mb": total_size / (1024 * 1024), "average_size_mb": ( (total_size / active_count / (1024 * 1024)) if active_count > 0 else 0 ), }
[docs] def get_tensor_lifecycle(self) -> List[Dict[str, Any]]: """Get complete tensor lifecycle history.""" with self._lock: return self.tensor_history.copy()
[docs] class TFMemoryProfiler: """Main TensorFlow Stormlog class.""" def __init__( self, device: Optional[str] = None, enable_tensor_tracking: bool = True ) -> None: """ Initialize TensorFlow memory profiler. Args: device: TensorFlow device name (e.g., '/GPU:0', '/CPU:0') enable_tensor_tracking: Whether to track individual tensors """ if not TF_AVAILABLE: raise ImportError( "TensorFlow not available. Please install TensorFlow to use this profiler." ) self.device = device or self._get_default_device() self.enable_tensor_tracking = enable_tensor_tracking # Initialize components self.tensor_tracker: Optional[TensorTracker] = ( TensorTracker() if enable_tensor_tracking else None ) self.snapshots: List[MemorySnapshot] = [] self.function_profiles: Dict[str, Dict[str, Any]] = {} self.profiling_active = False self.profile_thread: Optional[threading.Thread] = None self._lock = threading.Lock() # Setup TensorFlow memory growth self._setup_tf_memory() logging.info(f"TensorFlow Stormlog initialized for device: {self.device}") def _get_default_device(self) -> str: """Get default TensorFlow device.""" try: gpus = tf.config.list_physical_devices("GPU") if gpus: return "/GPU:0" else: return "/CPU:0" except Exception as exc: logging.debug("TF device detection failed: %s", exc) return "/CPU:0" def _setup_tf_memory(self) -> None: """Setup TensorFlow memory growth to avoid OOM errors.""" try: gpus = tf.config.list_physical_devices("GPU") if gpus: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) logging.info(f"Enabled memory growth for {gpu}") except Exception as e: logging.warning(f"Could not setup TensorFlow memory growth: {e}") def _get_memory_info(self) -> Dict[str, float]: """Get current memory usage information.""" try: if "/GPU:" in self.device: # GPU memory information gpu_details = tf.config.experimental.get_memory_info(self.device) gpu_memory_mb = gpu_details.get("current", 0) / (1024 * 1024) gpu_reserved_mb = gpu_details.get("peak", 0) / (1024 * 1024) gpu_utilization = min( 100.0, ( (gpu_memory_mb / gpu_reserved_mb * 100) if gpu_reserved_mb > 0 else 0 ), ) else: gpu_memory_mb = 0.0 gpu_reserved_mb = 0.0 gpu_utilization = 0.0 # CPU memory (approximate) import psutil process = psutil.Process() cpu_memory_mb = process.memory_info().rss / (1024 * 1024) return { "gpu_memory_mb": gpu_memory_mb, "cpu_memory_mb": cpu_memory_mb, "gpu_reserved_mb": gpu_reserved_mb, "gpu_utilization": gpu_utilization, } except Exception as e: logging.warning(f"Could not get memory info: {e}") return { "gpu_memory_mb": 0.0, "cpu_memory_mb": 0.0, "gpu_reserved_mb": 0.0, "gpu_utilization": 0.0, }
[docs] def capture_snapshot(self, name: str = "snapshot") -> MemorySnapshot: """Capture current memory state.""" memory_info = self._get_memory_info() # Get tensor information num_tensors = 0 tensor_sizes = {} if self.tensor_tracker: active_tensors = self.tensor_tracker.get_active_tensors() num_tensors = active_tensors["count"] tensor_sizes = {"total_mb": active_tensors["total_size_mb"]} snapshot = MemorySnapshot( timestamp=time.time(), name=name, gpu_memory_mb=memory_info["gpu_memory_mb"], cpu_memory_mb=memory_info["cpu_memory_mb"], gpu_memory_reserved_mb=memory_info["gpu_reserved_mb"], gpu_utilization=memory_info["gpu_utilization"], num_tensors=num_tensors, tensor_sizes=tensor_sizes, ) with self._lock: self.snapshots.append(snapshot) return snapshot
[docs] def profile_function(self, func: Callable[P, R]) -> Callable[P, R]: """Decorator to profile function memory usage.""" @wraps(func) def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: func_name = func.__name__ # Capture before before_snapshot = self.capture_snapshot(f"{func_name}_before") start_time = time.time() try: # Execute function result = func(*args, **kwargs) # Capture after end_time = time.time() after_snapshot = self.capture_snapshot(f"{func_name}_after") # Calculate metrics duration = end_time - start_time memory_used = ( after_snapshot.gpu_memory_mb - before_snapshot.gpu_memory_mb ) peak_memory = max( before_snapshot.gpu_memory_mb, after_snapshot.gpu_memory_mb ) # Store function profile with self._lock: if func_name not in self.function_profiles: self.function_profiles[func_name] = { "calls": 0, "total_duration": 0.0, "total_memory_used": 0.0, "peak_memory": 0.0, "snapshots": [], } profile = self.function_profiles[func_name] profile["calls"] += 1 profile["total_duration"] += duration profile["total_memory_used"] += memory_used profile["peak_memory"] = max(profile["peak_memory"], peak_memory) profile["snapshots"].extend([before_snapshot, after_snapshot]) return result except Exception as e: # Capture error state _error_snapshot = self.capture_snapshot(f"{func_name}_error") logging.error(f"Error in profiled function {func_name}: {e}") raise return wrapper
[docs] @contextmanager def profile_context(self, name: str = "context") -> Iterator[None]: """Context manager for profiling code blocks.""" before_snapshot = self.capture_snapshot(f"{name}_start") start_time = time.time() try: yield finally: end_time = time.time() after_snapshot = self.capture_snapshot(f"{name}_end") # Store context profile duration = end_time - start_time memory_used = after_snapshot.gpu_memory_mb - before_snapshot.gpu_memory_mb with self._lock: if name not in self.function_profiles: self.function_profiles[name] = { "calls": 0, "total_duration": 0.0, "total_memory_used": 0.0, "peak_memory": 0.0, "snapshots": [], } profile = self.function_profiles[name] profile["calls"] += 1 profile["total_duration"] += duration profile["total_memory_used"] += memory_used profile["peak_memory"] = max( profile["peak_memory"], max(before_snapshot.gpu_memory_mb, after_snapshot.gpu_memory_mb), ) profile["snapshots"].extend([before_snapshot, after_snapshot])
[docs] def start_continuous_profiling(self, interval: float = 1.0) -> None: """Start continuous memory profiling.""" self.profiling_active = True def profile_loop() -> None: while self.profiling_active: self.capture_snapshot("continuous") time.sleep(interval) self.profile_thread = threading.Thread(target=profile_loop, daemon=True) self.profile_thread.start() logging.info("Started continuous profiling")
[docs] def stop_continuous_profiling(self) -> None: """Stop continuous memory profiling.""" self.profiling_active = False if self.profile_thread: self.profile_thread.join(timeout=5.0) self.profile_thread = None logging.info("Stopped continuous profiling")
[docs] def get_results(self) -> ProfileResult: """Get comprehensive profiling results.""" with self._lock: if not self.snapshots: # Return empty results return ProfileResult( start_time=time.time(), end_time=time.time(), peak_memory_mb=0.0, average_memory_mb=0.0, min_memory_mb=0.0, total_allocations=0, total_deallocations=0, snapshots=[], function_profiles={}, ) # Calculate metrics from snapshots gpu_memories = [s.gpu_memory_mb for s in self.snapshots] peak_memory = max(gpu_memories) average_memory = sum(gpu_memories) / len(gpu_memories) min_memory = min(gpu_memories) # Estimate allocations/deallocations from memory changes total_allocations = sum( 1 for i in range(1, len(gpu_memories)) if gpu_memories[i] > gpu_memories[i - 1] ) total_deallocations = sum( 1 for i in range(1, len(gpu_memories)) if gpu_memories[i] < gpu_memories[i - 1] ) # Get tensor lifecycle if available tensor_lifecycle = {} if self.tensor_tracker: tensor_lifecycle = { "history": self.tensor_tracker.get_tensor_lifecycle(), "active": self.tensor_tracker.get_active_tensors(), } return ProfileResult( start_time=self.snapshots[0].timestamp, end_time=self.snapshots[-1].timestamp, peak_memory_mb=peak_memory, average_memory_mb=average_memory, min_memory_mb=min_memory, total_allocations=total_allocations, total_deallocations=total_deallocations, snapshots=self.snapshots.copy(), function_profiles=self.function_profiles.copy(), tensor_lifecycle=tensor_lifecycle, )
[docs] def reset(self) -> None: """Reset profiler state.""" with self._lock: self.snapshots.clear() self.function_profiles.clear() if self.tensor_tracker: self.tensor_tracker.tensor_history.clear() logging.info("Profiler state reset")
def __enter__(self) -> "TFMemoryProfiler": """Context manager entry.""" self.capture_snapshot("context_start") return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: """Context manager exit.""" self.capture_snapshot("context_end") if self.profiling_active: self.stop_continuous_profiling()