Skip to content

mcmc

mcmc

Markov Chain Monte Carlo (MCMC) module for single-cell RNA sequencing data analysis.

This module implements MCMC inference for SCRIBE models using Numpyro's NUTS.

MCMCInferenceEngine

Handles MCMC inference execution.

run_inference staticmethod

run_inference(model_config, count_data, n_cells, n_genes, n_samples=2000, n_warmup=1000, n_chains=1, seed=42, mcmc_kwargs=None, annotation_prior_logits=None, dataset_indices=None, init_values=None)

Execute MCMC inference using NUTS.

PARAMETER DESCRIPTION
model_config

Model configuration object.

TYPE: ModelConfig

count_data

Processed count data (cells as rows).

TYPE: ndarray

n_cells

Number of cells.

TYPE: int

n_genes

Number of genes.

TYPE: int

n_samples

Number of MCMC samples.

TYPE: int DEFAULT: 2_000

n_warmup

Number of warmup samples.

TYPE: int DEFAULT: 1_000

n_chains

Number of parallel chains.

TYPE: int DEFAULT: 1

seed

Random seed for reproducibility.

TYPE: int DEFAULT: 42

mcmc_kwargs

Keyword arguments for the NUTS kernel (e.g., target_accept_prob, max_tree_depth).

TYPE: Optional[dict] DEFAULT: None

annotation_prior_logits

Prior logits for annotation-guided mixture models.

TYPE: Optional[ndarray] DEFAULT: None

init_values

Constrained-space values to initialize MCMC chains via init_to_value. Typically obtained from compute_init_values(svi_results.get_map(...), model_config). When provided, an init_strategy is constructed and merged into the NUTS kernel kwargs. If mcmc_kwargs already contains an init_strategy, a warning is emitted and the existing strategy is overridden.

TYPE: Optional[Dict[str, ndarray]] DEFAULT: None

RETURNS DESCRIPTION
MCMC

Results from the MCMC run containing samples and diagnostics.

Source code in src/scribe/mcmc/inference_engine.py
@staticmethod
def run_inference(
    model_config: ModelConfig,
    count_data: jnp.ndarray,
    n_cells: int,
    n_genes: int,
    n_samples: int = 2_000,
    n_warmup: int = 1_000,
    n_chains: int = 1,
    seed: int = 42,
    mcmc_kwargs: Optional[dict] = None,
    annotation_prior_logits: Optional[jnp.ndarray] = None,
    dataset_indices: Optional[jnp.ndarray] = None,
    init_values: Optional[Dict[str, jnp.ndarray]] = None,
) -> Any:
    """Execute MCMC inference using NUTS.

    Parameters
    ----------
    model_config : ModelConfig
        Model configuration object.
    count_data : jnp.ndarray
        Processed count data (cells as rows).
    n_cells : int
        Number of cells.
    n_genes : int
        Number of genes.
    n_samples : int, default=2_000
        Number of MCMC samples.
    n_warmup : int, default=1_000
        Number of warmup samples.
    n_chains : int, default=1
        Number of parallel chains.
    seed : int, default=42
        Random seed for reproducibility.
    mcmc_kwargs : Optional[dict], default=None
        Keyword arguments for the NUTS kernel (e.g.,
        ``target_accept_prob``, ``max_tree_depth``).
    annotation_prior_logits : Optional[jnp.ndarray], default=None
        Prior logits for annotation-guided mixture models.
    init_values : Optional[Dict[str, jnp.ndarray]], default=None
        Constrained-space values to initialize MCMC chains via
        ``init_to_value``.  Typically obtained from
        ``compute_init_values(svi_results.get_map(...), model_config)``.
        When provided, an ``init_strategy`` is constructed and merged
        into the NUTS kernel kwargs.  If ``mcmc_kwargs`` already
        contains an ``init_strategy``, a warning is emitted and the
        existing strategy is overridden.

    Returns
    -------
    numpyro.infer.MCMC
        Results from the MCMC run containing samples and diagnostics.
    """
    # Get model function (no guide needed for MCMC)
    model, _, _ = get_model_and_guide(model_config, guide_families=None)

    # Build effective NUTS kwargs, optionally injecting init_to_value
    effective_kwargs: Dict[str, Any] = dict(mcmc_kwargs or {})
    if init_values is not None:
        from numpyro.infer.initialization import init_to_value

        if "init_strategy" in effective_kwargs:
            warnings.warn(
                "init_values overrides the existing init_strategy "
                "in mcmc_kwargs.",
                UserWarning,
                stacklevel=2,
            )
        effective_kwargs["init_strategy"] = init_to_value(
            values=init_values
        )

    # Create NUTS sampler with the (possibly augmented) kwargs
    nuts_kernel = NUTS(model, **effective_kwargs)

    # Create MCMC instance
    mcmc = MCMC(
        nuts_kernel,
        num_samples=n_samples,
        num_warmup=n_warmup,
        num_chains=n_chains,
    )

    # Create random number generator key
    rng_key = random.PRNGKey(seed)

    # Prepare model arguments
    model_args = {
        "n_cells": n_cells,
        "n_genes": n_genes,
        "counts": count_data,
        "model_config": model_config,
        "annotation_prior_logits": annotation_prior_logits,
        "dataset_indices": dataset_indices,
    }

    # Run inference
    mcmc.run(rng_key, **model_args)

    return mcmc

