stormlog.jax.utils
Utility functions for JAX memory profiling.
This module provides helper functions for JAX device discovery, memory formatting, system information, and environment validation.
Functions
Return the active JAX backend name. |
|
|
Format memory size in human-readable format. |
Return backend diagnostics for JAX. |
|
|
Return device kind, platform, and live memory statistics. |
Return full system and JAX environment report. |
|
Return True when JAX is importable. |
|
Validate JAX environment for memory profiling. |
- stormlog.jax.utils.jax_is_available()[source]
Return True when JAX is importable.
- Return type:
bool
- stormlog.jax.utils.detect_jax_backend()[source]
Return the active JAX backend name.
Returns one of ‘gpu’, ‘tpu’, or ‘cpu’. Returns ‘cpu’ as a fallback if JAX is not installed or backend detection fails.
- Return type:
str
- stormlog.jax.utils.get_device_info(device_index=0)[source]
Return device kind, platform, and live memory statistics.
- Parameters:
device_index (int) – Index into
jax.local_devices()(default 0).- Returns:
Dictionary with keys
kind,platform,device_id,process_index,memory_stats(raw dict fromdevice.memory_stats()), andclient.- Return type:
Dict[str, Any]
- stormlog.jax.utils.get_backend_info()[source]
Return backend diagnostics for JAX.
Returns a dictionary with the JAX runtime backend classification and platform details.
- Return type:
Dict[str, Any]
- stormlog.jax.utils.get_system_info()[source]
Return full system and JAX environment report.
Includes JAX version, device list, platform, Python version, CPU count, and system memory statistics.
- Return type:
Dict[str, Any]
- stormlog.jax.utils.format_memory(bytes_value)[source]
Format memory size in human-readable format.
Delegates to
stormlog.utils.format_bytes()when available, otherwise provides a standalone implementation.- Parameters:
bytes_value (int | float | None)
- Return type:
str