Skip to content

mc

mc

Bayesian model comparison for SCRIBE.

This module provides scalable, fully Bayesian model comparison tools based on out-of-sample predictive accuracy. It implements two complementary criteria:

  • WAIC (Widely Applicable Information Criterion): a fast, analytical approximation to LOO-CV computed entirely from the posterior samples already available after fitting.
  • PSIS-LOO (Pareto-Smoothed Importance Sampling LOO): a more reliable criterion that applies Pareto smoothing to the raw IS weights, with a per-observation diagnostic k̂.

In addition, the module provides:

  • Gene-level comparison: per-gene elpd differences between two models, with standard errors and z-scores.
  • Model stacking: optimal predictive ensemble weights via convex optimization of the LOO log-score.
  • Goodness-of-fit diagnostics: per-gene randomized quantile residuals (RQR) that assess how well a single fitted model describes each gene's count distribution, with expression-scale-invariant summary statistics.
  • PPC-based goodness-of-fit: full posterior predictive checks that compare observed histograms to PPC credible bands, producing calibration failure rates and L1 density distances for higher-resolution gene filtering.
Quick start

from scribe.mc import compare_models mc = compare_models( ... [results_nbdm, results_hierarchical], ... counts=counts, ... model_names=["NBDM", "Hierarchical"], ... gene_names=gene_names, ... compute_gene_liks=True, ... ) print(mc.summary()) # ranked comparison table print(mc.diagnostics()) # PSIS k̂ diagnostics mc.rank() # pandas DataFrame mc.gene_level_comparison("NBDM", "Hierarchical") # per-gene DataFrame

Class hierarchy
  • ScribeModelComparisonResults — stores raw log-likelihood matrices and provides lazy-computed WAIC, PSIS-LOO, stacking, and gene-level methods.
Factory
  • compare_models() — accepts a list of fitted results objects, computes log-likelihoods for each model, and returns a ScribeModelComparisonResults.
Low-level functions
  • waic() / compute_waic_stats() — JAX-accelerated WAIC.
  • compute_psis_loo() — NumPy/SciPy PSIS-LOO with Pareto fitting.
  • gene_level_comparison() — per-gene elpd differences.
  • compute_stacking_weights() — stacking weight optimization.
  • compute_quantile_residuals() — randomized quantile residuals.
  • goodness_of_fit_scores() — per-gene fit diagnostics from residuals.
  • compute_gof_mask() — boolean gene mask from fit quality.
  • ppc_goodness_of_fit_scores() — PPC-based per-gene calibration and L1 scoring.
  • compute_ppc_gof_mask() — PPC-based boolean gene mask with gene batching.

See paper/_model_comparison.qmd and paper/_goodness_of_fit.qmd for full mathematical derivations.

ScribeModelComparisonResults dataclass

ScribeModelComparisonResults(model_names, log_liks_cell, log_liks_gene=None, gene_names=None, n_cells=0, n_genes=0, active_components=None, dtype=float64)

Structured results for Bayesian model comparison across K models.

Stores raw posterior log-likelihood matrices for each model and provides methods for computing WAIC, PSIS-LOO, model ranking, and gene-level comparisons. All expensive computations are performed lazily and cached.

PARAMETER DESCRIPTION
model_names

Human-readable names for the K models in the comparison.

TYPE: list of str

log_liks_cell

List of K arrays, each of shape (S, C), containing the per-cell log-likelihoods under each model. S is the number of posterior samples, C is the number of cells.

TYPE: list of jnp.ndarray

log_liks_gene

List of K arrays of shape (S, G), containing per-gene log-likelihoods (summed over cells) for each model. Required for :meth:gene_level_comparison. Can be computed by calling :func:compare_models with compute_gene_liks=True.

TYPE: list of jnp.ndarray DEFAULT: None

gene_names

Names for the G genes. Used in gene-level comparison output.

TYPE: list of str DEFAULT: None

n_cells

Number of cells (observations).

TYPE: int DEFAULT: 0

n_genes

Number of genes.

TYPE: int DEFAULT: 0

active_components

Per-model boolean masks recording which mixture components survived the dead-component pruning step (see :func:compare_models parameter component_threshold). Entry k has shape (K_original_k,) and is set to None when no pruning was applied (non-mixture model or all components active). None for the whole list when pruning was not requested.

TYPE: list of np.ndarray or None DEFAULT: None

dtype

Precision used for PSIS-LOO computations.

TYPE: numpy dtype DEFAULT: np.float64

ATTRIBUTE DESCRIPTION
K

Number of models.

TYPE: int

Examples:

>>> from scribe.mc import compare_models
>>> mc = compare_models(
...     [results_nbdm, results_hierarchical],
...     counts=counts,
...     model_names=["NBDM", "Hierarchical"],
...     gene_names=gene_names,
...     compute_gene_liks=True,
... )
>>> mc.rank()
>>> mc.summary()
>>> mc.gene_level_comparison("NBDM", "Hierarchical")

K property

K

Number of models being compared.

waic

waic(model_idx=None)

Compute WAIC statistics for one or all models.

Results are cached after the first call; repeated calls are free.

PARAMETER DESCRIPTION
model_idx

If provided, return WAIC statistics only for model model_idx. If None (default), return a list of dicts for all K models.

TYPE: int DEFAULT: None

RETURNS DESCRIPTION
dict or list of dict

Each dict contains keys: lppd, p_waic_1, p_waic_2, elppd_waic_1, elppd_waic_2, waic_1, waic_2.

Examples:

>>> stats_all = mc.waic()
>>> stats_first = mc.waic(model_idx=0)
Source code in src/scribe/mc/results.py
def waic(self, model_idx: Optional[int] = None) -> Union[dict, List[dict]]:
    """Compute WAIC statistics for one or all models.

    Results are cached after the first call; repeated calls are free.

    Parameters
    ----------
    model_idx : int, optional
        If provided, return WAIC statistics only for model ``model_idx``.
        If ``None`` (default), return a list of dicts for all K models.

    Returns
    -------
    dict or list of dict
        Each dict contains keys: ``lppd``, ``p_waic_1``, ``p_waic_2``,
        ``elppd_waic_1``, ``elppd_waic_2``, ``waic_1``, ``waic_2``.

    Examples
    --------
    >>> stats_all = mc.waic()
    >>> stats_first = mc.waic(model_idx=0)
    """
    if self._waic_cache is None:
        self._waic_cache = [
            {
                k: (float(v) if jnp.ndim(v) == 0 else v)
                for k, v in compute_waic_stats(ll, aggregate=True).items()
            }
            for ll in self.log_liks_cell
        ]
    if model_idx is not None:
        return self._waic_cache[model_idx]
    return self._waic_cache

psis_loo

psis_loo(model_idx=None)

Compute PSIS-LOO statistics for one or all models.

PSIS-LOO is computed using NumPy/SciPy (Pareto fitting is not JIT- compilable). Results are cached after the first call.

PARAMETER DESCRIPTION
model_idx

If provided, return PSIS-LOO statistics only for model model_idx. If None, return a list for all K models.

TYPE: int DEFAULT: None

RETURNS DESCRIPTION
dict or list of dict

Each dict contains: elpd_loo, p_loo, looic, elpd_loo_i (per-observation), k_hat (per-observation), lppd, n_bad (number of observations with k̂ ≥ 0.7).

Examples:

>>> loo_all = mc.psis_loo()
>>> print(loo_all[0]["n_bad"])
Source code in src/scribe/mc/results.py
def psis_loo(
    self, model_idx: Optional[int] = None
) -> Union[dict, List[dict]]:
    """Compute PSIS-LOO statistics for one or all models.

    PSIS-LOO is computed using NumPy/SciPy (Pareto fitting is not JIT-
    compilable).  Results are cached after the first call.

    Parameters
    ----------
    model_idx : int, optional
        If provided, return PSIS-LOO statistics only for model
        ``model_idx``.  If ``None``, return a list for all K models.

    Returns
    -------
    dict or list of dict
        Each dict contains: ``elpd_loo``, ``p_loo``, ``looic``,
        ``elpd_loo_i`` (per-observation), ``k_hat`` (per-observation),
        ``lppd``, ``n_bad`` (number of observations with k̂ ≥ 0.7).

    Examples
    --------
    >>> loo_all = mc.psis_loo()
    >>> print(loo_all[0]["n_bad"])
    """
    if self._psis_loo_cache is None:
        self._psis_loo_cache = [
            compute_psis_loo(np.asarray(ll), dtype=self.dtype)
            for ll in self.log_liks_cell
        ]
    if model_idx is not None:
        return self._psis_loo_cache[model_idx]
    return self._psis_loo_cache

stacking_weights

stacking_weights(n_restarts=5, seed=42)

Compute optimal stacking weights from PSIS-LOO estimates.

The stacking weights maximize the LOO log predictive score of the model ensemble. They are computed once and cached.

PARAMETER DESCRIPTION
n_restarts

Number of random restarts for the convex optimization.

TYPE: int DEFAULT: 5

seed

Random seed.

TYPE: int DEFAULT: 42

RETURNS DESCRIPTION
np.ndarray, shape ``(K,)``

Optimal stacking weights summing to 1.

Source code in src/scribe/mc/results.py
def stacking_weights(
    self,
    n_restarts: int = 5,
    seed: int = 42,
) -> np.ndarray:
    """Compute optimal stacking weights from PSIS-LOO estimates.

    The stacking weights maximize the LOO log predictive score of the
    model ensemble.  They are computed once and cached.

    Parameters
    ----------
    n_restarts : int, default=5
        Number of random restarts for the convex optimization.
    seed : int, default=42
        Random seed.

    Returns
    -------
    np.ndarray, shape ``(K,)``
        Optimal stacking weights summing to 1.
    """
    if self._stacking_weights_cache is None:
        loo_results = self.psis_loo()
        loo_log_i = [r["elpd_loo_i"] for r in loo_results]
        self._stacking_weights_cache = compute_stacking_weights(
            loo_log_i, n_restarts=n_restarts, seed=seed
        )
    return self._stacking_weights_cache

rank

rank(criterion='psis_loo', include_stacking=True)

Rank models by predictive performance.

Produces a summary DataFrame analogous to arviz.compare(), with columns for elpd, effective parameter count, elpd difference from the best model, standard error of the difference, and model weights.

PARAMETER DESCRIPTION
criterion

Criterion to use for ranking. One of: - 'psis_loo': PSIS-LOO elpd (recommended). - 'waic_2': WAIC using variance-based penalty. - 'waic_1': WAIC using bias-corrected penalty.

TYPE: str DEFAULT: 'psis_loo'

include_stacking

If True, include stacking weights in the output.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
DataFrame

Rows are models, sorted by elpd descending (best first). Columns:

model Model name. elpd Expected log predictive density. p_eff Effective number of parameters. elpd_diff Difference in elpd from the best model (0 for the best). elpd_diff_se Standard error of the elpd difference (from pointwise CLT). z_score elpd_diff / elpd_diff_se. weight_pseudo_bma Pseudo-BMA (AIC-style) weight. weight_stacking Optimal stacking weight (only if include_stacking=True). n_bad_k Number of observations with k̂ ≥ 0.7 (PSIS-LOO only).

Examples:

>>> df = mc.rank()
>>> print(df[["model", "elpd", "elpd_diff", "weight_stacking"]])
Source code in src/scribe/mc/results.py
def rank(
    self,
    criterion: str = "psis_loo",
    include_stacking: bool = True,
) -> pd.DataFrame:
    """Rank models by predictive performance.

    Produces a summary DataFrame analogous to ``arviz.compare()``, with
    columns for elpd, effective parameter count, elpd difference from the
    best model, standard error of the difference, and model weights.

    Parameters
    ----------
    criterion : str, default='psis_loo'
        Criterion to use for ranking.  One of:
        - ``'psis_loo'``: PSIS-LOO elpd (recommended).
        - ``'waic_2'``: WAIC using variance-based penalty.
        - ``'waic_1'``: WAIC using bias-corrected penalty.
    include_stacking : bool, default=True
        If ``True``, include stacking weights in the output.

    Returns
    -------
    pd.DataFrame
        Rows are models, sorted by elpd descending (best first).
        Columns:

        ``model``
            Model name.
        ``elpd``
            Expected log predictive density.
        ``p_eff``
            Effective number of parameters.
        ``elpd_diff``
            Difference in elpd from the best model (0 for the best).
        ``elpd_diff_se``
            Standard error of the elpd difference (from pointwise CLT).
        ``z_score``
            ``elpd_diff / elpd_diff_se``.
        ``weight_pseudo_bma``
            Pseudo-BMA (AIC-style) weight.
        ``weight_stacking``
            Optimal stacking weight (only if ``include_stacking=True``).
        ``n_bad_k``
            Number of observations with k̂ ≥ 0.7 (PSIS-LOO only).

    Examples
    --------
    >>> df = mc.rank()
    >>> print(df[["model", "elpd", "elpd_diff", "weight_stacking"]])
    """
    if criterion == "psis_loo":
        loo_results = self.psis_loo()
        elpd_values = np.array([r["elpd_loo"] for r in loo_results])
        p_eff_values = np.array([r["p_loo"] for r in loo_results])
        elpd_pointwise = [r["elpd_loo_i"] for r in loo_results]
        n_bad = [r["n_bad"] for r in loo_results]
    elif criterion in ("waic_2", "waic_1"):
        waic_results = self.waic()
        elppd_key = (
            "elppd_waic_2" if criterion == "waic_2" else "elppd_waic_1"
        )
        p_key = "p_waic_2" if criterion == "waic_2" else "p_waic_1"
        elpd_values = np.array([r[elppd_key] for r in waic_results])
        p_eff_values = np.array([r[p_key] for r in waic_results])
        # For SE computation: use per-observation lppd differences
        # We need pointwise WAIC contributions, so recompute with
        # aggregate=False
        elpd_pointwise = []
        for ll in self.log_liks_cell:
            pw_stats = compute_waic_stats(ll, aggregate=False)
            elpd_pointwise.append(np.asarray(pw_stats[elppd_key]))
        n_bad = [0] * self.K  # n_bad is only meaningful for PSIS-LOO
    else:
        raise ValueError(
            f"Unknown criterion '{criterion}'. "
            "Use 'psis_loo', 'waic_2', or 'waic_1'."
        )

    # Best model index (highest elpd)
    best_idx = int(np.argmax(elpd_values))
    best_pointwise = elpd_pointwise[best_idx]

    # Pairwise elpd differences and SE (relative to best model)
    elpd_diff = elpd_values - elpd_values[best_idx]  # best model has diff=0

    # SE of the difference via pointwise CLT
    elpd_diff_se = np.zeros(self.K)
    z_scores = np.zeros(self.K)
    for k in range(self.K):
        if k == best_idx:
            continue
        d_i = elpd_pointwise[k] - best_pointwise
        se = float(np.sqrt(np.sum((d_i - d_i.mean()) ** 2)))
        elpd_diff_se[k] = se
        z_scores[k] = elpd_diff[k] / se if se > 0 else 0.0

    # Pseudo-BMA weights (from WAIC or LOO)
    # Use the negative elpd scaled as WAIC: IC = -2 * elpd
    ic_values = -2.0 * elpd_values
    wt_pbma = np.asarray(pseudo_bma_weights(jnp.array(ic_values)))

    # Stacking weights (optional, more expensive)
    wt_stack = None
    if include_stacking:
        try:
            wt_stack = self.stacking_weights()
        except Exception:
            # Stacking can fail if LOO densities are degenerate
            wt_stack = wt_pbma.copy()

    # Assemble DataFrame
    records = []
    for k in range(self.K):
        rec = {
            "model": self.model_names[k],
            "elpd": float(elpd_values[k]),
            "p_eff": float(p_eff_values[k]),
            "elpd_diff": float(elpd_diff[k]),
            "elpd_diff_se": float(elpd_diff_se[k]),
            "z_score": float(z_scores[k]),
            "weight_pseudo_bma": float(wt_pbma[k]),
        }
        if include_stacking and wt_stack is not None:
            rec["weight_stacking"] = float(wt_stack[k])
        if criterion == "psis_loo":
            rec["n_bad_k"] = int(n_bad[k])
        records.append(rec)

    df = pd.DataFrame(records)
    df = df.sort_values("elpd", ascending=False).reset_index(drop=True)
    return df

gene_level_comparison

gene_level_comparison(model_A, model_B, gene_names=None, criterion='waic_2')

Compare two models gene by gene.

Computes per-gene elpd differences, standard errors, and z-scores using gene-level log-likelihoods (summed over cells).

Requires that :func:compare_models was called with compute_gene_liks=True, otherwise raises RuntimeError.

PARAMETER DESCRIPTION
model_A

Index or name of model A in the comparison.

TYPE: int or str

model_B

Index or name of model B.

TYPE: int or str

gene_names

Override the stored gene names.

TYPE: list of str DEFAULT: None

criterion

WAIC variant to use for per-gene elpd.

TYPE: str DEFAULT: 'waic_2'

RETURNS DESCRIPTION
DataFrame

Per-gene comparison table from :func:~scribe.mc._gene_level.gene_level_comparison.

RAISES DESCRIPTION
RuntimeError

If gene-level log-likelihoods are not available.

Examples:

>>> df = mc.gene_level_comparison("NBDM", "Hierarchical")
>>> print(df.head(10))
Source code in src/scribe/mc/results.py
def gene_level_comparison(
    self,
    model_A: Union[int, str],
    model_B: Union[int, str],
    gene_names: Optional[List[str]] = None,
    criterion: str = "waic_2",
) -> pd.DataFrame:
    """Compare two models gene by gene.

    Computes per-gene elpd differences, standard errors, and z-scores
    using gene-level log-likelihoods (summed over cells).

    Requires that :func:`compare_models` was called with
    ``compute_gene_liks=True``, otherwise raises ``RuntimeError``.

    Parameters
    ----------
    model_A : int or str
        Index or name of model A in the comparison.
    model_B : int or str
        Index or name of model B.
    gene_names : list of str, optional
        Override the stored gene names.
    criterion : str, default='waic_2'
        WAIC variant to use for per-gene elpd.

    Returns
    -------
    pd.DataFrame
        Per-gene comparison table from
        :func:`~scribe.mc._gene_level.gene_level_comparison`.

    Raises
    ------
    RuntimeError
        If gene-level log-likelihoods are not available.

    Examples
    --------
    >>> df = mc.gene_level_comparison("NBDM", "Hierarchical")
    >>> print(df.head(10))
    """
    if self.log_liks_gene is None:
        raise RuntimeError(
            "Gene-level log-likelihoods are not available.  "
            "Re-run compare_models() with compute_gene_liks=True."
        )

    # Resolve model indices
    idx_A = _resolve_model_idx(model_A, self.model_names)
    idx_B = _resolve_model_idx(model_B, self.model_names)

    names = gene_names or self.gene_names

    return gene_level_comparison(
        log_liks_A=np.asarray(self.log_liks_gene[idx_A]),
        log_liks_B=np.asarray(self.log_liks_gene[idx_B]),
        gene_names=names,
        label_A=self.model_names[idx_A],
        label_B=self.model_names[idx_B],
        criterion=criterion,
    )

diagnostics

diagnostics(model_idx=None)

Format PSIS-LOO diagnostics (k̂ summary) for one or all models.

PARAMETER DESCRIPTION
model_idx

If provided, show diagnostics only for model model_idx. Otherwise show diagnostics for all models.

TYPE: int DEFAULT: None

RETURNS DESCRIPTION
str

Multi-line diagnostic summary.

Source code in src/scribe/mc/results.py
def diagnostics(self, model_idx: Optional[int] = None) -> str:
    """Format PSIS-LOO diagnostics (k̂ summary) for one or all models.

    Parameters
    ----------
    model_idx : int, optional
        If provided, show diagnostics only for model ``model_idx``.
        Otherwise show diagnostics for all models.

    Returns
    -------
    str
        Multi-line diagnostic summary.
    """
    loo_results = self.psis_loo()
    if model_idx is not None:
        indices = [model_idx]
    else:
        indices = list(range(self.K))

    parts = []
    for k in indices:
        header = f"\n--- {self.model_names[k]} ---"
        parts.append(header)
        parts.append(psis_loo_summary(loo_results[k]))
    return "\n".join(parts)

summary

summary(criterion='psis_loo', include_stacking=True)

Format a ranked comparison table as a string.

PARAMETER DESCRIPTION
criterion

Ranking criterion: 'psis_loo', 'waic_2', or 'waic_1'.

TYPE: str DEFAULT: 'psis_loo'

include_stacking

Whether to include stacking weights.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
str

Formatted comparison table.

Examples:

>>> print(mc.summary())
Source code in src/scribe/mc/results.py
def summary(
    self,
    criterion: str = "psis_loo",
    include_stacking: bool = True,
) -> str:
    """Format a ranked comparison table as a string.

    Parameters
    ----------
    criterion : str, default='psis_loo'
        Ranking criterion: ``'psis_loo'``, ``'waic_2'``, or ``'waic_1'``.
    include_stacking : bool, default=True
        Whether to include stacking weights.

    Returns
    -------
    str
        Formatted comparison table.

    Examples
    --------
    >>> print(mc.summary())
    """
    df = self.rank(criterion=criterion, include_stacking=include_stacking)
    header = f"Model Comparison ({criterion.upper()})\n" + "=" * 60 + "\n"
    return header + df.to_string(index=False)

__repr__

__repr__()

Concise representation of the model comparison.

When dead-component pruning was applied, the repr shows the original and surviving component counts per model, e.g. 'ZINBVCP(4→2)'.

Source code in src/scribe/mc/results.py
def __repr__(self) -> str:
    """Concise representation of the model comparison.

    When dead-component pruning was applied, the repr shows the original
    and surviving component counts per model, e.g. ``'ZINBVCP(4→2)'``.
    """
    # Build per-model labels, annotating pruned mixture models
    labels = []
    for k, name in enumerate(self.model_names):
        mask = (
            self.active_components[k]
            if self.active_components is not None
            else None
        )
        if mask is not None:
            k_orig = int(mask.shape[0])
            k_pruned = int(mask.sum())
            labels.append(f"{name}({k_orig}\u2192{k_pruned})")
        else:
            labels.append(name)

    return (
        f"ScribeModelComparisonResults("
        f"K={self.K}, "
        f"n_cells={self.n_cells}, "
        f"n_genes={self.n_genes}, "
        f"models={labels})"
    )

compare_models

compare_models(results_list, counts, model_names=None, gene_names=None, n_samples=1000, rng_key=None, batch_size=None, posterior_sample_chunk_size=8, compute_gene_liks=False, ignore_nans=False, component_threshold=0.0, r_floor=1e-06, p_floor=1e-06, dtype_lik=float32, dtype_psis=float64)

Create a model comparison results object from a list of fitted models.

For each model, this function:

  1. Ensures posterior samples are available (calls get_posterior_samples if needed for SVI models).
  2. Computes the per-cell log-likelihood matrix of shape (S, C) using the model's log_likelihood method.
  3. Optionally computes per-gene log-likelihood matrices of shape (S, G) when compute_gene_liks=True.
  4. Returns a :class:ScribeModelComparisonResults that provides lazy WAIC, PSIS-LOO, stacking, and gene-level comparison methods.
PARAMETER DESCRIPTION
results_list

List of K fitted model objects to compare.

TYPE: list of ScribeSVIResults or ScribeMCMCResults

counts

Observed count matrix (cells × genes).

TYPE: array-like, shape ``(C, G)``

model_names

Human-readable names for each model. Defaults to ["model_0", "model_1", ...].

TYPE: list of str DEFAULT: None

gene_names

Gene names for gene-level comparisons.

TYPE: list of str DEFAULT: None

n_samples

Number of posterior samples to draw for SVI models that do not yet have posterior_samples populated.

TYPE: int DEFAULT: 1000

rng_key

Random key for SVI posterior sampling. Defaults to jax.random.PRNGKey(0) if None.

TYPE: PRNGKey DEFAULT: None

batch_size

Mini-batch size for log-likelihood computation. None uses the full dataset (fast but memory-intensive).

TYPE: int DEFAULT: None

posterior_sample_chunk_size

Posterior-sample chunk size passed to results.log_likelihood to bound peak memory. Smaller values reduce memory pressure (important for large cell-by-gene matrices) at the cost of longer runtime. Set to None or 0 to evaluate all posterior samples in one vmap.

TYPE: int DEFAULT: 64

compute_gene_liks

If True, also compute per-gene log-likelihoods (shape (S, G)) for gene-level model comparison. Doubles the computation time.