ScribeMCMCResults dataclass

ScribeMCMCResults(samples, n_cells, n_genes, model_type, model_config, prior_params, obs=None, var=None, uns=None, n_obs=None, n_vars=None, predictive_samples=None, n_components=None, denoised_counts=None, _n_cells_per_dataset=None, _dataset_indices=None, _promoted_dataset_keys=None, _mcmc=None)

Bases: ParameterExtractionMixin, GeneSubsettingMixin, ComponentMixin, DatasetMixin, ModelHelpersMixin, SamplingMixin, LikelihoodMixin, NormalizationMixin, MixtureAnalysisMixin

SCRIBE MCMC results.

Stores posterior samples and provides analysis methods via mixins. The underlying numpyro.infer.MCMC object is wrapped (composition) rather than inherited, so gene/component subsetting always returns another ScribeMCMCResults instance.

ATTRIBUTE DESCRIPTION
samples

Raw posterior samples keyed by parameter name.

TYPE: Dict

n_cells

Number of cells in the dataset.

TYPE: int

n_genes

Number of genes in the dataset.

TYPE: int

model_type

Model identifier (e.g. "nbdm", "zinb_mix").

TYPE: str

model_config

Configuration used for inference.

TYPE: ModelConfig

prior_params

Prior parameter values used during inference.

TYPE: Dict[str, Any]

obs

Cell-level metadata from adata.obs.

TYPE: Optional[DataFrame]

var

Gene-level metadata from adata.var.

TYPE: Optional[DataFrame]

uns

Unstructured metadata from adata.uns.

TYPE: Optional[Dict]

n_obs

Number of observations (cells).

TYPE: Optional[int]

n_vars

Number of variables (genes).

TYPE: Optional[int]

predictive_samples

Predictive samples from :meth:get_ppc_samples.

TYPE: Optional[ndarray]

n_components

Number of mixture components (None for non-mixture models).

TYPE: Optional[int]

denoised_counts

Denoised counts from :meth:denoise_counts.

TYPE: Optional[ndarray]

_mcmc

Wrapped numpyro.infer.MCMC object for diagnostics. None on subsets produced by gene/component indexing.

TYPE: Optional[Any]

posterior_samples property

posterior_samples

Posterior samples (read-only property).

concat classmethod

concat(results_list, *, align_genes='assume_aligned', join='cells', check_model=True, validation='var_only')

Concatenate multiple MCMC results objects along the cell axis.

The method supports combining objects that represent the same model and gene space while differing in cell count. Cell-specific posterior samples are concatenated along the cell axis, while non-cell-specific samples must be identical across inputs.

PARAMETER DESCRIPTION
results_list

Results objects to concatenate. At least two elements are required.

TYPE: list of ScribeMCMCResults

align_genes

Gene-alignment strategy. "strict" requires matching gene sets; if all objects include var, differing gene order is resolved by reordering to match the first object. "assume_aligned" skips gene-set/order validation and assumes all inputs are already aligned.

TYPE: (strict, assume_aligned) DEFAULT: "strict"

