Skip to content

Statistics

numerax.stats

Statistics submodule for numerax.

make_profile_llh

make_profile_llh(llh_fn: Callable, is_nuisance: list[bool] | ndarray, get_initial_nuisance: Callable, tol: float = 1e-06, initial_value: float = 1e-09, initial_diff: float = 1000000000.0, optimizer: GradientTransformation = _DEFAULT_OPTIMIZER) -> Callable

Factory function for creating profile likelihood functions.

Overview

Profile likelihood is a statistical technique used when dealing with nuisance parameters that are not of primary interest but are necessary for the model. This function creates an optimized profile likelihood that maximizes over nuisance parameters while keeping inference parameters fixed.

Mathematical Background

Given a likelihood function \(L(\boldsymbol{\theta}, \boldsymbol{\lambda})\) where \(\boldsymbol{\theta}\) are parameters of interest and \(\boldsymbol{\lambda}\) are nuisance parameters, the profile likelihood is:

\[L_p(\boldsymbol{\theta}) = \max_{\boldsymbol{\lambda}} L(\boldsymbol{\theta}, \boldsymbol{\lambda})\]

In practice, we work with the log-likelihood \(\ell(\boldsymbol{\theta}, \boldsymbol{\lambda}) = \log L(\boldsymbol{\theta}, \boldsymbol{\lambda})\):

\[\ell_p(\boldsymbol{\theta}) = \max_{\boldsymbol{\lambda}} \ell(\boldsymbol{\theta}, \boldsymbol{\lambda})\]

This function uses L-BFGS optimization to find the maximum likelihood estimates of nuisance parameters for each fixed value of inference parameters.

Args
  • llh_fn: Log likelihood function taking (params, *args) and returning scalar log likelihood value
  • is_nuisance: Boolean array where True indicates nuisance parameters and False indicates inference parameters
  • get_initial_nuisance: Function taking (*args) and returning initial values for nuisance parameters
  • tol: Convergence tolerance for the optimization (default: 1e-6)
  • initial_value: Initial objective value for convergence tracking (default: 1e-9)
  • initial_diff: Initial difference for convergence tracking (default: 1e9)
  • optimizer: Optax optimizer to use for maximization (default: lbfgs()). Currently tested only with the default L-BFGS optimizer
Returns

Profile likelihood function with signature: (inference_values, *args) -> (profile_llh_value, optimal_nuisance, convergence_diff, num_iterations)

Example

Consider fitting a normal distribution where we want to infer the mean \(\mu\) but treat the variance \(\sigma^2\) as a nuisance parameter:

import jax.numpy as jnp
import numerax

# Sample data
data = jnp.array([1.2, 0.8, 1.5, 0.9, 1.1, 1.3, 0.7, 1.4])


# Log likelihood for normal distribution
def normal_llh(params, data):
    mu, log_sigma = params  # Use log(sigma) for numerical stability
    sigma = jnp.exp(log_sigma)
    return jnp.sum(
        -0.5 * jnp.log(2 * jnp.pi)
        - log_sigma
        - 0.5 * ((data - mu) / sigma) ** 2
    )


# Profile over log_sigma (nuisance), infer mu
is_nuisance = [False, True]  # mu=inference, log_sigma=nuisance


def get_initial_log_sigma(data):
    # Initialize with log of sample standard deviation
    return jnp.array([jnp.log(jnp.std(data))])


profile_llh = numerax.stats.make_profile_llh(
    normal_llh, is_nuisance, get_initial_log_sigma
)

# Evaluate profile likelihood at different mu values
mu_test = 1.0
llh_val, opt_log_sigma, diff, n_iter = profile_llh(
    jnp.array([mu_test]), data
)
Notes
  • The function is JIT-compiled for performance
  • Uses L-BFGS optimization which is well-suited for smooth likelihood surfaces
  • Returns convergence information for diagnostics
  • Handles parameter masking automatically
  • Consider using log-parameterization for positive parameters (e.g., \(\log \sigma\)) for unconstrained optimization
  • This function might not work well if the likelihood surface has multiple local maxima; in such cases, consider ensuring that initial guesses are close to the global maximum.

chi2

Chi-squared distribution functions.

This module provides a complete interface for chi-squared distribution computations, combining re-exported JAX's standard statistical functions (pdf, cdf, etc.) with a custom high-precision percent point function (ppf).

All functions support location-scale parameterization and are fully compatible with JAX transformations (JIT, grad, vmap).

ppf

ppf(q: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> ArrayLike

Chi-squared percent point function (inverse CDF).

Overview

Computes the percent point function (quantile function) of the chi-squared distribution. This is the inverse of the cumulative distribution function, finding \(x\) such that \(P(X \leq x) = q\) for a chi-squared random variable \(X\) with \(\text{df}\) degrees of freedom.

Mathematical Background

The chi-squared distribution with \(\text{df}\) degrees of freedom is a special case of the gamma distribution:

\[X \sim \chi^2(\text{df}) \equiv \text{Gamma}\left( \frac{\text{df}}{2}, 2\right)\]

For the location-scale family:

\[Y = \text{loc} + \text{scale} \cdot X\]

The percent point function is computed as:

\[ \text{ppf}(q, \text{df}, \text{loc}, \text{scale}) = \text{loc} + \text{scale} \cdot 2 \cdot \text{gammap\_inverse}\left(q, \frac{\text{df}}{2}\right) \]
Args
  • q: Probability values in \([0, 1]\). Can be scalar or array.
  • df: Degrees of freedom (must be positive). Can be scalar or array.
  • loc: Location parameter (default: 0). Can be scalar or array.
  • scale: Scale parameter (must be positive, default: 1). Can be scalar or array.
Returns

Quantiles \(x\) where \(P(X \leq x) = q\). Shape follows JAX broadcasting rules.

Example
import jax.numpy as jnp
import numerax

# Single quantile
x = numerax.stats.chi2.ppf(0.5, df=2)  # Median of χ²(2)

# Multiple quantiles
q_vals = jnp.array([0.1, 0.25, 0.5, 0.75, 0.9])
x_vals = numerax.stats.chi2.ppf(q_vals, df=3)

# Location-scale family
x_scaled = numerax.stats.chi2.ppf(0.5, df=2, loc=1, scale=2)

# Differentiable for optimization
grad_fn = jax.grad(numerax.stats.chi2.ppf)
sensitivity = grad_fn(0.5, 2.0)  # ∂x/∂q at median
Notes
  • Differentiable: Automatic differentiation through gammap_inverse
  • Broadcasting: Supports JAX array broadcasting for all parameters
  • Performance: JIT-compiled compatibility