Skip to content

Statistics

numerax.stats

Statistics submodule for numerax.

make_profile_llh

make_profile_llh(llh_fn: Callable, is_nuisance: list[bool] | ndarray, get_initial_nuisance: Callable, tol: float = 1e-06, initial_value: float = 1e-09, initial_diff: float = 1000000000.0, optimizer: GradientTransformation = _DEFAULT_OPTIMIZER) -> Callable

Factory function for creating profile likelihood functions.

Overview

Profile likelihood is a statistical technique used when dealing with nuisance parameters that are not of primary interest but are necessary for the model. This function creates an optimized profile likelihood that maximizes over nuisance parameters while keeping inference parameters fixed.

Mathematical Background

Given a likelihood function \(L(\boldsymbol{\theta}, \boldsymbol{\lambda})\) where \(\boldsymbol{\theta}\) are parameters of interest and \(\boldsymbol{\lambda}\) are nuisance parameters, the profile likelihood is:

\[L_p(\boldsymbol{\theta}) = \max_{\boldsymbol{\lambda}} L(\boldsymbol{\theta}, \boldsymbol{\lambda})\]

In practice, we work with the log-likelihood \(\ell(\boldsymbol{\theta}, \boldsymbol{\lambda}) = \log L(\boldsymbol{\theta}, \boldsymbol{\lambda})\):

\[\ell_p(\boldsymbol{\theta}) = \max_{\boldsymbol{\lambda}} \ell(\boldsymbol{\theta}, \boldsymbol{\lambda})\]

This function uses L-BFGS optimization to find the maximum likelihood estimates of nuisance parameters for each fixed value of inference parameters.

Args
  • llh_fn: Log likelihood function taking (params, *args) and returning scalar log likelihood value
  • is_nuisance: Boolean array where True indicates nuisance parameters and False indicates inference parameters
  • get_initial_nuisance: Function taking (*args) and returning initial values for nuisance parameters
  • tol: Convergence tolerance for the optimization (default: 1e-6)
  • initial_value: Initial objective value for convergence tracking (default: 1e-9)
  • initial_diff: Initial difference for convergence tracking (default: 1e9)
  • optimizer: Optax optimizer to use for maximization (default: lbfgs()). Currently tested only with the default L-BFGS optimizer
Returns

Profile likelihood function with signature: (inference_values, *args) -> (profile_llh_value, optimal_nuisance, convergence_diff, num_iterations)

Example

Consider fitting a normal distribution where we want to infer the mean \(\mu\) but treat the variance \(\sigma^2\) as a nuisance parameter:

import jax.numpy as jnp
import numerax

# Sample data
data = jnp.array([1.2, 0.8, 1.5, 0.9, 1.1, 1.3, 0.7, 1.4])


# Log likelihood for normal distribution
def normal_llh(params, data):
    mu, log_sigma = params  # Use log(sigma) for numerical stability
    sigma = jnp.exp(log_sigma)
    return jnp.sum(
        -0.5 * jnp.log(2 * jnp.pi)
        - log_sigma
        - 0.5 * ((data - mu) / sigma) ** 2
    )


# Profile over log_sigma (nuisance), infer mu
is_nuisance = [False, True]  # mu=inference, log_sigma=nuisance


def get_initial_log_sigma(data):
    # Initialize with log of sample standard deviation
    return jnp.array([jnp.log(jnp.std(data))])


profile_llh = numerax.stats.make_profile_llh(
    normal_llh, is_nuisance, get_initial_log_sigma
)

# Evaluate profile likelihood at different mu values
mu_test = 1.0
llh_val, opt_log_sigma, diff, n_iter = profile_llh(
    jnp.array([mu_test]), data
)
Notes
  • The function is JIT-compiled for performance
  • Uses L-BFGS optimization which is well-suited for smooth likelihood surfaces
  • Returns convergence information for diagnostics
  • Handles parameter masking automatically
  • Consider using log-parameterization for positive parameters (e.g., \(\log \sigma\)) for unconstrained optimization
  • This function might not work well if the likelihood surface has multiple local maxima; in such cases, consider ensuring that initial guesses are close to the global maximum.

chi2

Chi-squared distribution functions.

This module provides a complete interface for chi-squared distribution computations, combining re-exported JAX's standard statistical functions (pdf, cdf, etc.) with a custom high-precision percent point function (ppf).