join

Concatenation axis. Only cell-axis concatenation is supported.

TYPE: cells DEFAULT: "cells"

check_model

If True, require exact agreement for model-level fields (model_type, model_config, and prior_params).

TYPE: bool DEFAULT: True

validation

Validation policy for non-cell-specific fields. "strict" enforces deep equality checks for shared sample sites and metadata. "var_only" performs fast key-level checks and relies on the user-trusted model fit plus gene validation from var (or n_genes when var is missing), taking non-cell-specific values from the first object.

TYPE: (strict, var_only) DEFAULT: "strict"

RETURNS DESCRIPTION
ScribeMCMCResults

Concatenated MCMC results with _mcmc=None.

RAISES DESCRIPTION
ValueError

If inputs are empty, incompatible, or use unsupported options.

TypeError

If inputs are not all ScribeMCMCResults instances.

Source code in src/scribe/mcmc/results.py
@classmethod
def concat(
    cls,
    results_list: List["ScribeMCMCResults"],
    *,
    align_genes: str = "assume_aligned",
    join: str = "cells",
    check_model: bool = True,
    validation: str = "var_only",
) -> "ScribeMCMCResults":
    """Concatenate multiple MCMC results objects along the cell axis.

    The method supports combining objects that represent the same model and
    gene space while differing in cell count. Cell-specific posterior
    samples are concatenated along the cell axis, while non-cell-specific
    samples must be identical across inputs.

    Parameters
    ----------
    results_list : list of ScribeMCMCResults
        Results objects to concatenate. At least two elements are required.
    align_genes : {"strict", "assume_aligned"}, default="strict"
        Gene-alignment strategy. ``"strict"`` requires matching gene sets;
        if all objects include ``var``, differing gene order is resolved by
        reordering to match the first object. ``"assume_aligned"`` skips
        gene-set/order validation and assumes all inputs are already aligned.
    join : {"cells"}, default="cells"
        Concatenation axis. Only cell-axis concatenation is supported.
    check_model : bool, default=True
        If ``True``, require exact agreement for model-level fields
        (``model_type``, ``model_config``, and ``prior_params``).
    validation : {"strict", "var_only"}, default="strict"
        Validation policy for non-cell-specific fields.
        ``"strict"`` enforces deep equality checks for shared sample sites
        and metadata. ``"var_only"`` performs fast key-level checks and
        relies on the user-trusted model fit plus gene validation from
        ``var`` (or ``n_genes`` when ``var`` is missing), taking
        non-cell-specific values from the first object.

    Returns
    -------
    ScribeMCMCResults
        Concatenated MCMC results with ``_mcmc=None``.

    Raises
    ------
    ValueError
        If inputs are empty, incompatible, or use unsupported options.
    TypeError
        If inputs are not all ``ScribeMCMCResults`` instances.
    """
    # Guard against accidentally passing a single results object instead of
    # a list/tuple. Because results support ``__getitem__``, iterating over
    # a single object can trigger expensive implicit indexing.
    if isinstance(results_list, cls):
        raise TypeError(
            "results_list must be a sequence of results, e.g. "
            "ScribeMCMCResults.concat([res_a, res_b])."
        )
    if not results_list or len(results_list) < 2:
        raise ValueError(
            "results_list must contain at least two elements. "
            "Note: concat is a classmethod — call "
            "ScribeMCMCResults.concat([res_a, res_b]), not "
            "res_a.concat([res_b])."
        )
    if join != "cells":
        raise ValueError("Only join='cells' is currently supported.")
    if align_genes not in {"strict", "assume_aligned"}:
        raise ValueError(
            "align_genes must be one of {'strict', 'assume_aligned'}."
        )
    if validation not in {"strict", "var_only"}:
        raise ValueError(
            "validation must be one of {'strict', 'var_only'}."
        )

    for idx, res in enumerate(results_list):
        if not isinstance(res, cls):
            raise TypeError(
                f"All entries must be {cls.__name__}; got {type(res)} at index {idx}."
            )

    reference = results_list[0]
    if check_model:
        _validate_mcmc_model_compatibility(reference, results_list)

    aligned_results = (
        _align_mcmc_genes_to_reference(
            reference=reference, results_list=results_list
        )
        if align_genes == "strict"
        else results_list
    )
    first = aligned_results[0]

    _validate_equal_mcmc_sample_sizes(aligned_results)

    # Classify cell-specific sample sites from ParamSpec metadata so only
    # cell-axis variables are concatenated along axis 1.
    cell_sample_keys = _build_cell_specific_keys(
        first.model_config.param_specs or [],
        first.samples,
    )
    samples = _concat_mcmc_samples(
        sample_dicts=[res.samples for res in aligned_results],
        cell_specific_keys=cell_sample_keys,
        strict=(validation == "strict"),
    )

    obs = _concat_optional_obs([res.obs for res in aligned_results])
    var = first.var.copy() if first.var is not None else None
    uns = _merge_optional_uns(
        [res.uns for res in aligned_results],
        strict=(validation == "strict"),
    )

    n_cells_total = int(sum(res.n_cells for res in aligned_results))
    n_obs_total = (
        int(sum(res.n_obs for res in aligned_results))
        if all(res.n_obs is not None for res in aligned_results)
        else (obs.shape[0] if obs is not None else None)
    )

    # --- dataset metadata: merge existing or promote single-dataset ---
    n_cells_per_dataset = _merge_cells_per_dataset(
        [
            getattr(res, "_n_cells_per_dataset", None)
            for res in aligned_results
        ]
    )
    dataset_indices = _concat_dataset_indices(
        [
            getattr(res, "_dataset_indices", None)
            for res in aligned_results
        ]
    )

    # When all inputs are single-dataset and we are combining more than
    # one, promote to a multi-dataset result so that ``get_dataset(i)``
    # can retrieve the i-th original result's cells.
    n_inputs = len(aligned_results)
    combined_config = first.model_config
    promoted_dataset_keys = None
    if n_cells_per_dataset is None and n_inputs > 1:
        n_cells_per_dataset = jnp.array(
            [int(res.n_cells) for res in aligned_results],
            dtype=jnp.int32,
        )
        dataset_indices = jnp.concatenate(
            [
                jnp.full(int(res.n_cells), i, dtype=jnp.int32)
                for i, res in enumerate(aligned_results)
            ]
        )
        combined_config = first.model_config.model_copy(
            update={"n_datasets": n_inputs}
        )

        # Stack non-cell-specific samples along a new dataset axis (1,
        # after the sample axis at 0) so get_dataset(i) can recover
        # per-dataset values.
        promoted_dataset_keys = set()
        for key in samples:
            if key not in cell_sample_keys:
                stacked = jnp.stack(
                    [res.samples[key] for res in aligned_results],
                    axis=1,
                )
                samples[key] = stacked
                promoted_dataset_keys.add(key)

    return cls(
        samples=samples,
        n_cells=n_cells_total,
        n_genes=first.n_genes,
        model_type=first.model_type,
        model_config=combined_config,
        prior_params=first.prior_params,
        obs=obs,
        var=var,
        uns=uns,
        n_obs=n_obs_total,
        n_vars=first.n_genes,
        predictive_samples=None,
        n_components=first.n_components,
        denoised_counts=None,
        _n_cells_per_dataset=n_cells_per_dataset,
        _dataset_indices=dataset_indices,
        _promoted_dataset_keys=promoted_dataset_keys,
        _mcmc=None,
    )