TYPE: bool DEFAULT: False

ignore_nans

If True, discard posterior samples that produce NaN log-likelihoods. Useful when the model occasionally produces degenerate samples.

TYPE: bool DEFAULT: False

component_threshold

Dead-component pruning threshold for mixture models. Any mixture component whose posterior-mean mixing weight is strictly below this fraction is removed before log-likelihood computation; the remaining weights are renormalized to sum to one. This prevents dead components from inflating p_waic_2 and the PSIS-LOO effective parameter count.

Implementation note: Component pruning currently relies on runtime tensor-shape inference plus explicit mixture metadata from param_specs to keep canonical tensors (e.g. p/r) aligned after subsetting. A potential long-term refactor is to drive this entirely from derived-parameter lineage metadata (DerivedParam dependencies), if that metadata is persisted as a first-class runtime contract.

  • 0.0 (default): no pruning; behavior is unchanged.
  • Typical value: 0.01 (prune components below 1 %).
  • Non-mixture models are always passed through unmodified.
  • The pruning decision is stored in :attr:ScribeModelComparisonResults.active_components for transparency.

TYPE: float DEFAULT: 0.0

r_floor

Minimum value clamped onto the NB dispersion parameter r before evaluating log-likelihoods. Posterior samples from a wide variational guide (e.g. high guide rank) can produce r values that underflow to zero in float32, causing lgamma(r) = NaN and discarding the entire sample. A small positive floor prevents this at negligible cost. Set to 0.0 to disable.

TYPE: float DEFAULT: 1e-6

p_floor

Epsilon applied to the success probability p (or effective probability p_hat for VCP models), clipping it to the open interval (p_floor, 1 - p_floor) before evaluating log-likelihoods.

Two float32 degenerate cases this guards against:

  1. Hierarchical modelsphi_g → 0 causes p_g = 1/(1+0) = 1.0 exactly in float32, making r * log(1 - p) = NaN.
  2. VCP modelsphi_capture → ∞ causes p_capture = 0, which then gives p_hat = 0 and NB(r,0).log_prob(0) = NaN.

Set to 0.0 to disable.

TYPE: float DEFAULT: 1e-6

dtype_lik

Precision for log-likelihood computation.

TYPE: dtype DEFAULT: jnp.float32

dtype_psis

Precision for PSIS-LOO computation. Double precision is recommended for reliable Pareto fitting.

TYPE: numpy dtype DEFAULT: np.float64

RETURNS DESCRIPTION
ScribeModelComparisonResults

Structured comparison results with lazy-computed WAIC, PSIS-LOO, and stacking weights.

Examples:

>>> from scribe.mc import compare_models
>>> mc = compare_models(
...     [results_nbdm, results_hierarchical],
...     counts=counts,
...     model_names=["NBDM", "Hierarchical"],
...     gene_names=gene_names,
...     compute_gene_liks=True,
... )
>>> print(mc.summary())
>>> print(mc.diagnostics())
Source code in src/scribe/mc/results.py
def compare_models(
    results_list,
    counts: Union[np.ndarray, jnp.ndarray],
    model_names: Optional[List[str]] = None,
    gene_names: Optional[List[str]] = None,
    n_samples: int = 1000,
    rng_key=None,
    batch_size: Optional[int] = None,
    posterior_sample_chunk_size: Optional[int] = 8,
    compute_gene_liks: bool = False,
    ignore_nans: bool = False,
    component_threshold: float = 0.0,
    r_floor: float = 1e-6,
    p_floor: float = 1e-6,
    dtype_lik: jnp.dtype = jnp.float32,
    dtype_psis: type = np.float64,
) -> ScribeModelComparisonResults:
    """Create a model comparison results object from a list of fitted models.

    For each model, this function:

    1. Ensures posterior samples are available (calls
       ``get_posterior_samples`` if needed for SVI models).
    2. Computes the per-cell log-likelihood matrix of shape ``(S, C)``
       using the model's ``log_likelihood`` method.
    3. Optionally computes per-gene log-likelihood matrices of shape
       ``(S, G)`` when ``compute_gene_liks=True``.
    4. Returns a :class:`ScribeModelComparisonResults` that provides lazy
       WAIC, PSIS-LOO, stacking, and gene-level comparison methods.

    Parameters
    ----------
    results_list : list of ScribeSVIResults or ScribeMCMCResults
        List of K fitted model objects to compare.
    counts : array-like, shape ``(C, G)``
        Observed count matrix (cells × genes).
    model_names : list of str, optional
        Human-readable names for each model.  Defaults to
        ``["model_0", "model_1", ...]``.
    gene_names : list of str, optional
        Gene names for gene-level comparisons.
    n_samples : int, default=1000
        Number of posterior samples to draw for SVI models that do not yet
        have ``posterior_samples`` populated.
    rng_key : jax.random.PRNGKey, optional
        Random key for SVI posterior sampling.  Defaults to
        ``jax.random.PRNGKey(0)`` if ``None``.
    batch_size : int, optional
        Mini-batch size for log-likelihood computation.  ``None`` uses the
        full dataset (fast but memory-intensive).
    posterior_sample_chunk_size : int, optional, default=64
        Posterior-sample chunk size passed to ``results.log_likelihood`` to
        bound peak memory. Smaller values reduce memory pressure (important for
        large cell-by-gene matrices) at the cost of longer runtime. Set to
        ``None`` or ``0`` to evaluate all posterior samples in one ``vmap``.
    compute_gene_liks : bool, default=False
        If ``True``, also compute per-gene log-likelihoods (shape ``(S, G)``)
        for gene-level model comparison.  Doubles the computation time.
    ignore_nans : bool, default=False
        If ``True``, discard posterior samples that produce NaN log-likelihoods.
        Useful when the model occasionally produces degenerate samples.
    component_threshold : float, default=0.0
        Dead-component pruning threshold for mixture models.  Any mixture
        component whose **posterior-mean mixing weight** is strictly below
        this fraction is removed before log-likelihood computation; the
        remaining weights are renormalized to sum to one.  This prevents dead
        components from inflating ``p_waic_2`` and the PSIS-LOO effective
        parameter count.

        Implementation note:
        Component pruning currently relies on runtime tensor-shape inference
        plus explicit mixture metadata from ``param_specs`` to keep canonical
        tensors (e.g. ``p``/``r``) aligned after subsetting.  A potential
        long-term refactor is to drive this entirely from derived-parameter
        lineage metadata (``DerivedParam`` dependencies), if that metadata is
        persisted as a first-class runtime contract.

        - ``0.0`` (default): no pruning; behavior is unchanged.
        - Typical value: ``0.01`` (prune components below 1 %).
        - Non-mixture models are always passed through unmodified.
        - The pruning decision is stored in
          :attr:`ScribeModelComparisonResults.active_components` for
          transparency.
    r_floor : float, default=1e-6
        Minimum value clamped onto the NB dispersion parameter ``r`` before
        evaluating log-likelihoods.  Posterior samples from a wide
        variational guide (e.g. high guide rank) can produce ``r`` values
        that underflow to zero in ``float32``, causing ``lgamma(r) = NaN``
        and discarding the entire sample.  A small positive floor prevents
        this at negligible cost.  Set to ``0.0`` to disable.
    p_floor : float, default=1e-6
        Epsilon applied to the success probability ``p`` (or effective
        probability ``p_hat`` for VCP models), clipping it to the open
        interval ``(p_floor, 1 - p_floor)`` before evaluating log-likelihoods.

        Two float32 degenerate cases this guards against:

        1. **Hierarchical models** — ``phi_g → 0`` causes ``p_g = 1/(1+0)
           = 1.0`` exactly in float32, making ``r * log(1 - p) = NaN``.
        2. **VCP models** — ``phi_capture → ∞`` causes ``p_capture = 0``,
           which then gives ``p_hat = 0`` and ``NB(r,0).log_prob(0) = NaN``.

        Set to ``0.0`` to disable.
    dtype_lik : jnp.dtype, default=jnp.float32
        Precision for log-likelihood computation.
    dtype_psis : numpy dtype, default=np.float64
        Precision for PSIS-LOO computation.  Double precision is recommended
        for reliable Pareto fitting.

    Returns
    -------
    ScribeModelComparisonResults
        Structured comparison results with lazy-computed WAIC, PSIS-LOO,
        and stacking weights.

    Examples
    --------
    >>> from scribe.mc import compare_models
    >>> mc = compare_models(
    ...     [results_nbdm, results_hierarchical],
    ...     counts=counts,
    ...     model_names=["NBDM", "Hierarchical"],
    ...     gene_names=gene_names,
    ...     compute_gene_liks=True,
    ... )
    >>> print(mc.summary())
    >>> print(mc.diagnostics())
    """
    from jax import random as jrandom

    K = len(results_list)

    # Default model names
    if model_names is None:
        model_names = [f"model_{k}" for k in range(K)]
    if len(model_names) != K:
        raise ValueError(
            f"model_names has length {len(model_names)} but results_list has {K} models."
        )

    # Default RNG key for SVI sampling
    if rng_key is None:
        rng_key = jrandom.PRNGKey(0)

    # Ensure counts is a JAX array
    counts = jnp.asarray(counts, dtype=dtype_lik)
    n_cells, n_genes = counts.shape

    # Compute per-cell log-likelihoods for each model
    log_liks_cell = []
    log_liks_gene = [] if compute_gene_liks else None
    # Track dead-component pruning info (None per model = no pruning applied)
    active_masks: List[Optional[np.ndarray]] = []
    any_pruned = False

    # Split RNG keys so each model gets an independent key
    rng_keys = jrandom.split(rng_key, K)

    for k, results in enumerate(results_list):
        name = model_names[k]
        print(f"Computing log-likelihoods for {name}...")

        # Ensure posterior samples are available
        if getattr(results, "posterior_samples", None) is None:
            try:
                # SVI: requires rng_key and n_samples
                results.get_posterior_samples(rng_keys[k], n_samples)
            except TypeError:
                # MCMC: no arguments needed (samples come from MCMC run)
                results.get_posterior_samples()

        # Dead-component pruning: replace results with a leaner effective model
        # when mixture components fall below component_threshold.
        results_eff, active_mask = _prune_dead_components(
            results, component_threshold
        )
        active_masks.append(active_mask)
        if active_mask is not None:
            any_pruned = True
            k_orig = int(active_mask.shape[0])
            k_kept = int(active_mask.sum())
            print(
                f"  Pruned {k_orig - k_kept} dead component(s) for {name} "
                f"({k_orig}{k_kept} components)."
            )

        # Per-cell log-likelihoods: shape (S, C)
        ll_cell = _get_log_liks(
            results_eff,
            counts,
            "cell",
            batch_size,
            posterior_sample_chunk_size,
            dtype_lik,
            ignore_nans,
            r_floor,
            p_floor,
        )
        log_liks_cell.append(ll_cell)

        # Per-gene log-likelihoods: shape (S, G) — optional
        if compute_gene_liks:
            ll_gene = _get_log_liks(
                results_eff,
                counts,
                "gene",
                batch_size,
                posterior_sample_chunk_size,
                dtype_lik,
                ignore_nans,
                r_floor,
                p_floor,
            )
            log_liks_gene.append(ll_gene)

    return ScribeModelComparisonResults(
        model_names=model_names,
        log_liks_cell=log_liks_cell,
        log_liks_gene=log_liks_gene,
        gene_names=gene_names,
        n_cells=n_cells,
        n_genes=n_genes,
        # Store masks only when at least one model was pruned, otherwise None
        active_components=active_masks if any_pruned else None,
        dtype=dtype_psis,
    )

compute_waic_stats

compute_waic_stats(log_liks, aggregate=True, dtype=float32)

JIT-compiled computation of all WAIC statistics.

Computes lppd, both versions of the effective parameter count, and both WAIC variants from a log-likelihood matrix in a single forward pass.

PARAMETER DESCRIPTION
log_liks

Log-likelihoods for each posterior sample s and observation i. S is the number of posterior samples, n is the number of observations (cells or genes depending on context).

TYPE: jnp.ndarray, shape ``(S, n)``

aggregate

If True return scalar totals; if False return per-observation arrays of shape (n,).

TYPE: bool DEFAULT: True

dtype

Floating-point precision for all computations.

TYPE: dtype DEFAULT: jnp.float32

RETURNS DESCRIPTION
dict

