Utilities¶
numerax.utils ¶
Utility functions for the numerax package.
This module provides development utilities for creating JAX-compatible functions and tools for working with PyTree structures, including parameter counting for machine learning models.
count_params ¶
Count the total number of parameters in a PyTree structure.
Overview¶
This function counts parameters in PyTree-based models by filtering for array-like objects and summing their sizes. It is particularly useful for neural network models built with JAX frameworks like Equinox.
The function traverses the PyTree structure, applies a filter to identify parameter arrays, and computes the total parameter count.
Args¶
- pytree: The PyTree structure to count parameters in (e.g., a model, dict of arrays, or nested structure)
- filter: Optional filter function to identify parameters. If
None, usesequinox.is_arrayas the default filter. Custom filters should accept a single argument and returnTruefor objects that should be counted - verbose: If
True, prints the parameter count in scientific notation. IfFalse, only returns the count silently
Returns¶
The total number of parameters as an integer
Requirements¶
- equinox: Install with
pip install numerax[sciml]orpip install equinox
Example¶
import jax.numpy as jnp
from numerax.utils import count_params
# Simple dict-based model
model = {"weights": jnp.ones((10, 5)), "bias": jnp.zeros(5)}
count = count_params(model)
# Prints: Number of parameters: 5.5e+01
# Returns: 55
# With custom filter
count = count_params(
model,
filter=lambda x: hasattr(x, "ndim") and x.ndim > 1,
verbose=False,
)
# Returns: 50 (only the weights matrix)
# With Equinox model
import equinox as eqx
class MLP(eqx.Module):
layers: list
def __init__(self, key):
self.layers = [
eqx.nn.Linear(10, 64, key=key),
eqx.nn.Linear(64, 1, key=key),
]
model = MLP(jax.random.PRNGKey(0))
count = count_params(model)
# Counts all trainable parameters in the MLP
Notes¶
- The default filter (
equinox.is_array) correctly identifies parameter arrays in Equinox modules and standard JAX PyTrees - For custom filtering logic, provide a function that returns
Truefor leaves that should be counted as parameters - The function handles nested PyTree structures automatically
preserve_metadata ¶
Wrapper that ensures a decorator preserves function metadata for documentation tools.
Overview¶
When wrapping functions with decorators, metadata like __name__,
__doc__, and __module__ can be lost if the decorator doesn't
explicitly preserve them. This wrapper uses functools.wraps to
ensure metadata is maintained, which is important for documentation
generators like pdoc and for debugging.
Args¶
- decorator: The decorator function to wrap
Returns¶
A new decorator that preserves metadata
Example¶
from numerax.utils import preserve_metadata
# A simple decorator that doesn't preserve metadata
def my_decorator(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
# Without preserve_metadata, __doc__ and __name__ are lost
@my_decorator
def func1(x):
"""This docstring will be lost."""
return x
# With preserve_metadata, they are preserved
@preserve_metadata(my_decorator)
def func2(x):
"""This docstring will be preserved."""
return x
tree_summary ¶
Pretty-print PyTree structure with shapes and parameter counts.
Overview¶
This function displays a hierarchical view of a PyTree structure
(e.g., neural network models) showing the organization, array shapes,
data types, and parameter counts at each level. The output is
similar to Keras' model.summary() or torchinfo's summaries. This is
compatible with PyTree-based models from frameworks like Equinox.
Args¶
- pytree: The PyTree structure to summarize (e.g., a model, dict of arrays, or nested structure)
- is_leaf: Optional function to identify leaf nodes. If
None, usesequinox.is_arrayas the default. Leaf nodes are displayed with shape, dtype, and parameter count details. Custom functions should accept a single argument and returnTruefor leaves - max_depth: Maximum nesting depth to display. Nodes deeper than this level will not be shown. Defaults to 3
- verbose: If
True, prints the formatted summary. IfFalse, only returns the total parameter count silently - hide_empty: If
True, skips nodes with zero parameters. Defaults toTrueto avoid clutter from primitive attributes (integers, strings, functions) in neural network modules that don't contribute to parameter counts
Returns¶
The total number of parameters as an integer
Requirements¶
- equinox: Install with
pip install numerax[sciml]orpip install equinox(required when using defaultis_leaf)
Example¶
import jax.numpy as jnp
from numerax.utils import tree_summary
# Nested dict-based model
model = {
"encoder": {
"weights": jnp.ones((10, 20)),
"bias": jnp.zeros(20),
},
"decoder": {
"weights": jnp.ones((20, 5)),
"bias": jnp.zeros(5),
},
}
count = tree_summary(model)
# Prints formatted table showing structure
# Returns: 325
# With custom is_leaf function
tree_summary(model, is_leaf=lambda x: hasattr(x, "shape"))
# Limit depth
tree_summary(model, max_depth=2)
# Silent mode
count = tree_summary(model, verbose=False)
# Returns: 325 without printing
Output Format¶
======================================================================
PyTree Summary
======================================================================
Name Shape Dtype Params
----------------------------------------------------------------------
root 325
encoder 220
- weights [10,20] float32 200
- bias [20] float32 20
decoder 105
- weights [20,5] float32 100
- bias [5] float32 5
======================================================================
Total params: 325
======================================================================
Notes¶
- Container nodes (dicts, lists, modules) show total parameter counts for their entire subtree
- Leaf nodes (arrays) show shape, dtype, and individual param count
- Indentation shows nesting depth in the PyTree structure
- Works with Equinox modules, nested dicts, lists, tuples, and custom PyTree nodes
- Use custom
is_leaffunctions to control what counts as a leaf node (useful for custom PyTree registrations)