__post_init__

__post_init__()

Validate model configuration and set derived attributes.

Source code in src/scribe/mcmc/results.py
def __post_init__(self):
    """Validate model configuration and set derived attributes."""
    if (
        self.n_components is None
        and self.model_config.n_components is not None
    ):
        self.n_components = self.model_config.n_components

    self._validate_model_config()

from_mcmc classmethod

from_mcmc(mcmc, n_cells, n_genes, model_type, model_config, prior_params, **kwargs)

Create results from an existing numpyro.infer.MCMC instance.

Extracts samples once and stores the MCMC object for diagnostics.

PARAMETER DESCRIPTION
mcmc

Completed MCMC run.

TYPE: MCMC

n_cells

Number of cells.

TYPE: int

n_genes

Number of genes.

TYPE: int

model_type

Model identifier.

TYPE: str

model_config

Model configuration.

TYPE: ModelConfig

prior_params

Prior parameter values.

TYPE: Dict[str, Any]

**kwargs

Forwarded to the dataclass constructor (e.g. obs, var).

DEFAULT: {}

RETURNS DESCRIPTION
ScribeMCMCResults
Source code in src/scribe/mcmc/results.py
@classmethod
def from_mcmc(
    cls,
    mcmc,
    n_cells: int,
    n_genes: int,
    model_type: str,
    model_config: ModelConfig,
    prior_params: Dict[str, Any],
    **kwargs,
) -> "ScribeMCMCResults":
    """Create results from an existing ``numpyro.infer.MCMC`` instance.

    Extracts samples once and stores the MCMC object for diagnostics.

    Parameters
    ----------
    mcmc : numpyro.infer.MCMC
        Completed MCMC run.
    n_cells : int
        Number of cells.
    n_genes : int
        Number of genes.
    model_type : str
        Model identifier.
    model_config : ModelConfig
        Model configuration.
    prior_params : Dict[str, Any]
        Prior parameter values.
    **kwargs
        Forwarded to the dataclass constructor (e.g. ``obs``, ``var``).

    Returns
    -------
    ScribeMCMCResults
    """
    return cls(
        samples=mcmc.get_samples(group_by_chain=False),
        n_cells=n_cells,
        n_genes=n_genes,
        model_type=model_type,
        model_config=model_config,
        prior_params=prior_params,
        _mcmc=mcmc,
        **kwargs,
    )