Dictionary with keys:

lppd Log pointwise predictive density. p_waic_1 Effective parameter count (bias-corrected version). p_waic_2 Effective parameter count (variance-based, preferred version). elppd_waic_1 Expected log pointwise predictive density under WAIC1. elppd_waic_2 Expected log pointwise predictive density under WAIC2. waic_1 WAIC on deviance scale, using p_waic_1. waic_2 WAIC on deviance scale, using p_waic_2 (recommended).

Source code in src/scribe/mc/_waic.py
@partial(jit, static_argnames=["aggregate", "dtype"])
def compute_waic_stats(
    log_liks: jnp.ndarray,
    aggregate: bool = True,
    dtype: jnp.dtype = jnp.float32,
) -> dict:
    """JIT-compiled computation of all WAIC statistics.

    Computes lppd, both versions of the effective parameter count, and both
    WAIC variants from a log-likelihood matrix in a single forward pass.

    Parameters
    ----------
    log_liks : jnp.ndarray, shape ``(S, n)``
        Log-likelihoods for each posterior sample ``s`` and observation ``i``.
        ``S`` is the number of posterior samples, ``n`` is the number of
        observations (cells or genes depending on context).
    aggregate : bool, default=True
        If ``True`` return scalar totals; if ``False`` return per-observation
        arrays of shape ``(n,)``.
    dtype : jnp.dtype, default=jnp.float32
        Floating-point precision for all computations.

    Returns
    -------
    dict
        Dictionary with keys:

        ``lppd``
            Log pointwise predictive density.
        ``p_waic_1``
            Effective parameter count (bias-corrected version).
        ``p_waic_2``
            Effective parameter count (variance-based, preferred version).
        ``elppd_waic_1``
            Expected log pointwise predictive density under WAIC1.
        ``elppd_waic_2``
            Expected log pointwise predictive density under WAIC2.
        ``waic_1``
            WAIC on deviance scale, using p_waic_1.
        ``waic_2``
            WAIC on deviance scale, using p_waic_2 (recommended).
    """
    # Compute per-observation lppd once; reuse for p_waic_1
    lppd_pw = _lppd(log_liks, aggregate=False, dtype=dtype)

    # Aggregate lppd
    lppd = jnp.sum(lppd_pw) if aggregate else lppd_pw

    # Effective parameter counts
    pw1 = _p_waic_1(log_liks, lppd_pointwise=lppd_pw, aggregate=aggregate, dtype=dtype)
    pw2 = _p_waic_2(log_liks, aggregate=aggregate, dtype=dtype)

    # elpd and WAIC
    elppd_1 = lppd - pw1
    elppd_2 = lppd - pw2
    waic_1 = -2.0 * elppd_1
    waic_2 = -2.0 * elppd_2

    return {
        "lppd": lppd,
        "p_waic_1": pw1,
        "p_waic_2": pw2,
        "elppd_waic_1": elppd_1,
        "elppd_waic_2": elppd_2,
        "waic_1": waic_1,
        "waic_2": waic_2,
    }

waic

waic(log_liks, aggregate=True, dtype=float32)

Compute WAIC statistics from a posterior log-likelihood matrix.

This is the public entry point. All JIT-compiled computation is delegated to :func:compute_waic_stats.

PARAMETER DESCRIPTION
log_liks

Matrix of log-likelihoods: rows are posterior samples, columns are observations (cells when return_by="cell", genes when return_by="gene").

TYPE: jnp.ndarray, shape ``(S, n)``

aggregate

If True return scalar totals; if False return per-observation vectors useful for gene-level or cell-level comparisons.

TYPE: bool DEFAULT: True

dtype

Floating-point precision.

TYPE: dtype DEFAULT: jnp.float32

RETURNS DESCRIPTION
dict

Keys: lppd, p_waic_1, p_waic_2, elppd_waic_1, elppd_waic_2, waic_1, waic_2.

Examples:

>>> import jax.numpy as jnp
>>> from scribe.mc._waic import waic
>>> log_liks = jnp.ones((500, 1000)) * -2.0   # (S=500, n=1000)
>>> stats = waic(log_liks)
>>> print(stats["waic_2"])
4000.0
Source code in src/scribe/mc/_waic.py
def waic(
    log_liks: jnp.ndarray,
    aggregate: bool = True,
    dtype: jnp.dtype = jnp.float32,
) -> dict:
    """Compute WAIC statistics from a posterior log-likelihood matrix.

    This is the public entry point.  All JIT-compiled computation is delegated
    to :func:`compute_waic_stats`.

    Parameters
    ----------
    log_liks : jnp.ndarray, shape ``(S, n)``
        Matrix of log-likelihoods: rows are posterior samples, columns are
        observations (cells when ``return_by="cell"``, genes when
        ``return_by="gene"``).
    aggregate : bool, default=True
        If ``True`` return scalar totals; if ``False`` return per-observation
        vectors useful for gene-level or cell-level comparisons.
    dtype : jnp.dtype, default=jnp.float32
        Floating-point precision.

    Returns
    -------
    dict
        Keys: ``lppd``, ``p_waic_1``, ``p_waic_2``, ``elppd_waic_1``,
        ``elppd_waic_2``, ``waic_1``, ``waic_2``.

    Examples
    --------
    >>> import jax.numpy as jnp
    >>> from scribe.mc._waic import waic
    >>> log_liks = jnp.ones((500, 1000)) * -2.0   # (S=500, n=1000)
    >>> stats = waic(log_liks)
    >>> print(stats["waic_2"])
    4000.0
    """
    return compute_waic_stats(log_liks, aggregate=aggregate, dtype=dtype)

pseudo_bma_weights

pseudo_bma_weights(waic_values, dtype=float32)

Compute pseudo-BMA model weights from an array of WAIC values.

The pseudo-BMA weight for model k is

w_k  ∝  exp(-0.5 * (WAIC_k - min_k WAIC_k))

which mimics the AIC weight formula and provides a simple summary of relative model quality.

PARAMETER DESCRIPTION
waic_values

WAIC values (on deviance scale, lower is better) for K models.

TYPE: jnp.ndarray, shape ``(K,)``

dtype

Floating-point precision.

TYPE: dtype DEFAULT: jnp.float32

RETURNS DESCRIPTION
jnp.ndarray, shape ``(K,)``

Normalized model weights summing to 1.

Examples:

>>> import jax.numpy as jnp
>>> from scribe.mc._waic import pseudo_bma_weights
>>> w = pseudo_bma_weights(jnp.array([200.0, 210.0, 215.0]))
>>> print(w.sum())
1.0
Source code in src/scribe/mc/_waic.py
def pseudo_bma_weights(
    waic_values: jnp.ndarray,
    dtype: jnp.dtype = jnp.float32,
) -> jnp.ndarray:
    """Compute pseudo-BMA model weights from an array of WAIC values.

    The pseudo-BMA weight for model k is

        w_k  ∝  exp(-0.5 * (WAIC_k - min_k WAIC_k))

    which mimics the AIC weight formula and provides a simple summary of
    relative model quality.

    Parameters
    ----------
    waic_values : jnp.ndarray, shape ``(K,)``
        WAIC values (on deviance scale, lower is better) for K models.
    dtype : jnp.dtype, default=jnp.float32
        Floating-point precision.

    Returns
    -------
    jnp.ndarray, shape ``(K,)``
        Normalized model weights summing to 1.

    Examples
    --------
    >>> import jax.numpy as jnp
    >>> from scribe.mc._waic import pseudo_bma_weights
    >>> w = pseudo_bma_weights(jnp.array([200.0, 210.0, 215.0]))
    >>> print(w.sum())
    1.0
    """
    waic_values = jnp.asarray(waic_values, dtype=dtype)
    # Subtract minimum for numerical stability before exponentiating
    delta = waic_values - jnp.min(waic_values)
    raw = jnp.exp(-0.5 * delta)
    return raw / jnp.sum(raw)

compute_psis_loo

compute_psis_loo(log_liks, dtype=float64)

Compute PSIS-LOO statistics from a posterior log-likelihood matrix.

For each observation i, applies Pareto-smoothed importance sampling to approximate the LOO predictive density without refitting the model. The Pareto shape parameter k̂_i serves as a per-observation reliability diagnostic.

PARAMETER DESCRIPTION
log_liks

Log-likelihood matrix: rows are posterior samples, columns are observations (cells). Can be a JAX or NumPy array; internally converted to NumPy for Pareto fitting.

TYPE: array-like, shape ``(S, n)``

dtype

Numerical precision. Double precision is recommended for PSIS-LOO because the Pareto fitting can be sensitive to precision.

TYPE: numpy dtype DEFAULT: np.float64

RETURNS DESCRIPTION
dict

Keys:

elpd_loo : float Total estimated expected log predictive density. p_loo : float Effective number of parameters: p_loo = lppd - elpd_loo. looic : float LOO information criterion on deviance scale: looic = -2 * elpd_loo. elpd_loo_i : np.ndarray, shape (n,) Per-observation LOO log predictive density. k_hat : np.ndarray, shape (n,) Per-observation Pareto shape diagnostic. lppd : float In-sample log pointwise predictive density (same definition as in WAIC, provided for convenience). n_bad : int Number of observations with k̂ ≥ 0.7 (unreliable LOO contributions).

Examples:

>>> import numpy as np
>>> from scribe.mc._psis_loo import compute_psis_loo
>>> rng = np.random.default_rng(0)
>>> log_liks = rng.normal(-3.0, 0.5, size=(500, 200))
>>> result = compute_psis_loo(log_liks)
>>> print(result["k_hat"].max())
Source code in src/scribe/mc/_psis_loo.py
def compute_psis_loo(
    log_liks: np.ndarray,
    dtype: type = np.float64,
) -> Dict[str, np.ndarray]:
    """Compute PSIS-LOO statistics from a posterior log-likelihood matrix.

    For each observation i, applies Pareto-smoothed importance sampling to
    approximate the LOO predictive density without refitting the model.  The
    Pareto shape parameter k̂_i serves as a per-observation reliability
    diagnostic.

    Parameters
    ----------
    log_liks : array-like, shape ``(S, n)``
        Log-likelihood matrix: rows are posterior samples, columns are
        observations (cells).  Can be a JAX or NumPy array; internally
        converted to NumPy for Pareto fitting.
    dtype : numpy dtype, default=np.float64
        Numerical precision.  Double precision is recommended for PSIS-LOO
        because the Pareto fitting can be sensitive to precision.

    Returns
    -------
    dict
        Keys:

        ``elpd_loo`` : float
            Total estimated expected log predictive density.
        ``p_loo`` : float
            Effective number of parameters:
            ``p_loo = lppd - elpd_loo``.
        ``looic`` : float
            LOO information criterion on deviance scale:
            ``looic = -2 * elpd_loo``.
        ``elpd_loo_i`` : np.ndarray, shape ``(n,)``
            Per-observation LOO log predictive density.
        ``k_hat`` : np.ndarray, shape ``(n,)``
            Per-observation Pareto shape diagnostic.
        ``lppd`` : float
            In-sample log pointwise predictive density (same definition as
            in WAIC, provided for convenience).
        ``n_bad`` : int
            Number of observations with k̂ ≥ 0.7 (unreliable LOO
            contributions).

    Examples
    --------
    >>> import numpy as np
    >>> from scribe.mc._psis_loo import compute_psis_loo
    >>> rng = np.random.default_rng(0)
    >>> log_liks = rng.normal(-3.0, 0.5, size=(500, 200))
    >>> result = compute_psis_loo(log_liks)
    >>> print(result["k_hat"].max())
    """
    # Convert to double-precision NumPy for stable Pareto fitting
    log_liks = np.asarray(log_liks, dtype=dtype)
    S, n = log_liks.shape

    # Storage for smoothed weights and diagnostics
    k_hat = np.zeros(n, dtype=dtype)
    elpd_loo_i = np.zeros(n, dtype=dtype)

    # Process each observation independently
    for i in range(n):
        # Raw log IS weights: log w_s = -log p(y_i | theta^s)
        raw_lw = -log_liks[:, i]

        # Pareto-smooth the weights
        smooth_lw, ki = _pareto_smooth_single(raw_lw)
        k_hat[i] = ki

        # Numerically stable log of the IS-weighted average:
        # log p_loo(y_i | y_{-i}) ≈
        #   log(sum_s exp(smooth_lw_s + log_lik_s)) - log(sum_s exp(smooth_lw_s))
        log_lik_i = log_liks[:, i]

        # Numerator: log sum_s exp(smooth_lw_s + log_lik_s)
        log_num_terms = smooth_lw + log_lik_i
        log_num_max = np.max(log_num_terms)
        log_numerator = log_num_max + np.log(
            np.sum(np.exp(log_num_terms - log_num_max))
        )

        # Denominator: log sum_s exp(smooth_lw_s)
        log_den_max = np.max(smooth_lw)
        log_denominator = log_den_max + np.log(
            np.sum(np.exp(smooth_lw - log_den_max))
        )

        elpd_loo_i[i] = log_numerator - log_denominator

    # In-sample lppd (same formula as WAIC, for reference)
    lse_max = np.max(log_liks, axis=0)          # shape (n,)
    lppd_i = lse_max + np.log(np.mean(np.exp(log_liks - lse_max), axis=0))
    lppd = float(np.sum(lppd_i))

    # Aggregate
    elpd_loo = float(np.sum(elpd_loo_i))
    p_loo = lppd - elpd_loo
    looic = -2.0 * elpd_loo

    return {
        "elpd_loo": elpd_loo,
        "p_loo": p_loo,
        "looic": looic,
        "elpd_loo_i": elpd_loo_i,
        "k_hat": k_hat,
        "lppd": lppd,
        "n_bad": int(np.sum(k_hat >= 0.7)),
    }

