Statistics¶
numerax.stats ¶
Statistics submodule for numerax.
make_profile_llh ¶
make_profile_llh(llh_fn: Callable, is_nuisance: list[bool] | ndarray, get_initial_nuisance: Callable, tol: float = 1e-06, initial_value: float = 1e-09, initial_diff: float = 1000000000.0, optimizer: GradientTransformation = _DEFAULT_OPTIMIZER) -> Callable
Factory function for creating profile likelihood functions.
Overview¶
Profile likelihood is a statistical technique used when dealing with nuisance parameters that are not of primary interest but are necessary for the model. This function creates an optimized profile likelihood that maximizes over nuisance parameters while keeping inference parameters fixed.
Mathematical Background¶
Given a likelihood function \(L(\boldsymbol{\theta}, \boldsymbol{\lambda})\) where \(\boldsymbol{\theta}\) are parameters of interest and \(\boldsymbol{\lambda}\) are nuisance parameters, the profile likelihood is:
In practice, we work with the log-likelihood \(\ell(\boldsymbol{\theta}, \boldsymbol{\lambda}) = \log L(\boldsymbol{\theta}, \boldsymbol{\lambda})\):
This function uses L-BFGS optimization to find the maximum likelihood estimates of nuisance parameters for each fixed value of inference parameters.
Args¶
- llh_fn: Log likelihood function taking (params, *args) and returning scalar log likelihood value
- is_nuisance: Boolean array where True indicates nuisance parameters and False indicates inference parameters
- get_initial_nuisance: Function taking (*args) and returning initial values for nuisance parameters
- tol: Convergence tolerance for the optimization (default: 1e-6)
- initial_value: Initial objective value for convergence tracking (default: 1e-9)
- initial_diff: Initial difference for convergence tracking (default: 1e9)
- optimizer: Optax optimizer to use for maximization (default: lbfgs()). Currently tested only with the default L-BFGS optimizer
Returns¶
Profile likelihood function with signature:
(inference_values, *args) -> (profile_llh_value, optimal_nuisance,
convergence_diff, num_iterations)
Example¶
Consider fitting a normal distribution where we want to infer the mean \(\mu\) but treat the variance \(\sigma^2\) as a nuisance parameter:
import jax.numpy as jnp
import numerax
# Sample data
data = jnp.array([1.2, 0.8, 1.5, 0.9, 1.1, 1.3, 0.7, 1.4])
# Log likelihood for normal distribution
def normal_llh(params, data):
mu, log_sigma = params # Use log(sigma) for numerical stability
sigma = jnp.exp(log_sigma)
return jnp.sum(
-0.5 * jnp.log(2 * jnp.pi)
- log_sigma
- 0.5 * ((data - mu) / sigma) ** 2
)
# Profile over log_sigma (nuisance), infer mu
is_nuisance = [False, True] # mu=inference, log_sigma=nuisance
def get_initial_log_sigma(data):
# Initialize with log of sample standard deviation
return jnp.array([jnp.log(jnp.std(data))])
profile_llh = numerax.stats.make_profile_llh(
normal_llh, is_nuisance, get_initial_log_sigma
)
# Evaluate profile likelihood at different mu values
mu_test = 1.0
llh_val, opt_log_sigma, diff, n_iter = profile_llh(
jnp.array([mu_test]), data
)
Notes¶
- The function is JIT-compiled for performance
- Uses L-BFGS optimization which is well-suited for smooth likelihood surfaces
- Returns convergence information for diagnostics
- Handles parameter masking automatically
- Consider using log-parameterization for positive parameters (e.g., \(\log \sigma\)) for unconstrained optimization
- This function might not work well if the likelihood surface has multiple local maxima; in such cases, consider ensuring that initial guesses are close to the global maximum.
chi2 ¶
Chi-squared distribution functions.
This module provides a complete interface for chi-squared distribution computations, combining re-exported JAX's standard statistical functions (pdf, cdf, etc.) with a custom high-precision percent point function (ppf).
All functions support location-scale parameterization and are fully compatible with JAX transformations (JIT, grad, vmap).
ppf ¶
Chi-squared percent point function (inverse CDF).
Overview¶
Computes the percent point function (quantile function) of the chi-squared distribution. This is the inverse of the cumulative distribution function, finding \(x\) such that \(P(X \leq x) = q\) for a chi-squared random variable \(X\) with \(\text{df}\) degrees of freedom.
Mathematical Background¶
The chi-squared distribution with \(\text{df}\) degrees of freedom is a special case of the gamma distribution:
For the location-scale family:
The percent point function is computed as:
Args¶
- q: Probability values in \([0, 1]\). Can be scalar or array.
- df: Degrees of freedom (must be positive). Can be scalar or array.
- loc: Location parameter (default: 0). Can be scalar or array.
- scale: Scale parameter (must be positive, default: 1). Can be scalar or array.
Returns¶
Quantiles \(x\) where \(P(X \leq x) = q\). Shape follows JAX broadcasting rules.
Example¶
import jax.numpy as jnp
import numerax
# Single quantile
x = numerax.stats.chi2.ppf(0.5, df=2) # Median of χ²(2)
# Multiple quantiles
q_vals = jnp.array([0.1, 0.25, 0.5, 0.75, 0.9])
x_vals = numerax.stats.chi2.ppf(q_vals, df=3)
# Location-scale family
x_scaled = numerax.stats.chi2.ppf(0.5, df=2, loc=1, scale=2)
# Differentiable for optimization
grad_fn = jax.grad(numerax.stats.chi2.ppf)
sensitivity = grad_fn(0.5, 2.0) # ∂x/∂q at median
Notes¶
- Differentiable: Automatic differentiation through
gammap_inverse - Broadcasting: Supports JAX array broadcasting for all parameters
- Performance: JIT-compiled compatibility
ncx2 ¶
Non-central chi-squared distribution functions.
This module provides the probability density function (pdf) and its
logarithm (logpdf) for the non-central chi-squared distribution,
built on the numerically stable scaled modified Bessel function
ive. The signatures mirror
scipy.stats.ncx2 (including loc and scale), and all
functions are compatible with JAX transformations (JIT, grad, vmap).
logpdf ¶
logpdf(x: ArrayLike, df: ArrayLike, nc: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> ArrayLike
Non-central chi-squared log probability density function.
Overview¶
Computes the natural logarithm of the probability density function of
the non-central chi-squared distribution with \(\text{df}\) degrees of
freedom and non-centrality parameter \(\lambda\) (nc), in the
location-scale family.
Mathematical Background¶
The non-central chi-squared distribution is the distribution of the sum of squares of \(k\) independent unit-variance normal random variables with means \(\mu_1, \ldots, \mu_k\):
with degrees of freedom \(k = \text{df}\) and non-centrality parameter
\(\lambda = \sum_{i=1}^{k} \mu_i^2\). Note that nc is \(\lambda\)
itself (matching scipy.stats.ncx2), not \(\lambda^2\).
The standard (\(\text{loc}=0\), \(\text{scale}=1\)) density is
where \(I_{\nu}\) is the modified Bessel function of the first kind. Using \(I_{\nu}(z) = e^{z}\,\mathtt{ive}(\nu, z)\) and combining the exponentials via \(-(x+\lambda)/2 + \sqrt{\lambda x} = -\tfrac12(\sqrt{x} - \sqrt{\lambda})^2\) gives the numerically stable form actually evaluated:
The location-scale family follows the usual convention \(f(x; \text{loc}, \text{scale}) = \tfrac{1}{\text{scale}} f_{\text{std}}\!\left( \tfrac{x - \text{loc}}{\text{scale}}\right)\).
Args¶
- x: Quantile values. Scalar or array.
- df: Degrees of freedom (must be positive). Scalar or array.
- nc: Non-centrality parameter \(\lambda \ge 0\). Scalar or array.
nc == 0reduces to the central chi-squared distribution. - loc: Location parameter (default: 0). Scalar or array.
- scale: Scale parameter (must be positive, default: 1). Scalar or array.
Returns¶
Log probability density values. Shape follows JAX broadcasting rules.
Example¶
import jax.numpy as jnp
import numerax
# Single value
lp = numerax.stats.ncx2.logpdf(4.0, df=3.0, nc=2.0)
# Vectorized
x_vals = jnp.array([0.5, 1.0, 2.0, 5.0])
lps = numerax.stats.ncx2.logpdf(x_vals, df=3.0, nc=2.0)
# Differentiable in x and nc
grad_fn = jax.grad(numerax.stats.ncx2.logpdf, argnums=2)
sensitivity = grad_fn(4.0, 3.0, 2.0) # d logpdf / d nc
Notes¶
- Differentiable: w.r.t.
x,nc,loc, andscale. Not differentiable w.r.t.df--dffeeds the order of the underlyingive, which has no order-derivative, so differentiating w.r.t.dfraisesTypeErrorrather than returning a silently-wrong value. - Gradient at \(\text{nc} = 0\): the gradient w.r.t.
ncis correct for every \(\text{nc} > 0\), but at exactly \(\text{nc} = 0\) it returns0rather than the true one-sided score $\tfrac12!\left(\tfrac{(x-\text{loc})/\text{scale}}{\text{df}} - 1\right)$. \(\text{nc} = 0\) is the boundary of the parameter domain (where the derivative is one-sided and the closed form has a removable singularity), so this measure-zero point is not special cased. Differentiate at a small \(\text{nc} > 0\) if the score at the null is needed.
- Broadcasting: Supports JAX array broadcasting for all parameters.
- Accuracy: inherits the ~1e-6 relative accuracy of
ive. - Support boundary: for \(x \le \text{loc}\) the density is taken
to be zero (
logpdfreturns-inf), matchingscipy.stats.ncx2.pdf. At the single point \(x = \text{loc}\) the mathematical limit can diverge for \(\text{df} < 2\); that measure-zero point is reported as zero density here.
pdf ¶
pdf(x: ArrayLike, df: ArrayLike, nc: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> ArrayLike
Non-central chi-squared probability density function.
Computes the probability density function of the non-central
chi-squared distribution with \(\text{df}\) degrees of freedom and
non-centrality parameter \(\lambda\) (nc), in the location-scale
family. This is exp(logpdf(...)); see
logpdf for the full mathematical
background, conventions, and differentiability notes.
Args¶
- x: Quantile values. Scalar or array.
- df: Degrees of freedom (must be positive). Scalar or array.
- nc: Non-centrality parameter \(\lambda \ge 0\). Scalar or array.
- loc: Location parameter (default: 0). Scalar or array.
- scale: Scale parameter (must be positive, default: 1). Scalar or array.
Returns¶
Probability density values. Shape follows JAX broadcasting rules.
Example¶
import jax.numpy as jnp
import numerax
# Single value
p = numerax.stats.ncx2.pdf(4.0, df=3.0, nc=2.0)
# Vectorized
x_vals = jnp.array([0.5, 1.0, 2.0, 5.0])
ps = numerax.stats.ncx2.pdf(x_vals, df=3.0, nc=2.0)
Notes¶
- Differentiable: w.r.t.
x,nc,loc, andscale; differentiating w.r.t.dfraisesTypeError. The gradient w.r.t.ncreturns0at exactlync = 0(a domain boundary). Seelogpdffor details. - Accuracy: inherits the ~1e-6 relative accuracy of
ive.