from_anndata classmethod

from_anndata(mcmc, adata, model_type, model_config, prior_params, **kwargs)

Create results from an MCMC instance and AnnData object.

PARAMETER DESCRIPTION
mcmc

Completed MCMC run.

TYPE: MCMC

adata

AnnData object with cell/gene metadata.

TYPE: AnnData

model_type

Model identifier.

TYPE: str

model_config

Model configuration.

TYPE: ModelConfig

prior_params

Prior parameter values.

TYPE: Dict[str, Any]

**kwargs

Forwarded to the dataclass constructor.

DEFAULT: {}

RETURNS DESCRIPTION
ScribeMCMCResults
Source code in src/scribe/mcmc/results.py
@classmethod
def from_anndata(
    cls,
    mcmc,
    adata,
    model_type: str,
    model_config: ModelConfig,
    prior_params: Dict[str, Any],
    **kwargs,
) -> "ScribeMCMCResults":
    """Create results from an MCMC instance and AnnData object.

    Parameters
    ----------
    mcmc : numpyro.infer.MCMC
        Completed MCMC run.
    adata : AnnData
        AnnData object with cell/gene metadata.
    model_type : str
        Model identifier.
    model_config : ModelConfig
        Model configuration.
    prior_params : Dict[str, Any]
        Prior parameter values.
    **kwargs
        Forwarded to the dataclass constructor.

    Returns
    -------
    ScribeMCMCResults
    """
    return cls.from_mcmc(
        mcmc=mcmc,
        n_cells=adata.n_obs,
        n_genes=adata.n_vars,
        model_type=model_type,
        model_config=model_config,
        prior_params=prior_params,
        obs=adata.obs.copy(),
        var=adata.var.copy(),
        uns=adata.uns.copy(),
        n_obs=adata.n_obs,
        n_vars=adata.n_vars,
        **kwargs,
    )

get_posterior_samples

get_posterior_samples(descriptive_names=False)

Return posterior samples.

MCMC samples already contain canonical parameters (p, r, mixing_weights, etc.) because derived parameters are registered as numpyro.deterministic sites and unconstrained specs sample via TransformedDistribution in constrained space.

PARAMETER DESCRIPTION
descriptive_names

If True, rename dict keys from internal short names to user-friendly descriptive names.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
Dict

Parameter name -> sample array.

Source code in src/scribe/mcmc/results.py
def get_posterior_samples(self, descriptive_names: bool = False) -> Dict:
    """Return posterior samples.

    MCMC samples already contain canonical parameters (``p``, ``r``,
    ``mixing_weights``, etc.) because derived parameters are
    registered as ``numpyro.deterministic`` sites and unconstrained
    specs sample via ``TransformedDistribution`` in constrained
    space.

    Parameters
    ----------
    descriptive_names : bool, default=False
        If True, rename dict keys from internal short names to
        user-friendly descriptive names.

    Returns
    -------
    Dict
        Parameter name -> sample array.
    """
    from ..models.config.parameter_mapping import rename_dict_keys

    return rename_dict_keys(self.samples, descriptive_names)