psis_loo_summary

psis_loo_summary(result)

Format a human-readable summary of PSIS-LOO diagnostics.

PARAMETER DESCRIPTION
result

Output of :func:compute_psis_loo.

TYPE: dict

RETURNS DESCRIPTION
str

A multi-line summary string.

Examples:

>>> print(psis_loo_summary(result))
Source code in src/scribe/mc/_psis_loo.py
def psis_loo_summary(result: dict) -> str:
    """Format a human-readable summary of PSIS-LOO diagnostics.

    Parameters
    ----------
    result : dict
        Output of :func:`compute_psis_loo`.

    Returns
    -------
    str
        A multi-line summary string.

    Examples
    --------
    >>> print(psis_loo_summary(result))
    """
    k = result["k_hat"]
    n = len(k)
    n_ok = int(np.sum(k < 0.5))
    n_ok2 = int(np.sum((k >= 0.5) & (k < 0.7)))
    n_bad = int(np.sum(k >= 0.7))

    lines = [
        "PSIS-LOO Summary",
        "=" * 40,
        f"  elpd_loo : {result['elpd_loo']:.2f}",
        f"  p_loo    : {result['p_loo']:.2f}",
        f"  LOO-IC   : {result['looic']:.2f}",
        "",
        f"  Pareto k̂ diagnostics (n={n} observations):",
        f"    k̂ < 0.5   (excellent)    : {n_ok:5d}  ({100*n_ok/n:5.1f}%)",
        f"    0.5 ≤ k̂ < 0.7 (OK)      : {n_ok2:5d}  ({100*n_ok2/n:5.1f}%)",
        f"    k̂ ≥ 0.7   (problematic) : {n_bad:5d}  ({100*n_bad/n:5.1f}%)",
    ]
    if n_bad > 0:
        lines.append(
            f"\n  WARNING: {n_bad} observations have k̂ ≥ 0.7."
            " LOO estimates may be unreliable for these cells."
        )
    return "\n".join(lines)

gene_level_comparison

gene_level_comparison(log_liks_A, log_liks_B, gene_names=None, label_A='A', label_B='B', criterion='waic_2')

Compute per-gene model comparison statistics between two models.

For each gene g, computes the pointwise elpd difference between model A and model B. The difference is positive when model A provides better predictions than model B for gene g.

The standard error of the total elpd difference follows from the CLT applied to the pointwise gene-level differences (see @eq-mc-se-delta-elpd in the paper):

SE(delta_elpd) = sqrt(sum_g (d_g - d_bar)^2)

where d_g = elpd_g(A) - elpd_g(B) is the per-gene difference.

PARAMETER DESCRIPTION
log_liks_A

Gene-level log-likelihoods for model A. Each entry is the total log p(all counts for gene g | theta^s) = sum_c log p(u_{gc}|theta^s). Rows are posterior samples, columns are genes.

TYPE: array-like, shape ``(S, G)``

log_liks_B

Gene-level log-likelihoods for model B. Must match log_liks_A in shape.

TYPE: array-like, shape ``(S, G)``

gene_names

Names for the G genes. If None, generic names gene_0, ... are generated.

TYPE: list of str DEFAULT: None

label_A

Human-readable label for model A.

TYPE: str DEFAULT: 'A'

label_B

Human-readable label for model B.

TYPE: str DEFAULT: 'B'

criterion

Which WAIC variant to use for pointwise elpd values. Must be one of 'waic_1', 'waic_2', 'elppd_waic_1', 'elppd_waic_2'. 'waic_2' (variance-based penalty) is recommended.

TYPE: str DEFAULT: 'waic_2'

RETURNS DESCRIPTION
DataFrame

DataFrame with one row per gene and columns:

gene Gene name. elpd_A Per-gene elpd for model A (negative half of WAIC). elpd_B Per-gene elpd for model B. elpd_diff Per-gene elpd difference (A - B); positive means A is better. elpd_diff_se Standard error of the per-gene elpd difference (assumes each gene is an independent observation; see note below). z_score z-score: elpd_diff / elpd_diff_se. Values |z| > 2 indicate a practically significant difference. p_waic_A Effective parameter count per gene for model A. p_waic_B Effective parameter count per gene for model B. favors Which model is favored: label_A if elpd_diff > 0, else label_B.

Notes

The per-gene SE is computed from the cell-level pointwise differences within each gene. However, since log_liks_A/B are already summed over cells (shape (S, G)), the only variability captured here is across posterior samples. The reported elpd_diff_se is therefore the posterior standard deviation of the per-gene elpd difference, not a frequentist SE. It correctly reflects model uncertainty but not sampling variability across cells.

Source code in src/scribe/mc/_gene_level.py
def gene_level_comparison(
    log_liks_A: np.ndarray,
    log_liks_B: np.ndarray,
    gene_names: Optional[List[str]] = None,
    label_A: str = "A",
    label_B: str = "B",
    criterion: str = "waic_2",
) -> pd.DataFrame:
    """Compute per-gene model comparison statistics between two models.

    For each gene g, computes the pointwise elpd difference between model A
    and model B.  The difference is positive when model A provides better
    predictions than model B for gene g.

    The standard error of the total elpd difference follows from the CLT
    applied to the pointwise gene-level differences (see @eq-mc-se-delta-elpd
    in the paper):

        SE(delta_elpd) = sqrt(sum_g (d_g - d_bar)^2)

    where d_g = elpd_g(A) - elpd_g(B) is the per-gene difference.

    Parameters
    ----------
    log_liks_A : array-like, shape ``(S, G)``
        Gene-level log-likelihoods for model A.  Each entry is the total
        log p(all counts for gene g | theta^s) = sum_c log p(u_{gc}|theta^s).
        Rows are posterior samples, columns are genes.
    log_liks_B : array-like, shape ``(S, G)``
        Gene-level log-likelihoods for model B.  Must match ``log_liks_A``
        in shape.
    gene_names : list of str, optional
        Names for the G genes.  If ``None``, generic names ``gene_0, ...`` are
        generated.
    label_A : str, default='A'
        Human-readable label for model A.
    label_B : str, default='B'
        Human-readable label for model B.
    criterion : str, default='waic_2'
        Which WAIC variant to use for pointwise elpd values.  Must be one of
        ``'waic_1'``, ``'waic_2'``, ``'elppd_waic_1'``, ``'elppd_waic_2'``.
        ``'waic_2'`` (variance-based penalty) is recommended.

    Returns
    -------
    pd.DataFrame
        DataFrame with one row per gene and columns:

        ``gene``
            Gene name.
        ``elpd_A``
            Per-gene elpd for model A (negative half of WAIC).
        ``elpd_B``
            Per-gene elpd for model B.
        ``elpd_diff``
            Per-gene elpd difference (A - B); positive means A is better.
        ``elpd_diff_se``
            Standard error of the per-gene elpd difference (assumes each gene
            is an independent observation; see note below).
        ``z_score``
            z-score: ``elpd_diff / elpd_diff_se``.  Values |z| > 2 indicate
            a practically significant difference.
        ``p_waic_A``
            Effective parameter count per gene for model A.
        ``p_waic_B``
            Effective parameter count per gene for model B.
        ``favors``
            Which model is favored: ``label_A`` if ``elpd_diff > 0``, else
            ``label_B``.

    Notes
    -----
    The per-gene SE is computed from the cell-level pointwise differences
    within each gene.  However, since ``log_liks_A/B`` are already summed
    over cells (shape ``(S, G)``), the only variability captured here is
    across posterior samples.  The reported ``elpd_diff_se`` is therefore
    the posterior standard deviation of the per-gene elpd difference, not
    a frequentist SE.  It correctly reflects model uncertainty but not
    sampling variability across cells.
    """
    import jax.numpy as jnp

    log_liks_A = np.asarray(log_liks_A, dtype=np.float64)
    log_liks_B = np.asarray(log_liks_B, dtype=np.float64)

    S, G = log_liks_A.shape
    if log_liks_B.shape != (S, G):
        raise ValueError(
            f"Shape mismatch: log_liks_A has shape {log_liks_A.shape} but "
            f"log_liks_B has shape {log_liks_B.shape}."
        )

    # Generate gene names if not provided
    if gene_names is None:
        gene_names = [f"gene_{g}" for g in range(G)]
    elif len(gene_names) != G:
        raise ValueError(
            f"gene_names has length {len(gene_names)} but log_liks has G={G} genes."
        )

    # Compute per-gene WAIC stats (aggregate=False returns per-gene arrays)
    stats_A = compute_waic_stats(jnp.array(log_liks_A), aggregate=False)
    stats_B = compute_waic_stats(jnp.array(log_liks_B), aggregate=False)

    # Per-gene elpd (using the requested criterion)
    # The elppd_waic_X keys give per-gene elpd directly (with aggregate=False)
    elppd_key = criterion.replace("waic_", "elppd_waic_")
    if elppd_key not in stats_A:
        # If user asked for 'waic_2', map to 'elppd_waic_2'
        # If user asked for 'elppd_waic_2', use directly
        if "elppd" not in elppd_key:
            elppd_key = f"elppd_{criterion}"
    # Fallback to elppd_waic_2
    if elppd_key not in stats_A:
        elppd_key = "elppd_waic_2"

    elpd_A = np.asarray(stats_A[elppd_key])
    elpd_B = np.asarray(stats_B[elppd_key])
    p_waic_A = np.asarray(stats_A["p_waic_2"])
    p_waic_B = np.asarray(stats_B["p_waic_2"])

    # Pointwise difference per gene
    elpd_diff = elpd_A - elpd_B

    # Per-gene SE: posterior std-dev of (elpd_A_s - elpd_B_s) over samples
    # We compute sample-by-sample lppd contributions and take their difference
    # as a proxy for per-gene comparison uncertainty.
    # lppd contribution for sample s and gene g is just log p(gene g | theta^s)
    # (already in log_liks_A[:, g]).  The pointwise elpd difference is:
    #   d_g^(s) = log_liks_A[s, g] - log_liks_B[s, g]
    # SE = std over samples of d_g^(s)
    d_samples = log_liks_A - log_liks_B  # shape (S, G)
    elpd_diff_se = np.std(d_samples, axis=0, ddof=1)  # shape (G,)

    # z-score: how many SEs away from zero?
    # Guard against zero SE (all samples identical for some genes)
    z_score = np.where(elpd_diff_se > 0, elpd_diff / elpd_diff_se, 0.0)

    # Build DataFrame
    df = pd.DataFrame(
        {
            "gene": gene_names,
            f"elpd_{label_A}": elpd_A,
            f"elpd_{label_B}": elpd_B,
            "elpd_diff": elpd_diff,
            "elpd_diff_se": elpd_diff_se,
            "z_score": z_score,
            f"p_waic_{label_A}": p_waic_A,
            f"p_waic_{label_B}": p_waic_B,
            "favors": np.where(elpd_diff > 0, label_A, label_B),
        }
    )

    # Sort by absolute z-score descending (most decisive genes first)
    df = df.sort_values("z_score", key=np.abs, ascending=False).reset_index(drop=True)
    return df

