Source code for stormlog.tui.workloads
"""Sample workload helpers and output formatters for the Textual TUI."""
from __future__ import annotations
from typing import Any, cast
from stormlog.utils import format_bytes
[docs]
def run_pytorch_sample_workload(profiler_cls: Any, torch_module: Any) -> dict[str, Any]:
profiler = profiler_cls()
def workload() -> Any:
x = torch_module.randn((3072, 3072), device="cuda")
y = torch_module.matmul(x, x)
return y.sum()
profiler.profile_function(workload)
return cast(dict[str, Any], profiler.get_summary())
[docs]
def run_tensorflow_sample_workload(profiler_cls: Any, tf_module: Any) -> Any:
profiler = profiler_cls()
with profiler.profile_context("tf_sample"):
tensor = tf_module.random.normal((2048, 2048))
product = tf_module.matmul(tensor, tensor)
tf_module.reduce_sum(product)
return profiler.get_results()
[docs]
def run_cpu_sample_workload(profiler_cls: Any) -> dict[str, Any]:
profiler = profiler_cls()
def workload() -> int:
data = [i for i in range(500000)]
return sum(data)
profiler.profile_function(workload)
return cast(dict[str, Any], profiler.get_summary())
[docs]
def format_pytorch_summary(summary: dict[str, Any]) -> str:
peak = summary.get("peak_memory_usage", 0)
delta = summary.get("memory_change_from_baseline", 0)
delta_sign = "-" if delta < 0 else ""
calls = summary.get("total_function_calls", "N/A")
lines = [
f"Functions profiled: {summary.get('total_functions_profiled', 'N/A')}",
f"Total calls: {calls}",
f"Peak memory: {format_bytes(peak)}",
f"Δ from baseline: {delta_sign}{format_bytes(abs(delta))}",
]
return "\n".join(lines)
[docs]
def format_tensorflow_results(results: Any) -> str:
duration = getattr(results, "duration", 0.0)
peak_memory_mb = getattr(results, "peak_memory_mb", 0.0)
average_memory_mb = getattr(results, "average_memory_mb", 0.0)
snapshots = getattr(results, "snapshots", [])
duration = 0.0 if duration is None else duration
peak_memory_mb = 0.0 if peak_memory_mb is None else peak_memory_mb
average_memory_mb = 0.0 if average_memory_mb is None else average_memory_mb
snapshots = [] if snapshots is None else snapshots
lines = [
f"Duration: {duration:.2f}s",
f"Peak memory: {peak_memory_mb:.2f} MB",
f"Average memory: {average_memory_mb:.2f} MB",
f"Snapshots: {len(snapshots)}",
]
return "\n".join(lines)
[docs]
def format_cpu_summary(summary: dict[str, Any]) -> str:
delta = summary.get("memory_change_from_baseline", 0)
delta_sign = "-" if delta < 0 else ""
lines = [
f"Snapshots collected: {summary.get('snapshots_collected', 0)}",
f"Peak RSS: {format_bytes(summary.get('peak_memory_usage', 0))}",
f"Δ from baseline: {delta_sign}{format_bytes(abs(delta))}",
]
return "\n".join(lines)