get_samples

get_samples(group_by_chain=False)

Return samples with optional chain grouping.

PARAMETER DESCRIPTION
group_by_chain

Preserve the chain dimension (requires the original MCMC object).

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
Dict

Parameter samples.

Source code in src/scribe/mcmc/results.py
def get_samples(self, group_by_chain: bool = False) -> Dict:
    """Return samples with optional chain grouping.

    Parameters
    ----------
    group_by_chain : bool, default=False
        Preserve the chain dimension (requires the original MCMC
        object).

    Returns
    -------
    Dict
        Parameter samples.
    """
    if group_by_chain:
        if self._mcmc is None:
            raise RuntimeError(
                "group_by_chain requires the original MCMC object "
                "(not available on subsets)."
            )
        return self._mcmc.get_samples(group_by_chain=True)
    return self.samples

print_summary

print_summary(**kwargs)

Print MCMC summary statistics (delegates to the wrapped MCMC).

RAISES DESCRIPTION
RuntimeError

If the MCMC object is not available (e.g. on subsets).

Source code in src/scribe/mcmc/results.py
def print_summary(self, **kwargs):
    """Print MCMC summary statistics (delegates to the wrapped MCMC).

    Raises
    ------
    RuntimeError
        If the MCMC object is not available (e.g. on subsets).
    """
    if self._mcmc is None:
        raise RuntimeError(
            "print_summary requires the original MCMC object "
            "(not available on subsets)."
        )
    self._mcmc.print_summary(**kwargs)

get_extra_fields

get_extra_fields(**kwargs)

Return MCMC extra fields (e.g. potential_energy, diverging).

Returns an empty dict when the MCMC object is not available (subsets).

Source code in src/scribe/mcmc/results.py
def get_extra_fields(self, **kwargs) -> Dict:
    """Return MCMC extra fields (e.g. potential_energy, diverging).

    Returns an empty dict when the MCMC object is not available
    (subsets).
    """
    if self._mcmc is None:
        return {}
    return self._mcmc.get_extra_fields(**kwargs)

__getstate__

__getstate__()

Return pickle-safe state for ScribeMCMCResults.

Notes

The wrapped _mcmc object retains local closure functions from model building and is intentionally dropped to ensure portability.

Source code in src/scribe/mcmc/results.py
def __getstate__(self) -> Dict[str, Any]:
    """Return pickle-safe state for ``ScribeMCMCResults``.

    Notes
    -----
    The wrapped ``_mcmc`` object retains local closure functions from model
    building and is intentionally dropped to ensure portability.
    """
    state = dict(self.__dict__)
    state["_mcmc"] = None
    state["model_config"] = make_model_config_pickle_safe(
        state.get("model_config")
    )
    return state

__setstate__

__setstate__(state)

Restore instance state after unpickling.

Source code in src/scribe/mcmc/results.py
def __setstate__(self, state: Dict[str, Any]) -> None:
    """Restore instance state after unpickling."""
    self.__dict__.update(state)

MCMCResultsFactory

Factory for creating MCMC results objects.

create_results staticmethod

create_results(mcmc_results, model_config, adata, count_data, n_cells, n_genes, model_type, n_components, prior_params)

Package MCMC results into a ScribeMCMCResults object.

PARAMETER DESCRIPTION
mcmc_results

Raw MCMC results from NumPyro.

TYPE: MCMC

model_config

Model configuration object.

TYPE: ModelConfig

adata

Original AnnData object (if provided).

TYPE: Optional[AnnData]

count_data

Processed count data.

TYPE: ndarray

n_cells

Number of cells.

TYPE: int

n_genes

Number of genes.

TYPE: int

model_type

Type of model.

TYPE: str

n_components

Number of mixture components.

TYPE: Optional[int]

prior_params

Dictionary of prior parameters.

TYPE: Dict[str, Any]

RETURNS DESCRIPTION
ScribeMCMCResults