format_gene_comparison_table

format_gene_comparison_table(df, top_n=20, sort_by='z_score')

Format a gene-level comparison DataFrame as a human-readable table.

PARAMETER DESCRIPTION
df

Output of :func:gene_level_comparison.

TYPE: DataFrame

top_n

Number of top genes to display. Displays all genes if None.

TYPE: int DEFAULT: 20

sort_by

Column to sort by (descending by absolute value for z_score, descending for all other columns).

TYPE: str DEFAULT: 'z_score'

RETURNS DESCRIPTION
str

Formatted table string.

Source code in src/scribe/mc/_gene_level.py
def format_gene_comparison_table(
    df: pd.DataFrame,
    top_n: Optional[int] = 20,
    sort_by: str = "z_score",
) -> str:
    """Format a gene-level comparison DataFrame as a human-readable table.

    Parameters
    ----------
    df : pd.DataFrame
        Output of :func:`gene_level_comparison`.
    top_n : int, optional
        Number of top genes to display.  Displays all genes if ``None``.
    sort_by : str, default='z_score'
        Column to sort by (descending by absolute value for ``z_score``,
        descending for all other columns).

    Returns
    -------
    str
        Formatted table string.
    """
    display_df = df.copy()

    # Sort
    if sort_by == "z_score":
        display_df = display_df.sort_values(
            "z_score", key=np.abs, ascending=False
        )
    else:
        display_df = display_df.sort_values(sort_by, ascending=False)

    # Truncate
    if top_n is not None:
        display_df = display_df.head(top_n)

    # Select core columns for display
    core_cols = [
        "gene", "elpd_diff", "elpd_diff_se", "z_score", "favors"
    ]
    # Only keep columns that exist
    cols = [c for c in core_cols if c in display_df.columns]
    display_df = display_df[cols]

    # Format floats
    float_cols = ["elpd_diff", "elpd_diff_se", "z_score"]
    for col in float_cols:
        if col in display_df.columns:
            display_df[col] = display_df[col].map(lambda x: f"{x:.3f}")

    header = f"Gene-level model comparison (top {top_n} genes by |z-score|)\n"
    table = display_df.to_string(index=False)
    return header + table

compute_stacking_weights

compute_stacking_weights(loo_log_densities, n_restarts=5, seed=42)

Compute optimal model stacking weights via convex optimization.

Solves the stacking problem:

w* = argmax_{w in Delta^{K-1}} sum_i log sum_k w_k * exp(loo_i_k)

using scipy's SLSQP solver with multiple random restarts to guard against local solutions (though the problem is strictly convex so local = global).

PARAMETER DESCRIPTION
loo_log_densities

List of K arrays, each of shape (n,), containing the per-observation LOO log predictive densities log p_loo(y_i | y_{-i}, M_k) for model k. These are the elpd_loo_i arrays from :func:~scribe.mc._psis_loo.compute_psis_loo.

TYPE: list of np.ndarray

n_restarts

Number of random initializations. The best solution across restarts is returned.

TYPE: int DEFAULT: 5

seed

Random seed for reproducibility.

TYPE: int DEFAULT: 42

RETURNS DESCRIPTION
np.ndarray, shape ``(K,)``

Optimal stacking weights, summing to 1. A weight near 0 means the corresponding model contributes negligibly to the optimal ensemble.

Examples:

>>> import numpy as np
>>> from scribe.mc._stacking import compute_stacking_weights
>>> rng = np.random.default_rng(0)
>>> K, n = 3, 200
>>> # Model 1 is better; it has higher LOO densities
>>> loo1 = rng.normal(-2.0, 0.3, n)
>>> loo2 = rng.normal(-2.5, 0.3, n)
>>> loo3 = rng.normal(-3.0, 0.3, n)
>>> w = compute_stacking_weights([loo1, loo2, loo3])
>>> print(w)  # Should be concentrated on model 1
Source code in src/scribe/mc/_stacking.py
def compute_stacking_weights(
    loo_log_densities: List[np.ndarray],
    n_restarts: int = 5,
    seed: int = 42,
) -> np.ndarray:
    """Compute optimal model stacking weights via convex optimization.

    Solves the stacking problem:

        w* = argmax_{w in Delta^{K-1}} sum_i log sum_k w_k * exp(loo_i_k)

    using scipy's SLSQP solver with multiple random restarts to guard against
    local solutions (though the problem is strictly convex so local = global).

    Parameters
    ----------
    loo_log_densities : list of np.ndarray
        List of K arrays, each of shape ``(n,)``, containing the per-observation
        LOO log predictive densities ``log p_loo(y_i | y_{-i}, M_k)`` for model k.
        These are the ``elpd_loo_i`` arrays from :func:`~scribe.mc._psis_loo.compute_psis_loo`.
    n_restarts : int, default=5
        Number of random initializations.  The best solution across restarts
        is returned.
    seed : int, default=42
        Random seed for reproducibility.

    Returns
    -------
    np.ndarray, shape ``(K,)``
        Optimal stacking weights, summing to 1.  A weight near 0 means the
        corresponding model contributes negligibly to the optimal ensemble.

    Examples
    --------
    >>> import numpy as np
    >>> from scribe.mc._stacking import compute_stacking_weights
    >>> rng = np.random.default_rng(0)
    >>> K, n = 3, 200
    >>> # Model 1 is better; it has higher LOO densities
    >>> loo1 = rng.normal(-2.0, 0.3, n)
    >>> loo2 = rng.normal(-2.5, 0.3, n)
    >>> loo3 = rng.normal(-3.0, 0.3, n)
    >>> w = compute_stacking_weights([loo1, loo2, loo3])
    >>> print(w)  # Should be concentrated on model 1
    """
    # Stack into (n, K) matrix
    log_loo_i = np.column_stack([np.asarray(l, dtype=np.float64) for l in loo_log_densities])
    K = log_loo_i.shape[1]

    rng = np.random.default_rng(seed)

    # Simplex constraints
    constraints = {"type": "eq", "fun": lambda w: np.sum(w) - 1.0}
    bounds = [(1e-6, 1.0)] * K

    best_val = np.inf
    best_w = np.ones(K) / K  # uniform fallback

    for _ in range(n_restarts):
        # Random initialization on the simplex
        w0 = rng.dirichlet(np.ones(K))
        # scipy calls fun(x, *args) and jac(x, *args), so log_loo_i is
        # passed as the second positional argument to both functions.
        result = minimize(
            fun=_log_mix_loo,
            x0=w0,
            args=(log_loo_i,),
            jac=_grad_log_mix_loo,
            method="SLSQP",
            bounds=bounds,
            constraints=constraints,
            options={"ftol": 1e-12, "maxiter": 1000, "disp": False},
        )
        if result.fun < best_val:
            best_val = result.fun
            best_w = result.x

    # Project back to simplex (clip negatives from numerical noise)
    best_w = np.clip(best_w, 0.0, 1.0)
    best_w /= best_w.sum()
    return best_w

stacking_summary

stacking_summary(weights, model_names=None)

Format a human-readable summary of stacking weights.

PARAMETER DESCRIPTION
weights

Stacking weights.

TYPE: np.ndarray, shape ``(K,)``

model_names

Names for the K models.

TYPE: list of str DEFAULT: None

RETURNS DESCRIPTION
str

Formatted summary string.

Source code in src/scribe/mc/_stacking.py
def stacking_summary(
    weights: np.ndarray,
    model_names: Optional[List[str]] = None,
) -> str:
    """Format a human-readable summary of stacking weights.

    Parameters
    ----------
    weights : np.ndarray, shape ``(K,)``
        Stacking weights.
    model_names : list of str, optional
        Names for the K models.

    Returns
    -------
    str
        Formatted summary string.
    """
    K = len(weights)
    if model_names is None:
        model_names = [f"Model {k}" for k in range(K)]

    # Sort by weight descending
    order = np.argsort(weights)[::-1]
    lines = ["Stacking Weights", "=" * 30]
    for k in order:
        lines.append(f"  {model_names[k]:30s}  {weights[k]:.4f}")
    return "\n".join(lines)

compute_quantile_residuals

compute_quantile_residuals(counts, r, p, rng_key, mixing_weights=None, epsilon=1e-06)

Compute randomized quantile residuals for NB or NB-mixture models.

For each cell-gene pair, the observed count is mapped through the model's predictive CDF (randomized for discrete data) and then through the inverse normal CDF. Under a correctly specified model the residuals are i.i.d. standard normal.

PARAMETER DESCRIPTION
counts

Observed UMI count matrix. Rows are cells, columns are genes.

TYPE: jnp.ndarray, shape ``(C, G)``

r

NB dispersion parameter.

  • Single-component: shape (G,) — one dispersion per gene.
  • Mixture: shape (K, G) — one dispersion per component per gene, where K is the number of mixture components.

TYPE: ndarray

p

NB success probability.

  • Scalar or shape (1,): shared across genes and components.
  • Shape (K,): one per component (shared across genes).
  • Shape (G,): one per gene (single-component only).
  • Shape (K, G): one per component per gene.

TYPE: ndarray

rng_key

JAX PRNG key for the uniform randomization step.

TYPE: ndarray

mixing_weights

Mixture component weights.

  • None for single-component models.
  • Shape (K,) for global weights.
  • Shape (C, K) for per-cell weights.

TYPE: ndarray or None DEFAULT: None

epsilon

Clipping bound: PIT values are clamped to (epsilon, 1-epsilon) before applying the inverse normal CDF, preventing infinite residuals from floating-point boundary cases.

TYPE: float DEFAULT: 1e-6

RETURNS DESCRIPTION
jnp.ndarray, shape ``(C, G)``

Randomized quantile residuals. Under the true model, each entry is approximately drawn from N(0, 1).

Source code in src/scribe/mc/_goodness_of_fit.py
def compute_quantile_residuals(
    counts: jnp.ndarray,
    r: jnp.ndarray,
    p: jnp.ndarray,
    rng_key: jnp.ndarray,
    mixing_weights: Optional[jnp.ndarray] = None,
    epsilon: float = 1e-6,
) -> jnp.ndarray:
    """Compute randomized quantile residuals for NB or NB-mixture models.

    For each cell-gene pair, the observed count is mapped through the
    model's predictive CDF (randomized for discrete data) and then
    through the inverse normal CDF.  Under a correctly specified model
    the residuals are i.i.d. standard normal.

    Parameters
    ----------
    counts : jnp.ndarray, shape ``(C, G)``
        Observed UMI count matrix.  Rows are cells, columns are genes.
    r : jnp.ndarray
        NB dispersion parameter.

        * Single-component: shape ``(G,)`` — one dispersion per gene.
        * Mixture: shape ``(K, G)`` — one dispersion per component per
          gene, where ``K`` is the number of mixture components.
    p : jnp.ndarray
        NB success probability.

        * Scalar or shape ``(1,)``: shared across genes and components.
        * Shape ``(K,)``: one per component (shared across genes).
        * Shape ``(G,)``: one per gene (single-component only).
        * Shape ``(K, G)``: one per component per gene.
    rng_key : jnp.ndarray
        JAX PRNG key for the uniform randomization step.
    mixing_weights : jnp.ndarray or None, optional
        Mixture component weights.

        * ``None`` for single-component models.
        * Shape ``(K,)`` for global weights.
        * Shape ``(C, K)`` for per-cell weights.
    epsilon : float, default=1e-6
        Clipping bound: PIT values are clamped to ``(epsilon, 1-epsilon)``
        before applying the inverse normal CDF, preventing infinite
        residuals from floating-point boundary cases.

    Returns
    -------
    jnp.ndarray, shape ``(C, G)``
        Randomized quantile residuals.  Under the true model, each
        entry is approximately drawn from N(0, 1).
    """
    counts = jnp.asarray(counts)
    r = jnp.asarray(r)
    p = jnp.asarray(p)
    C, G = counts.shape

    if mixing_weights is not None:
        # ---- Mixture model: marginal CDF ----
        cdf_upper = _marginal_nb_cdf(counts, r, p, mixing_weights)
        # CDF at (counts - 1); F(-1) = 0 by convention
        counts_minus_1 = jnp.maximum(counts - 1, 0)
        cdf_lower_raw = _marginal_nb_cdf(counts_minus_1, r, p, mixing_weights)
        cdf_lower = jnp.where(counts == 0, 0.0, cdf_lower_raw)
    else:
        # ---- Single-component model ----
        nb = dist.NegativeBinomialProbs(r, p)
        cdf_upper = nb.cdf(counts)
        counts_minus_1 = jnp.maximum(counts - 1, 0)
        cdf_lower_raw = nb.cdf(counts_minus_1)
        cdf_lower = jnp.where(counts == 0, 0.0, cdf_lower_raw)

    # Randomize: v ~ Uniform(cdf_lower, cdf_upper)
    u = random.uniform(rng_key, shape=(C, G))
    v = cdf_lower + u * (cdf_upper - cdf_lower)

    # Clip to (epsilon, 1 - epsilon) to avoid infinite residuals
    v = jnp.clip(v, epsilon, 1.0 - epsilon)

    # Transform to the normal scale
    q = norm.ppf(v)
    return q

