stats
Statistics functions and distributions for SCRIBE.
- scribe.stats.compute_histogram_percentiles(samples, percentiles=[5, 25, 50, 75, 95], normalize=True, sample_axis=0)[source]
Compute percentiles of histogram frequencies across multiple samples.
- Parameters:
samples (array-like) – Array of shape (n_samples, n_points) by default, or (n_points, n_samples) if sample_axis=1
percentiles (list-like, optional) – List of percentiles to compute (default: [5, 25, 50, 75, 95])
normalize (bool, optional) – Whether to normalize histograms (default: True)
sample_axis (int, optional) – Axis containing samples (default: 0)
- Returns:
bin_edges (array) – Array of bin edges (integers from min to max value + 1)
hist_percentiles (array) – Array of shape (len(percentiles), len(bin_edges)-1) containing the percentiles of histogram frequencies for each bin
- scribe.stats.compute_histogram_credible_regions(samples, credible_regions=[95, 68, 50], normalize=True, sample_axis=0, batch_size=1000, max_bin=None)[source]
Compute credible regions of histogram frequencies across multiple samples.
- Parameters:
samples (array-like) – Array of shape (n_samples, n_points) by default, or (n_points, n_samples) if sample_axis=1
credible_regions (list-like, optional) – List of credible region percentages to compute (default: [95, 68, 50]) For example, 95 will compute the 2.5 and 97.5 percentiles
normalize (bool, optional) – Whether to normalize histograms (default: True)
sample_axis (int, optional) – Axis containing samples (default: 0)
batch_size (int, optional) – Number of samples to process in each batch (default: 100)
max_bin (int, optional) – Maximum number of bins to process (default: None)
- Returns:
Dictionary containing: - ‘bin_edges’: array of bin edges - ‘regions’: nested dictionary where each key is the credible region percentage
- and values are dictionaries containing:
’lower’: lower bound of the credible region
’upper’: upper bound of the credible region
’median’: median (50th percentile)
- Return type:
- scribe.stats.compute_ecdf_percentiles(samples, percentiles=[5, 25, 50, 75, 95], sample_axis=0)[source]
Compute percentiles of ECDF values across multiple samples of integers.
- Parameters:
samples (array-like) – Array of shape (n_samples, n_points) by default, or (n_points, n_samples) if sample_axis=1, containing raw data samples of positive integers
percentiles (list-like, optional) – List of percentiles to compute (default: [5, 25, 50, 75, 95])
sample_axis (int, optional) – Axis containing samples (default: 0)
- Returns:
bin_edges (array) – Array of integer points at which ECDFs were evaluated (from min to max)
ecdf_percentiles (array) – Array of shape (len(percentiles), len(bin_edges)) containing the percentiles of ECDF values at each integer point
- scribe.stats.compute_ecdf_credible_regions(samples, credible_regions=[95, 68, 50], sample_axis=0, batch_size=1000, max_bin=None)[source]
Compute credible regions of ECDF values across multiple samples.
- Parameters:
samples (array-like) – Array of shape (n_samples, n_points) by default, or (n_points, n_samples) if sample_axis=1, containing raw data samples
credible_regions (list-like, optional) – List of credible region percentages to compute (default: [95, 68, 50]) For example, 95 will compute the 2.5 and 97.5 percentiles
sample_axis (int, optional) – Axis containing samples (default: 0)
batch_size (int, optional) – Number of samples to process in each batch (default: 1000)
max_bin (int, optional) – Maximum value to include in ECDF evaluation (default: None)
- Returns:
- Dictionary containing:
’bin_edges’: array of points at which ECDFs were evaluated
’regions’: nested dictionary where each key is the credible region percentage
- and values are dictionaries containing:
’lower’: lower bound of the credible region
’upper’: upper bound of the credible region
’median’: median (50th percentile)
- Return type:
- scribe.stats.sample_dirichlet_from_parameters(parameter_samples, n_samples_dirichlet=1, rng_key=None)[source]
Samples from a Dirichlet distribution given an array of parameter samples.
- Parameters:
parameter_samples (array-like) – Array of shape (n_samples, n_variables) containing parameter samples to use as concentration parameters for Dirichlet distributions
n_samples_dirichlet (int, optional) – Number of samples to draw from each Dirichlet distribution (default: 1)
rng_key (random.PRNGKey, optional) – JAX random number generator key. Defaults to random.PRNGKey(42) if None
- Returns:
- If n_samples_dirichlet=1:
Array of shape (n_samples, n_variables)
- If n_samples_dirichlet>1:
Array of shape (n_samples, n_variables, n_samples_dirichlet)
- Return type:
jnp.ndarray
- scribe.stats.fit_dirichlet_mle(samples, max_iter=1000, tol=1e-07, sample_axis=0)[source]
Fit a Dirichlet distribution to samples using Maximum Likelihood Estimation.
This implementation uses Newton’s method to find the concentration parameters that maximize the likelihood of the observed samples. The algorithm iteratively updates the concentration parameters using gradient and Hessian information until convergence.
- Parameters:
samples (array-like) – Array of shape (n_samples, n_variables) by default, or (n_variables, n_samples) if sample_axis=1, containing Dirichlet samples. Each row/column should sum to 1.
max_iter (int, optional) – Maximum number of iterations for optimization (default: 1000)
tol (float, optional) – Tolerance for convergence in parameter updates (default: 1e-7)
sample_axis (int, optional) – Axis containing samples (default: 0)
- Returns:
Array of concentration parameters for the fitted Dirichlet distribution. Shape is (n_variables,).
- Return type:
jnp.ndarray
- scribe.stats.fit_dirichlet_minka(samples, max_iter=1000, tol=1e-07, sample_axis=0)[source]
Fit a Dirichlet distribution to data using Minka’s fixed-point iteration.
- This function uses the relation:
ψ(α_j) - ψ(α₀) = ⟨ln x_j⟩ (with α₀ = ∑ₖ αₖ)
- so that the fixed point update is:
α_j ← ψ⁻¹( ψ(α₀) + ⟨ln x_j⟩ )
This method is generally more stable and faster than moment matching or maximum likelihood estimation via gradient descent.
- Parameters:
samples (array-like) – Data array with shape (n_samples, n_variables) by default (or transposed if sample_axis=1). Each row should sum to 1 (i.e., be a probability vector).
max_iter (int, optional) – Maximum number of iterations for the fixed-point algorithm.
tol (float, optional) – Tolerance for convergence - algorithm stops when max change in α is below this.
sample_axis (int, optional) – Axis containing samples (default: 0). Use 1 if data is (n_variables, n_samples).
- Returns:
Estimated concentration parameters (α) of shape (n_variables,).
- Return type:
jnp.ndarray
- class scribe.stats.BetaPrime(concentration1, concentration0, validate_args=None)[source]
Bases:
DistributionBeta Prime distribution (odds-of-Beta convention).
Convention
If p ~ Beta(α, β) and φ = (1 - p) / p (odds of “success” 1 - p), then φ ~ BetaPrime(α, β) in THIS CLASS.
Implementation detail
- Mathematically, φ has the standard Beta-prime with swapped parameters:
φ ~ BetaPrime_std(β, α).
This class accepts (α, β) at the call site and internally uses (β, α), so that your models can pass (α, β) unchanged. This is necessary because the NumPyro NegativeBinomial distribution expects the probs parameter to be the failure probability p, so that the odds ratio φ = (1 - p) / p is consistent with the parameterization of the BetaPrime distribution.
Density (with user parameters α, β)
f(φ; α, β) = φ^(β - 1) * (1 + φ)^(-(α + β)) / B(β, α), φ > 0
Note the Beta function arguments B(β, α).
- param concentration1:
α (matches the Beta prior’s first shape)
- type concentration1:
jnp.ndarray
- param concentration0:
β (matches the Beta prior’s second shape)
- type concentration0:
jnp.ndarray
- arg_constraints: dict[str, Any] = {'concentration0': Positive(lower_bound=0.0), 'concentration1': Positive(lower_bound=0.0)}
- support = Positive(lower_bound=0.0)
- has_rsample = False
- sample(key, sample_shape=())[source]
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters:
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns:
an array of shape sample_shape + batch_shape + event_shape
- Return type:
- log_prob(value)[source]
Evaluates the log probability density for a batch of samples given by value.
- Parameters:
value – A batch of samples from the distribution.
- Returns:
an array with shape value.shape[:-self.event_shape]
- Return type:
ArrayLike
- property mean
Mean of the distribution.
- property variance
Variance of the distribution.
- property mode
- property concentration1
Access to concentration1 parameter (α) for NumPyro compatibility.
- property concentration0
Access to concentration0 parameter (β) for NumPyro compatibility.
- class scribe.stats.LowRankLogisticNormal(loc, cov_factor, cov_diag, validate_args=None)[source]
Bases:
DistributionLow-rank Logistic-Normal distribution for compositional data.
This distribution models D-dimensional probability vectors (on the simplex) using a (D-1)-dimensional low-rank multivariate normal distribution in log-ratio space. It uses the Additive Log-Ratio (ALR) transformation.
Mathematical Definition
Let y ∈ ℝ^(D-1) ~ MVN(μ, Σ) where Σ = WW^T + diag(D) is low-rank. The ALR transformation maps y to the simplex Δ^D:
xᵢ = exp(yᵢ) / (1 + Σⱼ exp(yⱼ)) for i = 1, …, D-1 x_D = 1 / (1 + Σⱼ exp(yⱼ))
The inverse transformation (simplex to log-ratio space) is:
yᵢ = log(xᵢ / x_D) for i = 1, …, D-1
The Jacobian of the transformation has log-determinant:
log|det(J)| = -Σᵢ₌₁^D log(xᵢ)
Low-Rank Covariance Structure
The covariance matrix has the form:
Σ = WW^T + diag(D)
where: - W is a (D-1) × rank factor matrix - D is a (D-1) diagonal vector - Memory: O((D-1) × rank) vs O((D-1)²) for full covariance
This is critical for large D (e.g., 30K+ genes) where storing full covariance is prohibitive.
Asymmetry and Reference Component
The ALR transformation treats the last component (x_D) as a reference. This means the distribution is NOT symmetric under permutation of components. If you need symmetry, use SoftmaxNormal instead (but note that SoftmaxNormal cannot compute log_prob).
- param loc:
Location parameter μ ∈ ℝ^(D-1) (mean in log-ratio space)
- type loc:
jnp.ndarray
- param cov_factor:
Low-rank factor matrix W of shape (D-1, rank)
- type cov_factor:
jnp.ndarray
- param cov_diag:
Diagonal component D of shape (D-1,)
- type cov_diag:
jnp.ndarray
- param validate_args:
Whether to validate input arguments
- type validate_args:
bool, optional
Examples
>>> from jax import random >>> import jax.numpy as jnp >>> # Create a low-rank logistic-normal for 5-dimensional simplex >>> D = 5 >>> rank = 2 >>> loc = jnp.zeros(D - 1) >>> cov_factor = jnp.ones((D - 1, rank)) * 0.1 >>> cov_diag = jnp.ones(D - 1) * 0.5 >>> dist = LowRankLogisticNormal(loc, cov_factor, cov_diag) >>> # Sample from the distribution (returns D-dimensional simplex points) >>> samples = dist.sample(random.PRNGKey(0), (100,)) >>> samples.shape (100, 5) >>> # Samples sum to 1 >>> jnp.allclose(samples.sum(axis=-1), 1.0) True >>> # Evaluate log probability >>> log_p = dist.log_prob(samples[0])
References
Aitchison, J., & Shen, S. M. (1980). Logistic-normal distributions: Some properties and uses. Biometrika, 67(2), 261-272.
Aitchison, J. (1986). The Statistical Analysis of Compositional Data. Chapman & Hall.
See also
SoftmaxNormalSymmetric alternative using softmax (no log_prob available)
- arg_constraints: dict[str, Any] = {'cov_diag': Positive(lower_bound=0.0), 'cov_factor': IndependentConstraint(Real(), 2), 'loc': RealVector(Real(), 1)}
- support = Simplex()
- has_rsample = False
- sample(key, sample_shape=())[source]
Sample from the distribution.
Returns samples on the D-dimensional simplex.
- class scribe.stats.SoftmaxNormal(loc, cov_factor, cov_diag, validate_args=None)[source]
Bases:
DistributionSoftmax-Normal distribution for compositional data (symmetric).
This distribution models D-dimensional probability vectors (on the simplex) using a D-dimensional low-rank multivariate normal distribution with a softmax transformation. Unlike LowRankLogisticNormal (which uses ALR), this treats all components symmetrically.
Mathematical Definition
Let y ∈ ℝ^D ~ MVN(μ, Σ) where Σ = WW^T + diag(D) is low-rank. The softmax transformation maps y to the simplex Δ^D:
xᵢ = exp(yᵢ) / Σⱼ exp(yⱼ) for i = 1, …, D
Symmetry and Invariance
The softmax transformation is: - Symmetric: All components treated equally (no reference component) - Translation-invariant: softmax(y + c·1) = softmax(y) for any constant c
This translation invariance means the transformation is SINGULAR - you cannot uniquely invert it. Therefore, log_prob() is not available.
Low-Rank Covariance Structure
The covariance matrix has the form:
Σ = WW^T + diag(D)
- where:
W is a D × rank factor matrix
D is a D-dimensional diagonal vector
Memory: O(D × rank) vs O(D²) for full covariance
When to Use
- Use SoftmaxNormal when:
You want symmetric treatment of all components
You only need sampling (not log_prob evaluation)
You’re summarizing/visualizing posterior distributions
- Use LowRankLogisticNormal when:
You need to evaluate log_prob() for observed data
You’re using the distribution as a likelihood in Bayesian inference
Asymmetry (reference component) is acceptable
- param loc:
Location parameter μ ∈ ℝ^D (mean in log-space)
- type loc:
jnp.ndarray
- param cov_factor:
Low-rank factor matrix W of shape (D, rank)
- type cov_factor:
jnp.ndarray
- param cov_diag:
Diagonal component D of shape (D,)
- type cov_diag:
jnp.ndarray
- param validate_args:
Whether to validate input arguments
- type validate_args:
bool, optional
Examples
>>> from jax import random >>> import jax.numpy as jnp >>> # Create a softmax-normal for 5-dimensional simplex >>> D = 5 >>> rank = 2 >>> loc = jnp.zeros(D) >>> cov_factor = jnp.ones((D, rank)) * 0.1 >>> cov_diag = jnp.ones(D) * 0.5 >>> dist = SoftmaxNormal(loc, cov_factor, cov_diag) >>> # Sample from the distribution >>> samples = dist.sample(random.PRNGKey(0), (100,)) >>> samples.shape (100, 5) >>> # Samples sum to 1 >>> jnp.allclose(samples.sum(axis=-1), 1.0) True >>> # Access underlying log-space distribution >>> log_samples = dist.base_dist.sample(random.PRNGKey(1), (100,)) >>> # Apply softmax manually >>> manual_samples = jax.nn.softmax(log_samples, axis=-1)
See also
LowRankLogisticNormalALR-based alternative with log_prob() available
- arg_constraints: dict[str, Any] = {'cov_diag': Positive(lower_bound=0.0), 'cov_factor': IndependentConstraint(Real(), 2), 'loc': RealVector(Real(), 1)}
- support = Simplex()
- has_rsample = False
- sample(key, sample_shape=())[source]
Sample from the distribution.
Returns samples on the D-dimensional simplex.
- log_prob(value)[source]
Evaluate log probability density.
NOT IMPLEMENTED: The softmax transformation is singular (adding a constant to all log-space coordinates doesn’t change the output), so the Jacobian determinant is zero and log_prob is undefined.
- Parameters:
value (jnp.ndarray) – Points on the simplex
- Raises:
NotImplementedError – Always raised. Use LowRankLogisticNormal if you need log_prob(), or access base_dist.log_prob() for log-space density.