Packaged results object.

Source code in src/scribe/mcmc/results_factory.py
@staticmethod
def create_results(
    mcmc_results: Any,
    model_config: ModelConfig,
    adata: Optional["AnnData"],
    count_data: jnp.ndarray,
    n_cells: int,
    n_genes: int,
    model_type: str,
    n_components: Optional[int],
    prior_params: Dict[str, Any],
) -> ScribeMCMCResults:
    """Package MCMC results into a ``ScribeMCMCResults`` object.

    Parameters
    ----------
    mcmc_results : numpyro.infer.MCMC
        Raw MCMC results from NumPyro.
    model_config : ModelConfig
        Model configuration object.
    adata : Optional[AnnData]
        Original AnnData object (if provided).
    count_data : jnp.ndarray
        Processed count data.
    n_cells : int
        Number of cells.
    n_genes : int
        Number of genes.
    model_type : str
        Type of model.
    n_components : Optional[int]
        Number of mixture components.
    prior_params : Dict[str, Any]
        Dictionary of prior parameters.

    Returns
    -------
    ScribeMCMCResults
        Packaged results object.
    """
    if adata is not None:
        return ScribeMCMCResults.from_anndata(
            mcmc=mcmc_results,
            adata=adata,
            model_type=model_type,
            model_config=model_config,
            n_components=n_components,
            prior_params=prior_params,
        )

    return ScribeMCMCResults.from_mcmc(
        mcmc=mcmc_results,
        n_cells=n_cells,
        n_genes=n_genes,
        model_type=model_type,
        model_config=model_config,
        n_components=n_components,
        prior_params=prior_params,
    )

clamp_init_values

clamp_init_values(init)

Clamp init values away from distribution support boundaries.

SVI MAP estimates (stored in float32) can land exactly on support boundaries — e.g. phi_capture = 0.0 or p = 1.0 — where the log-probability is -inf. This makes init_to_value reject the initialization.

PARAMETER DESCRIPTION
init

Init values keyed by parameter name.

TYPE: Dict[str, ndarray]

RETURNS DESCRIPTION
Dict[str, ndarray]

A shallow copy with boundary values nudged into the interior.

Source code in src/scribe/mcmc/_init_from_svi.py
def clamp_init_values(
    init: Dict[str, jnp.ndarray],
) -> Dict[str, jnp.ndarray]:
    """Clamp init values away from distribution support boundaries.

    SVI MAP estimates (stored in float32) can land exactly on support
    boundaries — e.g. ``phi_capture = 0.0`` or ``p = 1.0`` — where the
    log-probability is ``-inf``.  This makes ``init_to_value`` reject
    the initialization.

    Parameters
    ----------
    init : Dict[str, jnp.ndarray]
        Init values keyed by parameter name.

    Returns
    -------
    Dict[str, jnp.ndarray]
        A shallow copy with boundary values nudged into the interior.
    """
    out = dict(init)
    for name, arr in out.items():
        support = _SUPPORT.get(name)
        if support == "unit":
            out[name] = jnp.clip(arr, _EPS, 1.0 - _EPS)
        elif support == "positive":
            out[name] = jnp.clip(arr, _EPS, None)
    return out

compute_init_values

compute_init_values(svi_map, target_config)

Compute MCMC init values from SVI MAP estimates.

Ensures the returned dict contains constrained-space values for all sampled sites of the target model's parameterization. Missing parameters are derived from the canonical pair (p, r) which is always present when get_map(canonical=True) is used.

PARAMETER DESCRIPTION
svi_map

MAP estimates from SVI, typically obtained via svi_results.get_map(use_mean=True, canonical=True). Must contain at least "p" and "r".

TYPE: Dict[str, ndarray]

target_config

Model configuration for the target MCMC run.

TYPE: ModelConfig

RETURNS DESCRIPTION
Dict[str, ndarray]

Init values keyed by site name, all in constrained space. Includes the original SVI MAP entries plus any derived parameters needed by the target parameterization.

RAISES DESCRIPTION
ValueError

If canonical parameters p and r are missing from svi_map and cannot be derived.

Notes