goodness_of_fit_scores

goodness_of_fit_scores(residuals)

Compute per-gene goodness-of-fit summary statistics from residuals.

Under a correctly specified model the residuals are i.i.d. N(0, 1) for each gene. These statistics measure departures from that reference.

PARAMETER DESCRIPTION
residuals

Randomized quantile residual matrix (output of compute_quantile_residuals).

TYPE: jnp.ndarray, shape ``(C, G)``

RETURNS DESCRIPTION
dict

Dictionary with per-gene arrays of shape (G,):

mean Sample mean of residuals. Should be near 0. variance Sample variance (Bessel-corrected). Should be near 1. Values >> 1 indicate the model underestimates gene variability; values << 1 indicate overestimation. tail_excess Fraction of |residual| > 2 minus the N(0,1) expectation (0.0455). Should be near 0. ks_distance Kolmogorov--Smirnov distance between the empirical residual distribution and the standard normal CDF. An omnibus measure of departure from N(0, 1).

Notes

Computational cost is O(C * G) for mean, variance, and tail excess. The KS distance additionally requires sorting each gene's residuals, adding an O(C log C * G) term, which is subdominant for typical single-cell dataset sizes.

Source code in src/scribe/mc/_goodness_of_fit.py
def goodness_of_fit_scores(
    residuals: jnp.ndarray,
) -> Dict[str, jnp.ndarray]:
    """Compute per-gene goodness-of-fit summary statistics from residuals.

    Under a correctly specified model the residuals are i.i.d. N(0, 1)
    for each gene.  These statistics measure departures from that reference.

    Parameters
    ----------
    residuals : jnp.ndarray, shape ``(C, G)``
        Randomized quantile residual matrix (output of
        ``compute_quantile_residuals``).

    Returns
    -------
    dict
        Dictionary with per-gene arrays of shape ``(G,)``:

        ``mean``
            Sample mean of residuals.  Should be near 0.
        ``variance``
            Sample variance (Bessel-corrected).  Should be near 1.
            Values >> 1 indicate the model underestimates gene variability;
            values << 1 indicate overestimation.
        ``tail_excess``
            Fraction of |residual| > 2 minus the N(0,1) expectation
            (0.0455).  Should be near 0.
        ``ks_distance``
            Kolmogorov--Smirnov distance between the empirical residual
            distribution and the standard normal CDF.  An omnibus measure
            of departure from N(0, 1).

    Notes
    -----
    Computational cost is O(C * G) for mean, variance, and tail excess.
    The KS distance additionally requires sorting each gene's residuals,
    adding an O(C log C * G) term, which is subdominant for typical
    single-cell dataset sizes.
    """
    C = residuals.shape[0]

    # Location miscalibration
    mean = jnp.mean(residuals, axis=0)

    # Scale miscalibration (Bessel-corrected)
    variance = jnp.var(residuals, axis=0, ddof=1)

    # Tail excess: fraction of |q| > 2, centered at the N(0,1) expectation
    tail_frac = jnp.mean(jnp.abs(residuals) > 2.0, axis=0)
    expected_tail = 2.0 * (1.0 - norm.cdf(2.0))  # ~0.0455
    tail_excess = tail_frac - expected_tail

    # KS distance vs N(0,1)
    ks_distance = _ks_distance_normal(residuals)

    return {
        "mean": mean,
        "variance": variance,
        "tail_excess": tail_excess,
        "ks_distance": ks_distance,
    }

compute_gof_mask

compute_gof_mask(counts, results, component=None, rng_key=None, counts_for_map=None, min_variance=0.5, max_variance=1.5, max_ks=None, epsilon=1e-06)

Build a per-gene goodness-of-fit boolean mask from a fitted model.

This is the high-level entry point analogous to scribe.de.compute_expression_mask. It extracts MAP parameters from the results object, computes randomized quantile residuals, and returns a boolean mask indicating which genes are adequately described by the model.

Under a correctly specified model, per-gene residual variance is approximately 1. Variance substantially above 1 indicates the model underestimates gene variability (e.g., missing zero-inflation or overdispersion); variance substantially below 1 indicates the model overestimates variability (e.g., prior too diffuse).

PARAMETER DESCRIPTION
counts

Observed UMI count matrix used for residual computation.

TYPE: jnp.ndarray, shape ``(C, G)``

results

Fitted model results object. Must support get_map() and expose n_components.

TYPE: ScribeSVIResults or ScribeMCMCResults

component

For mixture models, if specified, slice to a single component before computing MAP. If None and the model is a mixture, the full marginal mixture CDF is used (recommended).

TYPE: int or None DEFAULT: None

rng_key

JAX PRNG key for the randomization step. If None, a default key PRNGKey(0) is used.

TYPE: ndarray or None DEFAULT: None

counts_for_map

Count matrix to pass to get_map() for models with amortized capture probability. If None, counts is used.

TYPE: ndarray or None DEFAULT: None

min_variance

Lower bound on the residual variance per gene. Genes with s_g^2 <= min_variance are masked out (False). This catches genes where the model overestimates variability. Set to 0.0 to disable the lower-bound check.

TYPE: float DEFAULT: 0.5

max_variance

Upper bound on the residual variance per gene. Genes with s_g^2 >= max_variance are masked out (False). This catches genes where the model underestimates variability.

TYPE: float DEFAULT: 1.5

max_ks

If provided, upper bound on the KS distance per gene. Genes exceeding this are also masked out. If None, only the variance criteria are applied.

TYPE: float or None DEFAULT: None

epsilon

Clipping bound for the PIT values (see compute_quantile_residuals).

TYPE: float DEFAULT: 1e-6

RETURNS DESCRIPTION
np.ndarray, shape ``(G,)``

Boolean mask: True for genes passing the fit criteria, False for poorly fit genes.

Source code in src/scribe/mc/_goodness_of_fit.py
def compute_gof_mask(
    counts: jnp.ndarray,
    results,
    component: Optional[int] = None,
    rng_key: Optional[jnp.ndarray] = None,
    counts_for_map: Optional[jnp.ndarray] = None,
    min_variance: float = 0.5,
    max_variance: float = 1.5,
    max_ks: Optional[float] = None,
    epsilon: float = 1e-6,
) -> np.ndarray:
    """Build a per-gene goodness-of-fit boolean mask from a fitted model.

    This is the high-level entry point analogous to
    ``scribe.de.compute_expression_mask``.  It extracts MAP parameters
    from the results object, computes randomized quantile residuals, and
    returns a boolean mask indicating which genes are adequately described
    by the model.

    Under a correctly specified model, per-gene residual variance is
    approximately 1.  Variance substantially above 1 indicates the model
    **underestimates** gene variability (e.g., missing zero-inflation or
    overdispersion); variance substantially below 1 indicates the model
    **overestimates** variability (e.g., prior too diffuse).

    Parameters
    ----------
    counts : jnp.ndarray, shape ``(C, G)``
        Observed UMI count matrix used for residual computation.
    results : ScribeSVIResults or ScribeMCMCResults
        Fitted model results object.  Must support ``get_map()`` and
        expose ``n_components``.
    component : int or None, optional
        For mixture models, if specified, slice to a single component
        before computing MAP.  If ``None`` and the model is a mixture,
        the full marginal mixture CDF is used (recommended).
    rng_key : jnp.ndarray or None, optional
        JAX PRNG key for the randomization step.  If ``None``, a default
        key ``PRNGKey(0)`` is used.
    counts_for_map : jnp.ndarray or None, optional
        Count matrix to pass to ``get_map()`` for models with amortized
        capture probability.  If ``None``, ``counts`` is used.
    min_variance : float, default=0.5
        Lower bound on the residual variance per gene.  Genes with
        ``s_g^2 <= min_variance`` are masked out (``False``).  This
        catches genes where the model overestimates variability.
        Set to 0.0 to disable the lower-bound check.
    max_variance : float, default=1.5
        Upper bound on the residual variance per gene.  Genes with
        ``s_g^2 >= max_variance`` are masked out (``False``).  This
        catches genes where the model underestimates variability.
    max_ks : float or None, optional
        If provided, upper bound on the KS distance per gene.  Genes
        exceeding this are also masked out.  If ``None``, only the
        variance criteria are applied.
    epsilon : float, default=1e-6
        Clipping bound for the PIT values (see
        ``compute_quantile_residuals``).

    Returns
    -------
    np.ndarray, shape ``(G,)``
        Boolean mask: ``True`` for genes passing the fit criteria,
        ``False`` for poorly fit genes.
    """
    if rng_key is None:
        rng_key = random.PRNGKey(0)

    map_counts = counts_for_map if counts_for_map is not None else counts

    is_mixture = (
        getattr(results, "n_components", None) is not None
        and results.n_components > 1
    )

    if component is not None:
        # Slice to a single component — treated as single-component
        comp_results = results.get_component(component)
        map_est = comp_results.get_map(
            use_mean=True, canonical=True, verbose=False, counts=map_counts
        )
        r, p = _extract_r_p(map_est)
        residuals = compute_quantile_residuals(
            counts, r, p, rng_key, mixing_weights=None, epsilon=epsilon
        )
    elif is_mixture:
        # Full mixture: use marginal CDF
        map_est = results.get_map(
            use_mean=True, canonical=True, verbose=False, counts=map_counts
        )
        r, p, mixing_weights = _extract_mixture_params(map_est)
        residuals = compute_quantile_residuals(
            counts, r, p, rng_key,
            mixing_weights=mixing_weights, epsilon=epsilon,
        )
    else:
        # Single-component model
        map_est = results.get_map(
            use_mean=True, canonical=True, verbose=False, counts=map_counts
        )
        r, p = _extract_r_p(map_est)
        residuals = compute_quantile_residuals(
            counts, r, p, rng_key, mixing_weights=None, epsilon=epsilon
        )

    # Compute summary scores
    scores = goodness_of_fit_scores(residuals)

    # Build mask: variance must be within [min_variance, max_variance]
    mask = np.asarray(scores["variance"] < max_variance)
    if min_variance > 0.0:
        mask = mask & np.asarray(scores["variance"] > min_variance)

    # Optionally add KS criterion
    if max_ks is not None:
        mask = mask & np.asarray(scores["ks_distance"] < max_ks)

    return mask

ppc_goodness_of_fit_scores

ppc_goodness_of_fit_scores(ppc_samples, obs_counts, credible_level=95, max_bin=None)

Compute PPC-based per-gene goodness-of-fit scores.

For each gene the function compares the observed count histogram to posterior-predictive credible bands and produces two complementary metrics.

PARAMETER DESCRIPTION
ppc_samples

Posterior predictive count samples. S is the number of posterior draws, C the number of cells, G the number of genes.

TYPE: jnp.ndarray, shape ``(S, C, G)``

obs_counts

Observed UMI count matrix for the same cells and genes.

TYPE: jnp.ndarray, shape ``(C, G)``

credible_level

Width of the pointwise credible band (percentage). Default: 95.

TYPE: int DEFAULT: 95

max_bin

If set, histogram bins above this value are collapsed. Helps bound computation for heavy-tailed genes.

TYPE: int or None DEFAULT: None

RETURNS DESCRIPTION
dict

Dictionary with per-gene arrays of shape (G,):

calibration_failure Fraction of non-empty observed-histogram bins whose density falls outside the credible_level credible band. Under a well-specified model this should be close to 1 - credible_level / 100. l1_distance Sum of absolute differences between observed density and PPC median density across bins. Captures the magnitude of histogram-level misfit.

See Also

