Source code for stormlog.visualizer

"""Visualization tools for GPU memory profiling data."""

from collections import defaultdict
from datetime import datetime
from typing import Dict, List, Optional, Union

# Plotting imports
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Interactive plotting
import plotly.graph_objects as go
import seaborn as sns
from matplotlib.figure import Figure
from plotly.subplots import make_subplots

from .distributed_analysis import RankTimelinePoint, analyze_cross_rank_events
from .profiler import GPUMemoryProfiler, MemorySnapshot, ProfileResult
from .telemetry import TelemetryEventV2


[docs] class MemoryVisualizer: """Comprehensive visualization tool for memory profiling data.""" def __init__(self, profiler: Optional[GPUMemoryProfiler] = None): """ Initialize the visualizer. Args: profiler: GPUMemoryProfiler instance to visualize """ self.profiler = profiler self.style_config = { "figure_size": (12, 8), "dpi": 100, "color_palette": "viridis", "font_size": 10, "title_size": 14, "label_size": 12, } # Set up plotting style plt.style.use("default") sns.set_palette(self.style_config["color_palette"])
[docs] def plot_memory_timeline( self, results: Optional[List[ProfileResult]] = None, snapshots: Optional[List[MemorySnapshot]] = None, save_path: Optional[str] = None, interactive: bool = True, ) -> Union[plt.Figure, go.Figure]: """ Plot memory usage over time. Args: results: List of ProfileResults to plot snapshots: List of MemorySnapshots to plot save_path: Path to save the plot interactive: Whether to create interactive plot Returns: Matplotlib or Plotly figure """ # Get data if results is None and self.profiler: results = self.profiler.results if snapshots is None and self.profiler: snapshots = self.profiler.snapshots if not results and not snapshots: raise ValueError("No data available for plotting") # Prepare data timestamps = [] allocated_memory = [] reserved_memory = [] labels = [] # Add snapshot data if snapshots: for snapshot in snapshots: timestamps.append(snapshot.timestamp) allocated_memory.append(snapshot.allocated_memory) reserved_memory.append(snapshot.reserved_memory) labels.append(snapshot.operation or "monitor") # Add result data if results: for result in results: # Before snapshot timestamps.append(result.memory_before.timestamp) allocated_memory.append(result.memory_before.allocated_memory) reserved_memory.append(result.memory_before.reserved_memory) labels.append(f"before_{result.function_name}") # After snapshot timestamps.append(result.memory_after.timestamp) allocated_memory.append(result.memory_after.allocated_memory) reserved_memory.append(result.memory_after.reserved_memory) labels.append(f"after_{result.function_name}") if not timestamps: raise ValueError("No timestamp data available") # Convert to relative time (seconds from start) start_time = min(timestamps) relative_times = [(t - start_time) for t in timestamps] if interactive: return self._create_interactive_timeline( relative_times, allocated_memory, reserved_memory, labels, save_path ) else: return self._create_static_timeline( relative_times, allocated_memory, reserved_memory, labels, save_path )
[docs] def plot_cross_rank_timeline( self, events: List[TelemetryEventV2], save_path: Optional[str] = None, ) -> plt.Figure: """Plot a merged, aligned cross-rank device-memory timeline.""" if not events: raise ValueError("No telemetry events available for cross-rank plotting") merge_result, first_cause_result = analyze_cross_rank_events(events) if not merge_result.merged_points: raise ValueError("No timeline points available for cross-rank plotting") grouped_points: Dict[int, List[RankTimelinePoint]] = defaultdict(list) for point in merge_result.merged_points: grouped_points[point.rank].append(point) first_aligned_timestamp = min( point.aligned_timestamp_ns for point in merge_result.merged_points ) fig_obj, ax = plt.subplots( 1, 1, figsize=self.style_config["figure_size"], dpi=self.style_config["dpi"], ) fig: Figure = fig_obj for rank in sorted(grouped_points): rank_points = sorted( grouped_points[rank], key=lambda point: (point.aligned_timestamp_ns, point.timestamp_ns), ) relative_times = [ (point.aligned_timestamp_ns - first_aligned_timestamp) / 1_000_000_000 for point in rank_points ] device_used_gb = [ point.device_used_bytes / (1024**3) for point in rank_points ] ax.plot(relative_times, device_used_gb, linewidth=2, label=f"Rank {rank}") top_suspect = ( first_cause_result.suspects[0] if first_cause_result.suspects else None ) if top_suspect is not None: spike_relative_time = ( top_suspect.aligned_first_spike_timestamp_ns - first_aligned_timestamp ) / 1_000_000_000 for point in grouped_points.get(top_suspect.rank, []): if ( point.aligned_timestamp_ns == top_suspect.aligned_first_spike_timestamp_ns ): ax.scatter( [spike_relative_time], [point.device_used_bytes / (1024**3)], color="crimson", s=80, zorder=5, label=f"Top suspect rank {top_suspect.rank}", ) ax.annotate( f"Rank {top_suspect.rank} first cause", xy=(spike_relative_time, point.device_used_bytes / (1024**3)), xytext=(10, 10), textcoords="offset points", fontsize=self.style_config["font_size"], color="crimson", ) break if first_cause_result.cluster_onset_timestamp_ns is not None: cluster_onset_seconds = ( first_cause_result.cluster_onset_timestamp_ns - first_aligned_timestamp ) / 1_000_000_000 ax.axvline( cluster_onset_seconds, color="black", linestyle="--", linewidth=1.5, label="Cluster onset", ) title = "Cross-Rank Memory Timeline" if merge_result.job_id: title += f" ({merge_result.job_id})" if top_suspect is not None: title += f" - top suspect rank {top_suspect.rank}" ax.set_title(title, fontsize=self.style_config["title_size"]) ax.set_xlabel( "Aligned Time (seconds)", fontsize=self.style_config["label_size"] ) ax.set_ylabel( "Device Used Memory (GB)", fontsize=self.style_config["label_size"] ) ax.grid(True, alpha=0.3) ax.legend() fig.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig
def _create_static_timeline( self, times: List[float], allocated: List[int], reserved: List[int], labels: List[str], save_path: Optional[str], ) -> plt.Figure: """Create static matplotlib timeline plot.""" fig_obj, (ax1, ax2) = plt.subplots( 2, 1, figsize=self.style_config["figure_size"], sharex=True, dpi=self.style_config["dpi"], ) fig: Figure = fig_obj # Plot allocated memory ax1.plot( times, [m / (1024**3) for m in allocated], "b-", linewidth=2, label="Allocated", ) ax1.fill_between( times, [m / (1024**3) for m in allocated], alpha=0.3, color="blue" ) ax1.set_ylabel( "Allocated Memory (GB)", fontsize=self.style_config["label_size"] ) ax1.set_title( "GPU Memory Usage Over Time", fontsize=self.style_config["title_size"] ) ax1.grid(True, alpha=0.3) ax1.legend() # Plot reserved memory ax2.plot( times, [m / (1024**3) for m in reserved], "r-", linewidth=2, label="Reserved", ) ax2.fill_between( times, [m / (1024**3) for m in reserved], alpha=0.3, color="red" ) ax2.set_ylabel("Reserved Memory (GB)", fontsize=self.style_config["label_size"]) ax2.set_xlabel("Time (seconds)", fontsize=self.style_config["label_size"]) ax2.grid(True, alpha=0.3) ax2.legend() plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches="tight") return fig def _create_interactive_timeline( self, times: List[float], allocated: List[int], reserved: List[int], labels: List[str], save_path: Optional[str], ) -> go.Figure: """Create interactive plotly timeline plot.""" fig = make_subplots( rows=2, cols=1, shared_xaxes=True, subplot_titles=("Allocated Memory", "Reserved Memory"), vertical_spacing=0.1, ) # Convert bytes to GB for better readability allocated_gb = [m / (1024**3) for m in allocated] reserved_gb = [m / (1024**3) for m in reserved] # Allocated memory trace fig.add_trace( go.Scatter( x=times, y=allocated_gb, mode="lines+markers", name="Allocated Memory", line=dict(color="blue", width=2), fill="tonexty", hovertemplate="<b>Time:</b> %{x:.2f}s<br>" + "<b>Allocated:</b> %{y:.2f} GB<br>" + "<b>Operation:</b> %{text}<extra></extra>", text=labels, ), row=1, col=1, ) # Reserved memory trace fig.add_trace( go.Scatter( x=times, y=reserved_gb, mode="lines+markers", name="Reserved Memory", line=dict(color="red", width=2), fill="tonexty", hovertemplate="<b>Time:</b> %{x:.2f}s<br>" + "<b>Reserved:</b> %{y:.2f} GB<br>" + "<b>Operation:</b> %{text}<extra></extra>", text=labels, ), row=2, col=1, ) # Update layout fig.update_layout( title="GPU Memory Usage Timeline", showlegend=True, height=800, hovermode="closest", ) fig.update_xaxes(title_text="Time (seconds)", row=2, col=1) fig.update_yaxes(title_text="Memory (GB)", row=1, col=1) fig.update_yaxes(title_text="Memory (GB)", row=2, col=1) if save_path: if save_path.endswith(".html"): fig.write_html(save_path) else: fig.write_image(save_path, width=1200, height=800) return fig
[docs] def plot_function_comparison( self, results: Optional[List[ProfileResult]] = None, metric: str = "memory_allocated", save_path: Optional[str] = None, interactive: bool = True, ) -> Union[plt.Figure, go.Figure]: """ Compare memory usage across different functions. Args: results: List of ProfileResults to compare metric: Metric to compare ('memory_allocated', 'execution_time', 'peak_memory') save_path: Path to save the plot interactive: Whether to create interactive plot Returns: Matplotlib or Plotly figure """ if results is None and self.profiler: results = self.profiler.results if not results: raise ValueError("No results available for comparison") # Aggregate data by function name function_memory_allocated: Dict[str, List[float]] = {} function_execution_time: Dict[str, List[float]] = {} function_peak_memory: Dict[str, List[float]] = {} for result in results: func_name = result.function_name function_memory_allocated.setdefault(func_name, []).append( float(result.memory_allocated) ) function_execution_time.setdefault(func_name, []).append( float(result.execution_time) ) function_peak_memory.setdefault(func_name, []).append( float(result.peak_memory_usage()) ) # Prepare plot data functions = list(function_memory_allocated.keys()) if metric == "memory_allocated": values = [ float(np.mean(function_memory_allocated[func])) for func in functions ] ylabel = "Average Memory Allocated (GB)" title = "Average Memory Allocation by Function" values = [v / (1024**3) for v in values] # Convert to GB elif metric == "execution_time": values = [ float(np.mean(function_execution_time[func])) for func in functions ] ylabel = "Average Execution Time (seconds)" title = "Average Execution Time by Function" elif metric == "peak_memory": values = [float(np.max(function_peak_memory[func])) for func in functions] ylabel = "Peak Memory Usage (GB)" title = "Peak Memory Usage by Function" values = [v / (1024**3) for v in values] # Convert to GB else: raise ValueError(f"Unknown metric: {metric}") if interactive: return self._create_interactive_bar_chart( functions, values, ylabel, title, save_path ) else: return self._create_static_bar_chart( functions, values, ylabel, title, save_path )
def _create_static_bar_chart( self, labels: List[str], values: List[float], ylabel: str, title: str, save_path: Optional[str], ) -> plt.Figure: """Create static matplotlib bar chart.""" fig_obj, ax = plt.subplots( figsize=self.style_config["figure_size"], dpi=self.style_config["dpi"], ) fig: Figure = fig_obj bars = ax.bar(labels, values, alpha=0.8) ax.set_ylabel(ylabel, fontsize=self.style_config["label_size"]) ax.set_title(title, fontsize=self.style_config["title_size"]) ax.grid(True, alpha=0.3, axis="y") # Rotate x-axis labels if they're too long if max(len(label) for label in labels) > 10: plt.xticks(rotation=45, ha="right") # Add value labels on bars for bar in bars: height = bar.get_height() ax.text( bar.get_x() + bar.get_width() / 2.0, height, f"{height:.2f}", ha="center", va="bottom", fontsize=self.style_config["font_size"], ) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches="tight") return fig def _create_interactive_bar_chart( self, labels: List[str], values: List[float], ylabel: str, title: str, save_path: Optional[str], ) -> go.Figure: """Create interactive plotly bar chart.""" fig = go.Figure( data=[ go.Bar( x=labels, y=values, text=[f"{v:.2f}" for v in values], textposition="auto", hovertemplate="<b>%{x}</b><br>" + f"{ylabel}: %{{y:.2f}}<extra></extra>", ) ] ) fig.update_layout( title=title, xaxis_title="Function", yaxis_title=ylabel, height=600 ) if save_path: if save_path.endswith(".html"): fig.write_html(save_path) else: fig.write_image(save_path, width=1000, height=600) return fig
[docs] def plot_memory_heatmap( self, results: Optional[List[ProfileResult]] = None, save_path: Optional[str] = None, ) -> plt.Figure: """ Create a heatmap showing memory usage patterns. Args: results: List of ProfileResults to analyze save_path: Path to save the plot Returns: Matplotlib figure """ if results is None and self.profiler: results = self.profiler.results if not results: raise ValueError("No results available for heatmap") # Create data matrix functions = sorted({r.function_name for r in results}) metrics = ["execution_time", "memory_allocated", "memory_freed", "peak_memory"] data_matrix = np.zeros((len(functions), len(metrics))) for i, func in enumerate(functions): func_results = [r for r in results if r.function_name == func] # Calculate average metrics data_matrix[i, 0] = np.mean([r.execution_time for r in func_results]) data_matrix[i, 1] = np.mean([r.memory_allocated for r in func_results]) / ( 1024**3 ) # GB data_matrix[i, 2] = np.mean([r.memory_freed for r in func_results]) / ( 1024**3 ) # GB data_matrix[i, 3] = np.mean( [r.peak_memory_usage() for r in func_results] ) / ( 1024**3 ) # GB # Create heatmap fig_obj, ax = plt.subplots( figsize=(10, max(6, len(functions) * 0.5)), dpi=self.style_config["dpi"], ) fig: Figure = fig_obj # Normalize data for better visualization normalized_data = np.zeros_like(data_matrix) for j in range(data_matrix.shape[1]): col_max: float = float(np.max(data_matrix[:, j])) if col_max > 0: normalized_data[:, j] = data_matrix[:, j] / col_max im = ax.imshow(normalized_data, cmap="YlOrRd", aspect="auto") # Set ticks and labels ax.set_xticks(np.arange(len(metrics))) ax.set_yticks(np.arange(len(functions))) ax.set_xticklabels( [ "Execution Time", "Memory Allocated (GB)", "Memory Freed (GB)", "Peak Memory (GB)", ] ) ax.set_yticklabels(functions) # Rotate the tick labels and set their alignment plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # Add colorbar cbar = fig.colorbar(im, ax=ax) cbar.ax.set_ylabel("Normalized Value", rotation=-90, va="bottom") # Add text annotations for i in range(len(functions)): for j in range(len(metrics)): if metrics[j] == "execution_time": text = f"{data_matrix[i, j]:.3f}s" else: text = f"{data_matrix[i, j]:.2f}GB" ax.text(j, i, text, ha="center", va="center", color="black", fontsize=8) ax.set_title( "Memory Usage Heatmap by Function", fontsize=self.style_config["title_size"] ) fig.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches="tight") return fig
[docs] def create_dashboard( self, results: Optional[List[ProfileResult]] = None, snapshots: Optional[List[MemorySnapshot]] = None, save_path: Optional[str] = None, ) -> go.Figure: """ Create a comprehensive dashboard with multiple visualizations. Args: results: List of ProfileResults snapshots: List of MemorySnapshots save_path: Path to save the dashboard Returns: Plotly figure with subplots """ if results is None and self.profiler: results = self.profiler.results if snapshots is None and self.profiler: snapshots = self.profiler.snapshots # Create subplot grid fig = make_subplots( rows=2, cols=2, subplot_titles=( "Memory Timeline", "Function Comparison", "Memory Distribution", "Peak Memory Usage", ), specs=[ [{"secondary_y": True}, {"type": "bar"}], [{"type": "histogram"}, {"type": "scatter"}], ], ) # Timeline plot (top left) if snapshots: times = [(s.timestamp - snapshots[0].timestamp) for s in snapshots] allocated = [s.allocated_memory / (1024**3) for s in snapshots] fig.add_trace( go.Scatter(x=times, y=allocated, mode="lines", name="Allocated Memory"), row=1, col=1, ) # Function comparison (top right) if results: func_memory: Dict[str, List[float]] = {} for result in results: if result.function_name not in func_memory: func_memory[result.function_name] = [] func_memory[result.function_name].append( result.memory_allocated / (1024**3) ) functions = list(func_memory.keys()) avg_memory = [np.mean(func_memory[f]) for f in functions] fig.add_trace( go.Bar(x=functions, y=avg_memory, name="Avg Memory"), row=1, col=2 ) # Memory distribution (bottom left) if results: memory_values = [r.memory_allocated / (1024**3) for r in results] fig.add_trace( go.Histogram(x=memory_values, name="Memory Distribution"), row=2, col=1 ) # Peak memory scatter (bottom right) if results: exec_times = [r.execution_time for r in results] peak_memory = [r.peak_memory_usage() / (1024**3) for r in results] fig.add_trace( go.Scatter( x=exec_times, y=peak_memory, mode="markers", name="Execution Time vs Peak Memory", ), row=2, col=2, ) # Update layout fig.update_layout( title_text="GPU Memory Profiling Dashboard", height=800, showlegend=True ) # Update axis labels fig.update_xaxes(title_text="Time (s)", row=1, col=1) fig.update_yaxes(title_text="Memory (GB)", row=1, col=1) fig.update_xaxes(title_text="Function", row=1, col=2) fig.update_yaxes(title_text="Avg Memory (GB)", row=1, col=2) fig.update_xaxes(title_text="Memory (GB)", row=2, col=1) fig.update_yaxes(title_text="Count", row=2, col=1) fig.update_xaxes(title_text="Execution Time (s)", row=2, col=2) fig.update_yaxes(title_text="Peak Memory (GB)", row=2, col=2) if save_path: if save_path.endswith(".html"): fig.write_html(save_path) else: fig.write_image(save_path, width=1400, height=800) return fig
[docs] def export_data( self, results: Optional[List[ProfileResult]] = None, snapshots: Optional[List[MemorySnapshot]] = None, format: str = "csv", save_path: str = "memory_profile_data", ) -> str: """ Export profiling data to various formats. Args: results: List of ProfileResults to export snapshots: List of MemorySnapshots to export format: Export format ('csv', 'json', 'excel') save_path: Base path for saved files Returns: Path to saved file """ if results is None and self.profiler: results = self.profiler.results if snapshots is None and self.profiler: snapshots = self.profiler.snapshots timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") if format == "csv": # Export results if results: results_data = [] for r in results: results_data.append( { "function_name": r.function_name, "execution_time": r.execution_time, "memory_allocated": r.memory_allocated, "memory_freed": r.memory_freed, "peak_memory": r.peak_memory_usage(), "memory_diff": r.memory_diff(), "tensors_created": r.tensors_created, "tensors_deleted": r.tensors_deleted, } ) results_df = pd.DataFrame(results_data) results_path = f"{save_path}_results_{timestamp}.csv" results_df.to_csv(results_path, index=False) # Export snapshots if snapshots: snapshots_data = [] for s in snapshots: snapshots_data.append( { "timestamp": s.timestamp, "operation": s.operation, "allocated_memory": s.allocated_memory, "reserved_memory": s.reserved_memory, "active_memory": s.active_memory, "inactive_memory": s.inactive_memory, "device_id": s.device_id, } ) snapshots_df = pd.DataFrame(snapshots_data) snapshots_path = f"{save_path}_snapshots_{timestamp}.csv" snapshots_df.to_csv(snapshots_path, index=False) return f"{save_path}_{timestamp}.csv" elif format == "json": import json export_data = { "metadata": { "export_time": timestamp, "num_results": len(results) if results else 0, "num_snapshots": len(snapshots) if snapshots else 0, }, "results": [r.to_dict() for r in results] if results else [], "snapshots": [s.to_dict() for s in snapshots] if snapshots else [], } json_path = f"{save_path}_{timestamp}.json" with open(json_path, "w") as f: json.dump(export_data, f, indent=2, default=str) return json_path else: raise ValueError(f"Unsupported format: {format}")
[docs] def show(self, fig: Union[plt.Figure, go.Figure]) -> None: """Display a figure.""" if isinstance(fig, plt.Figure): plt.show() else: fig.show()