init_to_value only initializes numpyro.sample sites. Extra keys in the returned dict (e.g. deterministic sites r when the target is mean_prob) are harmlessly ignored by NumPyro.

Hierarchical hyperparameters (logit_p_loc, log_phi_scale, etc.) live in different spaces across parameterizations and cannot be reliably converted. They are omitted and will fall back to init_to_uniform inside NumPyro.

Source code in src/scribe/mcmc/_init_from_svi.py
def compute_init_values(
    svi_map: Dict[str, jnp.ndarray],
    target_config: ModelConfig,
) -> Dict[str, jnp.ndarray]:
    """Compute MCMC init values from SVI MAP estimates.

    Ensures the returned dict contains constrained-space values for all
    sampled sites of the target model's parameterization.  Missing
    parameters are derived from the canonical pair ``(p, r)`` which is
    always present when ``get_map(canonical=True)`` is used.

    Parameters
    ----------
    svi_map : Dict[str, jnp.ndarray]
        MAP estimates from SVI, typically obtained via
        ``svi_results.get_map(use_mean=True, canonical=True)``.
        Must contain at least ``"p"`` and ``"r"``.
    target_config : ModelConfig
        Model configuration for the target MCMC run.

    Returns
    -------
    Dict[str, jnp.ndarray]
        Init values keyed by site name, all in constrained space.
        Includes the original SVI MAP entries plus any derived
        parameters needed by the target parameterization.

    Raises
    ------
    ValueError
        If canonical parameters ``p`` and ``r`` are missing from
        *svi_map* and cannot be derived.

    Notes
    -----
    ``init_to_value`` only initializes ``numpyro.sample`` sites.  Extra
    keys in the returned dict (e.g. deterministic sites ``r`` when the
    target is ``mean_prob``) are harmlessly ignored by NumPyro.

    Hierarchical hyperparameters (``logit_p_loc``, ``log_phi_scale``,
    etc.) live in different spaces across parameterizations and cannot
    be reliably converted.  They are omitted and will fall back to
    ``init_to_uniform`` inside NumPyro.
    """
    init = dict(svi_map)
    target_param = target_config.parameterization

    # ------------------------------------------------------------------
    # Ensure canonical (p, r) are present
    # ------------------------------------------------------------------
    if "p" not in init and "phi" in init:
        init["p"] = jnp.clip(1.0 / (1.0 + init["phi"]), _EPS, 1.0 - _EPS)
    if "r" not in init:
        if "mu" in init and "p" in init:
            p = init["p"]
            init["r"] = jnp.clip(init["mu"] * (1.0 - p) / p, _EPS, None)
        elif "mu" in init and "phi" in init:
            init["r"] = jnp.clip(init["mu"] * init["phi"], _EPS, None)

    if "p" not in init or "r" not in init:
        raise ValueError(
            "SVI MAP must contain canonical parameters 'p' and 'r' "
            "(or enough information to derive them).  "
            "Use svi_results.get_map(canonical=True)."
        )

    # Clamp canonical values away from support boundaries.  SVI MAP
    # estimates can land exactly on boundaries (e.g. p = 1.0) which
    # makes derived quantities like mu = r*p/(1-p) blow up.
    init["p"] = jnp.clip(init["p"], _EPS, 1.0 - _EPS)
    init["r"] = jnp.clip(init["r"], _EPS, None)

    p = init["p"]
    r = init["r"]

    # ------------------------------------------------------------------
    # Derive missing core parameters for the target parameterization.
    # All derived values are clamped away from distribution support
    # boundaries so that NumPyro's init_to_value can compute finite
    # log-probabilities.
    # ------------------------------------------------------------------
    if (
        target_param in _MEAN_ODDS_PARAMETERIZATIONS
        or target_param in _MEAN_PROB_PARAMETERIZATIONS
    ):
        if "mu" not in init:
            init["mu"] = jnp.clip(r * p / (1.0 - p), _EPS, None)

    if target_param in _MEAN_ODDS_PARAMETERIZATIONS:
        if "phi" not in init:
            init["phi"] = jnp.clip((1.0 - p) / p, _EPS, None)

    # ------------------------------------------------------------------
    # Handle VCP capture-parameter name differences
    # ------------------------------------------------------------------
    _convert_capture_params(init, target_config)

    return init