Skip to content

Utilities

numerax.utils

Utility functions for the numerax package.

This module provides development utilities for creating JAX-compatible functions and tools for working with PyTree structures, including parameter counting for machine learning models.

count_params

count_params(pytree, filter=None, verbose=True)

Count the total number of parameters in a PyTree structure.

Overview

This function counts parameters in PyTree-based models by filtering for array-like objects and summing their sizes. It is particularly useful for neural network models built with JAX frameworks like Equinox.

The function traverses the PyTree structure, applies a filter to identify parameter arrays, and computes the total parameter count.

Args
  • pytree: The PyTree structure to count parameters in (e.g., a model, dict of arrays, or nested structure)
  • filter: Optional filter function to identify parameters. If None, uses equinox.is_array as the default filter. Custom filters should accept a single argument and return True for objects that should be counted
  • verbose: If True, prints the parameter count in scientific notation. If False, only returns the count silently
Returns

The total number of parameters as an integer

Requirements
  • equinox: Install with pip install numerax[sciml] or pip install equinox
Example
import jax.numpy as jnp
from numerax.utils import count_params

# Simple dict-based model
model = {"weights": jnp.ones((10, 5)), "bias": jnp.zeros(5)}
count = count_params(model)
# Prints: Number of parameters: 5.5e+01
# Returns: 55

# With custom filter
count = count_params(
    model,
    filter=lambda x: hasattr(x, "ndim") and x.ndim > 1,
    verbose=False,
)
# Returns: 50 (only the weights matrix)

# With Equinox model
import equinox as eqx


class MLP(eqx.Module):
    layers: list

    def __init__(self, key):
        self.layers = [
            eqx.nn.Linear(10, 64, key=key),
            eqx.nn.Linear(64, 1, key=key),
        ]


model = MLP(jax.random.PRNGKey(0))
count = count_params(model)
# Counts all trainable parameters in the MLP
Notes
  • The default filter (equinox.is_array) correctly identifies parameter arrays in Equinox modules and standard JAX PyTrees
  • For custom filtering logic, provide a function that returns True for leaves that should be counted as parameters
  • The function handles nested PyTree structures automatically

preserve_metadata

preserve_metadata(decorator)

Wrapper that ensures a decorator preserves function metadata for documentation tools.

Overview

When wrapping functions with decorators, metadata like __name__, __doc__, and __module__ can be lost if the decorator doesn't explicitly preserve them. This wrapper uses functools.wraps to ensure metadata is maintained, which is important for documentation generators like pdoc and for debugging.

Args
  • decorator: The decorator function to wrap
Returns

A new decorator that preserves metadata

Example
from numerax.utils import preserve_metadata

# A simple decorator that doesn't preserve metadata
def my_decorator(func):
    def wrapper(*args, **kwargs):
        return func(*args, **kwargs)
    return wrapper

# Without preserve_metadata, __doc__ and __name__ are lost
@my_decorator
def func1(x):
    """This docstring will be lost."""
    return x

# With preserve_metadata, they are preserved
@preserve_metadata(my_decorator)
def func2(x):
    """This docstring will be preserved."""
    return x

tree_summary

tree_summary(pytree, is_leaf=None, max_depth=3, verbose=True, hide_empty=True)

Pretty-print PyTree structure with shapes and parameter counts.

Overview

This function displays a hierarchical view of a PyTree structure (e.g., neural network models) showing the organization, array shapes, data types, and parameter counts at each level. The output is similar to Keras' model.summary() or torchinfo's summaries. This is compatible with PyTree-based models from frameworks like Equinox.

Args
  • pytree: The PyTree structure to summarize (e.g., a model, dict of arrays, or nested structure)
  • is_leaf: Optional function to identify leaf nodes. If None, uses equinox.is_array as the default. Leaf nodes are displayed with shape, dtype, and parameter count details. Custom functions should accept a single argument and return True for leaves
  • max_depth: Maximum nesting depth to display. Nodes deeper than this level will not be shown. Defaults to 3
  • verbose: If True, prints the formatted summary. If False, only returns the total parameter count silently
  • hide_empty: If True, skips nodes with zero parameters. Defaults to True to avoid clutter from primitive attributes (integers, strings, functions) in neural network modules that don't contribute to parameter counts
Returns

The total number of parameters as an integer

Requirements
  • equinox: Install with pip install numerax[sciml] or pip install equinox (required when using default is_leaf)
Example
import jax.numpy as jnp
from numerax.utils import tree_summary

# Nested dict-based model
model = {
    "encoder": {
        "weights": jnp.ones((10, 20)),
        "bias": jnp.zeros(20),
    },
    "decoder": {
        "weights": jnp.ones((20, 5)),
        "bias": jnp.zeros(5),
    },
}

count = tree_summary(model)
# Prints formatted table showing structure
# Returns: 325

# With custom is_leaf function
tree_summary(model, is_leaf=lambda x: hasattr(x, "shape"))

# Limit depth
tree_summary(model, max_depth=2)

# Silent mode
count = tree_summary(model, verbose=False)
# Returns: 325 without printing
Output Format
======================================================================
PyTree Summary
======================================================================
Name                  Shape           Dtype             Params
----------------------------------------------------------------------
root                                                       325
  encoder                                                  220
    - weights         [10,20]         float32              200
    - bias            [20]            float32               20
  decoder                                                  105
    - weights         [20,5]          float32              100
    - bias            [5]             float32                5
======================================================================
Total params: 325
======================================================================
Notes
  • Container nodes (dicts, lists, modules) show total parameter counts for their entire subtree
  • Leaf nodes (arrays) show shape, dtype, and individual param count
  • Indentation shows nesting depth in the PyTree structure
  • Works with Equinox modules, nested dicts, lists, tuples, and custom PyTree nodes
  • Use custom is_leaf functions to control what counts as a leaf node (useful for custom PyTree registrations)