All functions support location-scale parameterization and are fully compatible with JAX transformations (JIT, grad, vmap).

ppf

ppf(q: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> ArrayLike

Chi-squared percent point function (inverse CDF).

Overview

Computes the percent point function (quantile function) of the chi-squared distribution. This is the inverse of the cumulative distribution function, finding \(x\) such that \(P(X \leq x) = q\) for a chi-squared random variable \(X\) with \(\text{df}\) degrees of freedom.

Mathematical Background

The chi-squared distribution with \(\text{df}\) degrees of freedom is a special case of the gamma distribution:

\[X \sim \chi^2(\text{df}) \equiv \text{Gamma}\left( \frac{\text{df}}{2}, 2\right)\]

For the location-scale family:

\[Y = \text{loc} + \text{scale} \cdot X\]

The percent point function is computed as:

\[ \text{ppf}(q, \text{df}, \text{loc}, \text{scale}) = \text{loc} + \text{scale} \cdot 2 \cdot \text{gammap\_inverse}\left(q, \frac{\text{df}}{2}\right) \]
Args
  • q: Probability values in \([0, 1]\). Can be scalar or array.
  • df: Degrees of freedom (must be positive). Can be scalar or array.
  • loc: Location parameter (default: 0). Can be scalar or array.
  • scale: Scale parameter (must be positive, default: 1). Can be scalar or array.
Returns

Quantiles \(x\) where \(P(X \leq x) = q\). Shape follows JAX broadcasting rules.

Example
import jax.numpy as jnp
import numerax

# Single quantile
x = numerax.stats.chi2.ppf(0.5, df=2)  # Median of χ²(2)

# Multiple quantiles
q_vals = jnp.array([0.1, 0.25, 0.5, 0.75, 0.9])
x_vals = numerax.stats.chi2.ppf(q_vals, df=3)

# Location-scale family
x_scaled = numerax.stats.chi2.ppf(0.5, df=2, loc=1, scale=2)

# Differentiable for optimization
grad_fn = jax.grad(numerax.stats.chi2.ppf)
sensitivity = grad_fn(0.5, 2.0)  # ∂x/∂q at median
Notes
  • Differentiable: Automatic differentiation through gammap_inverse
  • Broadcasting: Supports JAX array broadcasting for all parameters
  • Performance: JIT-compiled compatibility

ncx2

Non-central chi-squared distribution functions.

This module provides the probability density function (pdf) and its logarithm (logpdf) for the non-central chi-squared distribution, built on the numerically stable scaled modified Bessel function ive. The signatures mirror scipy.stats.ncx2 (including loc and scale), and all functions are compatible with JAX transformations (JIT, grad, vmap).

logpdf

logpdf(x: ArrayLike, df: ArrayLike, nc: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> ArrayLike

Non-central chi-squared log probability density function.

Overview

Computes the natural logarithm of the probability density function of the non-central chi-squared distribution with \(\text{df}\) degrees of freedom and non-centrality parameter \(\lambda\) (nc), in the location-scale family.

Mathematical Background

The non-central chi-squared distribution is the distribution of the sum of squares of \(k\) independent unit-variance normal random variables with means \(\mu_1, \ldots, \mu_k\):

\[X = \sum_{i=1}^{k} (Z_i + \mu_i)^2, \qquad Z_i \sim N(0, 1),\]

with degrees of freedom \(k = \text{df}\) and non-centrality parameter \(\lambda = \sum_{i=1}^{k} \mu_i^2\). Note that nc is \(\lambda\) itself (matching scipy.stats.ncx2), not \(\lambda^2\).

The standard (\(\text{loc}=0\), \(\text{scale}=1\)) density is

\[f(x; k, \lambda) = \tfrac12 e^{-(x + \lambda)/2} \left(\tfrac{x}{\lambda}\right)^{\nu/2} I_{\nu}\!\left(\sqrt{\lambda x}\right), \qquad \nu = \tfrac{k}{2} - 1,\]

where \(I_{\nu}\) is the modified Bessel function of the first kind. Using \(I_{\nu}(z) = e^{z}\,\mathtt{ive}(\nu, z)\) and combining the exponentials via \(-(x+\lambda)/2 + \sqrt{\lambda x} = -\tfrac12(\sqrt{x} - \sqrt{\lambda})^2\) gives the numerically stable form actually evaluated:

\[f = \tfrac12\, e^{-\frac12(\sqrt{x} - \sqrt{\lambda})^2} \left(\tfrac{x}{\lambda}\right)^{\nu/2} \mathtt{ive}\!\left(\nu, \sqrt{\lambda x}\right).\]

The location-scale family follows the usual convention \(f(x; \text{loc}, \text{scale}) = \tfrac{1}{\text{scale}} f_{\text{std}}\!\left( \tfrac{x - \text{loc}}{\text{scale}}\right)\).

Args
  • x: Quantile values. Scalar or array.
  • df: Degrees of freedom (must be positive). Scalar or array.
  • nc: Non-centrality parameter \(\lambda \ge 0\). Scalar or array. nc == 0 reduces to the central chi-squared distribution.
  • loc: Location parameter (default: 0). Scalar or array.
  • scale: Scale parameter (must be positive, default: 1). Scalar or array.
Returns

Log probability density values. Shape follows JAX broadcasting rules.

Example
import jax.numpy as jnp
import numerax

# Single value
lp = numerax.stats.ncx2.logpdf(4.0, df=3.0, nc=2.0)

# Vectorized
x_vals = jnp.array([0.5, 1.0, 2.0, 5.0])
lps = numerax.stats.ncx2.logpdf(x_vals, df=3.0, nc=2.0)

# Differentiable in x and nc
grad_fn = jax.grad(numerax.stats.ncx2.logpdf, argnums=2)
sensitivity = grad_fn(4.0, 3.0, 2.0)  # d logpdf / d nc
Notes
  • Differentiable: w.r.t. x, nc, loc, and scale. Not differentiable w.r.t. df -- df feeds the order of the underlying ive, which has no order-derivative, so differentiating w.r.t. df raises TypeError rather than returning a silently-wrong value.
  • Gradient at \(\text{nc} = 0\): the gradient w.r.t. nc is correct for every \(\text{nc} > 0\), but at exactly \(\text{nc} = 0\) it returns 0 rather than the true one-sided score $\tfrac12!\left(\tfrac{(x-\text{loc})/\text{scale}}{\text{df}}
  • 1\right)$. \(\text{nc} = 0\) is the boundary of the parameter domain (where the derivative is one-sided and the closed form has a removable singularity), so this measure-zero point is not special cased. Differentiate at a small \(\text{nc} > 0\) if the score at the null is needed.
  • Broadcasting: Supports JAX array broadcasting for all parameters.
  • Accuracy: inherits the ~1e-6 relative accuracy of ive.
  • Support boundary: for \(x \le \text{loc}\) the density is taken to be zero (logpdf returns -inf), matching scipy.stats.ncx2.pdf. At the single point \(x = \text{loc}\) the mathematical limit can diverge for \(\text{df} < 2\); that measure-zero point is reported as zero density here.

pdf

pdf(x: ArrayLike, df: ArrayLike, nc: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> ArrayLike

Non-central chi-squared probability density function.

Computes the probability density function of the non-central chi-squared distribution with \(\text{df}\) degrees of freedom and non-centrality parameter \(\lambda\) (nc), in the location-scale family. This is exp(logpdf(...)); see logpdf for the full mathematical background, conventions, and differentiability notes.

Args
  • x: Quantile values. Scalar or array.
  • df: Degrees of freedom (must be positive). Scalar or array.
  • nc: Non-centrality parameter \(\lambda \ge 0\). Scalar or array.
  • loc: Location parameter (default: 0). Scalar or array.
  • scale: Scale parameter (must be positive, default: 1). Scalar or array.
Returns

Probability density values. Shape follows JAX broadcasting rules.

Example
import jax.numpy as jnp
import numerax

# Single value
p = numerax.stats.ncx2.pdf(4.0, df=3.0, nc=2.0)

# Vectorized
x_vals = jnp.array([0.5, 1.0, 2.0, 5.0])
ps = numerax.stats.ncx2.pdf(x_vals, df=3.0, nc=2.0)
Notes
  • Differentiable: w.r.t. x, nc, loc, and scale; differentiating w.r.t. df raises TypeError. The gradient w.r.t. nc returns 0 at exactly nc = 0 (a domain boundary). See logpdf for details.
  • Accuracy: inherits the ~1e-6 relative accuracy of ive.