compute_ppc_gof_mask : High-level mask builder that wraps this scorer. goodness_of_fit_scores : RQR-based alternative. scribe.stats.histogram.compute_histogram_credible_regions : Underlying credible-region computation.

Source code in src/scribe/mc/_goodness_of_fit.py
def ppc_goodness_of_fit_scores(
    ppc_samples: jnp.ndarray,
    obs_counts: jnp.ndarray,
    credible_level: int = 95,
    max_bin: Optional[int] = None,
) -> Dict[str, np.ndarray]:
    """Compute PPC-based per-gene goodness-of-fit scores.

    For each gene the function compares the observed count histogram to
    posterior-predictive credible bands and produces two complementary
    metrics.

    Parameters
    ----------
    ppc_samples : jnp.ndarray, shape ``(S, C, G)``
        Posterior predictive count samples.  ``S`` is the number of
        posterior draws, ``C`` the number of cells, ``G`` the number of
        genes.
    obs_counts : jnp.ndarray, shape ``(C, G)``
        Observed UMI count matrix for the same cells and genes.
    credible_level : int, optional
        Width of the pointwise credible band (percentage).  Default: 95.
    max_bin : int or None, optional
        If set, histogram bins above this value are collapsed.  Helps
        bound computation for heavy-tailed genes.

    Returns
    -------
    dict
        Dictionary with per-gene arrays of shape ``(G,)``:

        ``calibration_failure``
            Fraction of non-empty observed-histogram bins whose density
            falls outside the ``credible_level`` credible band.  Under a
            well-specified model this should be close to
            ``1 - credible_level / 100``.
        ``l1_distance``
            Sum of absolute differences between observed density and PPC
            median density across bins.  Captures the magnitude of
            histogram-level misfit.

    See Also
    --------
    compute_ppc_gof_mask : High-level mask builder that wraps this scorer.
    goodness_of_fit_scores : RQR-based alternative.
    scribe.stats.histogram.compute_histogram_credible_regions :
        Underlying credible-region computation.
    """
    from scribe.stats.histogram import compute_histogram_credible_regions

    obs_counts = np.asarray(obs_counts)
    G = obs_counts.shape[1]

    cal_failures = np.empty(G, dtype=np.float64)
    l1_distances = np.empty(G, dtype=np.float64)

    for g in range(G):
        # PPC samples for this gene: (S, C)
        gene_ppc = ppc_samples[:, :, g]

        # Compute credible regions from PPC samples
        cr = compute_histogram_credible_regions(
            gene_ppc,
            credible_regions=[credible_level],
            normalize=True,
            max_bin=max_bin,
        )

        bin_edges = cr["bin_edges"]
        region = cr["regions"][credible_level]
        lower = region["lower"]
        upper = region["upper"]
        median = region["median"]

        # Observed histogram with the same bin edges, normalized
        obs_hist, _ = np.histogram(obs_counts[:, g], bins=bin_edges)
        obs_total = obs_hist.sum()
        if obs_total > 0:
            obs_density = obs_hist / obs_total
        else:
            obs_density = obs_hist.astype(np.float64)

        # Calibration failure rate: fraction of non-empty observed bins
        # that fall outside the credible band
        nonempty = obs_density > 0
        n_nonempty = nonempty.sum()
        if n_nonempty > 0:
            outside = (obs_density[nonempty] < lower[nonempty]) | (
                obs_density[nonempty] > upper[nonempty]
            )
            cal_failures[g] = outside.sum() / n_nonempty
        else:
            cal_failures[g] = 0.0

        # L1 distance between observed and PPC median density
        l1_distances[g] = np.sum(np.abs(obs_density - median))

    return {
        "calibration_failure": cal_failures,
        "l1_distance": l1_distances,
    }

compute_ppc_gof_mask

compute_ppc_gof_mask(counts, results, component=None, n_ppc_samples=500, gene_batch_size=50, rng_key=None, counts_for_ppc=None, cell_mask=None, max_calibration_failure=0.5, max_l1_distance=None, credible_level=95, cell_batch_size=500, max_bin=None, verbose=True, return_scores=False)

Build a per-gene PPC goodness-of-fit boolean mask.

This is the high-level entry point for PPC-based gene filtering. It generates posterior predictive samples in gene batches, scores each batch against the observed counts, and applies user-specified thresholds to produce a boolean mask.

PARAMETER DESCRIPTION
counts

Observed UMI counts for the cells classified into this model (or component). Used both for histogram comparison and for amortized capture models.

TYPE: jnp.ndarray, shape ``(C_model, G)``

results

Fitted model results object. Must expose get_posterior_ppc_samples and get_component.

TYPE: ScribeSVIResults

component

For mixture models, which component to evaluate. If None the results object is used directly.

TYPE: int or None DEFAULT: None

n_ppc_samples

Number of posterior draws. Default: 500.

TYPE: int DEFAULT: 500

gene_batch_size

Number of genes per batch. Controls peak memory. Default: 50.

TYPE: int DEFAULT: 50

rng_key

JAX PRNG key. Defaults to random.PRNGKey(0).

TYPE: ndarray or None DEFAULT: None

counts_for_ppc

Full count matrix (C_all, G) for amortized capture models. If None, counts is used.

TYPE: ndarray or None DEFAULT: None

cell_mask

Boolean mask (C_all,) to subset PPC samples to the cells in counts. Applied after generation.

TYPE: ndarray or None DEFAULT: None

max_calibration_failure

Upper bound on calibration failure rate. Genes exceeding this are masked out. Default: 0.5.

TYPE: float DEFAULT: 0.5

max_l1_distance

Upper bound on L1 density distance. If None only the calibration criterion is applied.

TYPE: float or None DEFAULT: None

credible_level

Credible band width (percentage) for calibration scoring. Default: 95.

TYPE: int DEFAULT: 95

cell_batch_size

Cell batch size passed to get_posterior_ppc_samples. Default: 500.

TYPE: int DEFAULT: 500

max_bin

Cap on histogram bin count (see ppc_goodness_of_fit_scores).

TYPE: int or None DEFAULT: None

verbose

Print progress messages. Default: True.

TYPE: bool DEFAULT: True

return_scores

If True also return the full per-gene score dictionary. Default: False.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
ndarray or tuple[ndarray, dict]

Boolean mask of shape (G,) (True = gene passes). When return_scores is True, returns (mask, scores_dict) where scores_dict has keys 'calibration_failure' and 'l1_distance', each of shape (G,).

See Also

ppc_goodness_of_fit_scores : Low-level scorer. compute_gof_mask : RQR-based alternative.

Source code in src/scribe/mc/_goodness_of_fit.py
def compute_ppc_gof_mask(
    counts: jnp.ndarray,
    results,
    component: Optional[int] = None,
    n_ppc_samples: int = 500,
    gene_batch_size: int = 50,
    rng_key: Optional[jnp.ndarray] = None,
    counts_for_ppc: Optional[jnp.ndarray] = None,
    cell_mask: Optional[np.ndarray] = None,
    max_calibration_failure: float = 0.5,
    max_l1_distance: Optional[float] = None,
    credible_level: int = 95,
    cell_batch_size: int = 500,
    max_bin: Optional[int] = None,
    verbose: bool = True,
    return_scores: bool = False,
) -> "np.ndarray | tuple[np.ndarray, Dict[str, np.ndarray]]":
    """Build a per-gene PPC goodness-of-fit boolean mask.

    This is the high-level entry point for PPC-based gene filtering.
    It generates posterior predictive samples in gene batches, scores
    each batch against the observed counts, and applies user-specified
    thresholds to produce a boolean mask.

    Parameters
    ----------
    counts : jnp.ndarray, shape ``(C_model, G)``
        Observed UMI counts for the cells classified into this model
        (or component).  Used both for histogram comparison and for
        amortized capture models.
    results : ScribeSVIResults
        Fitted model results object.  Must expose
        ``get_posterior_ppc_samples`` and ``get_component``.
    component : int or None, optional
        For mixture models, which component to evaluate.  If ``None``
        the results object is used directly.
    n_ppc_samples : int, optional
        Number of posterior draws.  Default: 500.
    gene_batch_size : int, optional
        Number of genes per batch.  Controls peak memory.  Default: 50.
    rng_key : jnp.ndarray or None, optional
        JAX PRNG key.  Defaults to ``random.PRNGKey(0)``.
    counts_for_ppc : jnp.ndarray or None, optional
        Full count matrix ``(C_all, G)`` for amortized capture models.
        If ``None``, ``counts`` is used.
    cell_mask : np.ndarray or None, optional
        Boolean mask ``(C_all,)`` to subset PPC samples to the cells
        in ``counts``.  Applied after generation.
    max_calibration_failure : float, optional
        Upper bound on calibration failure rate.  Genes exceeding this
        are masked out.  Default: 0.5.
    max_l1_distance : float or None, optional
        Upper bound on L1 density distance.  If ``None`` only the
        calibration criterion is applied.
    credible_level : int, optional
        Credible band width (percentage) for calibration scoring.
        Default: 95.
    cell_batch_size : int, optional
        Cell batch size passed to ``get_posterior_ppc_samples``.
        Default: 500.
    max_bin : int or None, optional
        Cap on histogram bin count (see ``ppc_goodness_of_fit_scores``).
    verbose : bool, optional
        Print progress messages.  Default: ``True``.
    return_scores : bool, optional
        If ``True`` also return the full per-gene score dictionary.
        Default: ``False``.

    Returns
    -------
    np.ndarray or tuple[np.ndarray, dict]
        Boolean mask of shape ``(G,)`` (``True`` = gene passes).
        When ``return_scores`` is ``True``, returns
        ``(mask, scores_dict)`` where ``scores_dict`` has keys
        ``'calibration_failure'`` and ``'l1_distance'``, each of
        shape ``(G,)``.

    See Also
    --------
    ppc_goodness_of_fit_scores : Low-level scorer.
    compute_gof_mask : RQR-based alternative.
    """
    if rng_key is None:
        rng_key = random.PRNGKey(0)

    # Get the appropriate component result object
    comp = (
        results.get_component(component)
        if component is not None
        else results
    )

    ppc_counts = counts_for_ppc if counts_for_ppc is not None else counts
    n_genes = counts.shape[1]

    # Accumulate per-gene scores across batches
    all_cal = []
    all_l1 = []

    n_batches = (n_genes + gene_batch_size - 1) // gene_batch_size

    for batch_idx in range(n_batches):
        g_start = batch_idx * gene_batch_size
        g_end = min(g_start + gene_batch_size, n_genes)
        gene_indices = jnp.arange(g_start, g_end)

        if verbose:
            print(
                f"PPC GoF batch {batch_idx + 1}/{n_batches}: "
                f"genes [{g_start}, {g_end})"
            )

        # Split key per batch so results are reproducible
        rng_key, batch_key = random.split(rng_key)

        # Generate PPC samples for this gene batch: (S, C, G_batch)
        ppc = comp.get_posterior_ppc_samples(
            gene_indices=gene_indices,
            n_samples=n_ppc_samples,
            cell_batch_size=cell_batch_size,
            rng_key=batch_key,
            counts=ppc_counts,
            store_samples=False,
            verbose=False,
        )

        # Optionally subset cells
        if cell_mask is not None:
            ppc = ppc[:, cell_mask, :]

        # Score this batch
        batch_scores = ppc_goodness_of_fit_scores(
            ppc_samples=ppc,
            obs_counts=counts[:, g_start:g_end],
            credible_level=credible_level,
            max_bin=max_bin,
        )

        all_cal.append(batch_scores["calibration_failure"])
        all_l1.append(batch_scores["l1_distance"])

        # Free batch PPC memory
        del ppc

    # Clear cached posterior to free memory
    comp.posterior_samples = None

    # Concatenate across batches
    cal = np.concatenate(all_cal)
    l1 = np.concatenate(all_l1)

    # Build mask
    mask = cal <= max_calibration_failure
    if max_l1_distance is not None:
        mask = mask & (l1 <= max_l1_distance)

    if verbose:
        n_pass = mask.sum()
        print(
            f"PPC GoF mask: {n_pass}/{n_genes} genes pass "
            f"(calibration <= {max_calibration_failure}"
            + (f", L1 <= {max_l1_distance}" if max_l1_distance else "")
            + ")"
        )

    if return_scores:
        return mask, {
            "calibration_failure": cal,
            "l1_distance": l1,
        }
    return mask