svi

Stochastic Variational Inference (SVI) module for single-cell RNA sequencing data analysis.

This module implements SVI inference for SCRIBE models using Numpyro’s SVI.

class scribe.svi.SVIInferenceEngine[source]

Bases: object

Handles SVI inference execution.

static run_inference(model_config, count_data, n_cells, n_genes, optimizer=<numpyro.optim.Adam object>, loss=<numpyro.infer.elbo.TraceMeanField_ELBO object>, n_steps=100000, batch_size=None, seed=42, stable_update=True)[source]

Execute SVI inference.

Parameters:
  • model_config (ModelConfig) – Model configuration object

  • count_data (jnp.ndarray) – Processed count data (cells as rows)

  • n_cells (int) – Number of cells

  • n_genes (int) – Number of genes

  • optimizer (numpyro.optim.optimizers, default=Adam(step_size=0.001)) – Optimizer for variational inference

  • loss (numpyro.infer.elbo, default=TraceMeanField_ELBO()) – Loss function for variational inference

  • n_steps (int, default=100_000) – Number of optimization steps

  • batch_size (Optional[int], default=None) – Mini-batch size. If None, uses full dataset.

  • seed (int, default=42) – Random seed for reproducibility

  • stable_update (bool, default=True) – Whether to use numerically stable parameter updates

Returns:

Results from SVI run containing optimized parameters and loss history

Return type:

numpyro.infer.svi.SVIRunResult

class scribe.svi.SVIResultsFactory[source]

Bases: object

Factory for creating SVI results objects.

static create_results(svi_results, model_config, adata, count_data, n_cells, n_genes, model_type, n_components, prior_params)[source]

Package SVI results into ScribeSVIResults object.

Parameters:
  • svi_results (Any) – Raw SVI results from numpyro

  • model_config (ModelConfig) – Model configuration object

  • adata (Optional[AnnData]) – Original AnnData object (if provided)

  • count_data (jnp.ndarray) – Processed count data

  • n_cells (int) – Number of cells

  • n_genes (int) – Number of genes

  • model_type (str) – Type of model

  • n_components (Optional[int]) – Number of mixture components

  • prior_params (Dict[str, Any]) – Dictionary of prior parameters

Returns:

Packaged results object

Return type:

ScribeSVIResults

class scribe.svi.ScribeSVIResults(params, loss_history, n_cells, n_genes, model_type, model_config, prior_params, obs=None, var=None, uns=None, n_obs=None, n_vars=None, posterior_samples=None, predictive_samples=None, n_components=None)[source]

Bases: object

Base class for SCRIBE variational inference results.

This class stores the results from SCRIBE’s variational inference procedure, including model parameters, loss history, dataset dimensions, and model configuration. It can optionally store metadata from an AnnData object and posterior/predictive samples.

Parameters:
  • params (Dict)

  • loss_history (Array)

  • n_cells (int)

  • n_genes (int)

  • model_type (str)

  • model_config (ModelConfig)

  • prior_params (Dict[str, Any])

  • obs (DataFrame | None)

  • var (DataFrame | None)

  • uns (Dict | None)

  • n_obs (int | None)

  • n_vars (int | None)

  • posterior_samples (Dict | None)

  • predictive_samples (Dict | None)

  • n_components (int | None)

params

Dictionary of inferred model parameters from SCRIBE

Type:

Dict

loss_history

Array containing the ELBO loss values during training

Type:

jnp.ndarray

n_cells

Number of cells in the dataset

Type:

int

n_genes

Number of genes in the dataset

Type:

int

model_type

Type of model used for inference

Type:

str

model_config

Configuration object specifying model architecture and priors

Type:

ModelConfig

prior_params

Dictionary of prior parameter values used during inference

Type:

Dict[str, Any]

obs

Cell-level metadata from adata.obs, if provided

Type:

Optional[pd.DataFrame]

var

Gene-level metadata from adata.var, if provided

Type:

Optional[pd.DataFrame]

uns

Unstructured metadata from adata.uns, if provided

Type:

Optional[Dict]

n_obs

Number of observations (cells), if provided

Type:

Optional[int]

n_vars

Number of variables (genes), if provided

Type:

Optional[int]

posterior_samples

Samples of parameters from the posterior distribution, if generated

Type:

Optional[Dict]

predictive_samples

Predictive samples generated from the model, if generated

Type:

Optional[Dict]

n_components

Number of mixture components, if using a mixture model

Type:

Optional[int]

__getitem__(index)[source]

Enable indexing of ScribeSVIResults object.

__post_init__()[source]

Validate model configuration and parameters.

assignment_entropy_map(counts, return_by='gene', batch_size=None, cells_axis=0, temperature=None, use_mean=True, verbose=True, dtype=<class 'jax.numpy.float32'>)[source]

Compute the entropy of component assignments for each cell evaluated at the MAP.

This method calculates the entropy of the posterior component assignment probabilities for each cell or gene, providing a measure of assignment uncertainty. Higher entropy values indicate more uncertainty in the component assignments, while lower values indicate more confident assignments.

The entropy is calculated as:

H = -∑(p_i * log(p_i))

where p_i are the normalized probabilities for each component.

Parameters:
  • counts (jnp.ndarray) – The count matrix with shape (n_cells, n_genes).

  • return_by (str, default='gene') – Whether to return the entropy by cell or gene.

  • batch_size (Optional[int], default=None) – Size of mini-batches for likelihood computation

  • cells_axis (int, default=0) – Axis along which cells are arranged. 0 means cells are rows.

  • temperature (Optional[float], default=None) – If provided, applies temperature scaling to the log-likelihoods before computing entropy.

  • use_mean (bool, default=True) – If True, uses the mean of the posterior component probabilities instead of the MAP.

  • verbose (bool, default=True) – If True, prints a warning if NaNs were replaced with means

  • dtype (jnp.dtype, default=jnp.float32) – Data type for numerical precision in computations

Returns:

The component entropy for each cell evaluated at the MAP. Shape: (n_cells,).

Return type:

jnp.ndarray

Raises:

ValueError

  • If the model is not a mixture model - If posterior samples have not been generated yet

cell_type_probabilities(counts, batch_size=None, cells_axis=0, ignore_nans=False, dtype=<class 'jax.numpy.float32'>, fit_distribution=True, temperature=None, weights=None, weight_type=None, verbose=True)[source]

Compute probabilistic cell type assignments and fit Dirichlet distributions to characterize assignment uncertainty.

For each cell, this method:
  1. Computes component-specific log-likelihoods using posterior samples

  2. Converts these to probability distributions over cell types

  3. Fits a Dirichlet distribution to characterize the uncertainty in these assignments

Parameters:
  • counts (jnp.ndarray) – Count data to evaluate assignments for

  • batch_size (Optional[int], default=None) – Size of mini-batches for likelihood computation

  • cells_axis (int, default=0) – Axis along which cells are arranged. 0 means cells are rows.

  • ignore_nans (bool, default=False) – If True, removes any samples that contain NaNs.

  • dtype (jnp.dtype, default=jnp.float32) – Data type for numerical precision in computations

  • fit_distribution (bool, default=True) – If True, fits a Dirichlet distribution to the assignment probabilities

  • temperature (Optional[float], default=None) – If provided, apply temperature scaling to log probabilities

  • weights (Optional[jnp.ndarray], default=None) – Array used to weight genes when computing log likelihoods

  • weight_type (Optional[str], default=None) –

    How to apply weights. Must be one of:
    • ’multiplicative’: multiply log probabilities by weights

    • ’additive’: add weights to log probabilities

  • verbose (bool, default=True) – If True, prints progress messages

Returns:

Dictionary containing:
  • ’concentration’: Dirichlet concentration parameters for each cell. Shape: (n_cells, n_components). Only returned if fit_distribution is True.

  • ’mean_probabilities’: Mean assignment probabilities for each cell. Shape: (n_cells, n_components). Only returned if fit_distribution is True.

  • ’sample_probabilities’: Assignment probabilities for each posterior sample. Shape: (n_samples, n_cells, n_components)

Return type:

Dict[str, jnp.ndarray]

Raises:

ValueError

  • If the model is not a mixture model - If posterior samples have not been generated yet

Note

Most of the log-likelihood value differences between cell types are extremely large. Thus, the computation usually returns either 0 or 1. This computation is therefore not very useful, but it is included for completeness.

cell_type_probabilities_map(counts, batch_size=None, cells_axis=0, dtype=<class 'jax.numpy.float32'>, temperature=None, weights=None, weight_type=None, use_mean=False, verbose=True)[source]

Compute probabilistic cell type assignments using MAP estimates of parameters.

For each cell, this method:

1. Computes component-specific log-likelihoods using MAP parameter estimates 2. Converts these to probability distributions over cell types

Parameters:
  • counts (jnp.ndarray) – Count data to evaluate assignments for

  • batch_size (Optional[int], default=None) – Size of mini-batches for likelihood computation

  • cells_axis (int, default=0) – Axis along which cells are arranged. 0 means cells are rows.

  • dtype (jnp.dtype, default=jnp.float32) – Data type for numerical precision in computations

  • temperature (Optional[float], default=None) – If provided, apply temperature scaling to log probabilities

  • weights (Optional[jnp.ndarray], default=None) – Array used to weight genes when computing log likelihoods

  • weight_type (Optional[str], default=None) –

    How to apply weights. Must be one of:
    • ’multiplicative’: multiply log probabilities by weights

    • ’additive’: add weights to log probabilities

  • use_mean (bool, default=False) – If True, replaces undefined MAP values (NaN) with posterior means

  • verbose (bool, default=True) – If True, prints progress messages

Returns:

Dictionary containing:
  • ’probabilities’: Assignment probabilities for each cell.

Shape: (n_cells, n_components)

Return type:

Dict[str, jnp.ndarray]

Raises:

ValueError – If the model is not a mixture model

classmethod from_anndata(adata, params, loss_history, model_config, **kwargs)[source]

Create ScribeSVIResults from AnnData object.

Parameters:
  • adata (Any)

  • params (Dict)

  • loss_history (Array)

  • model_config (ModelConfig)

get_component(component_index)[source]

Create a view of the results selecting a specific mixture component.

This method returns a new ScribeSVIResults object that contains parameter values for the specified component, allowing for further gene-based indexing. Only applicable to mixture models.

Parameters:

component_index (int) – Index of the component to select

Returns:

A new ScribeSVIResults object with parameters for the selected component

Return type:

ScribeSVIResults

Raises:

ValueError – If the model is not a mixture model

get_distributions(backend='numpyro', split=False)[source]

Get the variational distributions for all parameters.

This method now delegates to the model-specific get_posterior_distributions function associated with the parameterization.

Parameters:
  • backend (str, default="numpyro") – Statistical package to use for distributions. Must be one of: - “scipy”: Returns scipy.stats distributions - “numpyro”: Returns numpyro.distributions

  • split (bool, default=False) – If True, returns lists of individual distributions for multidimensional parameters instead of batch distributions.

Returns:

Dictionary mapping parameter names to their distributions.

Return type:

Dict[str, Any]

Raises:

ValueError – If backend is not supported.

get_map(use_mean=False, canonical=True, verbose=True)[source]

Get the maximum a posteriori (MAP) estimates from the variational posterior.

Parameters:
  • use_mean (bool, default=False) – If True, replaces undefined MAP values (NaN) with posterior means

  • canonical (bool, default=True) – If True, includes canonical parameters (p, r) computed from other parameters for linked, odds_ratio, and unconstrained parameterizations

  • verbose (bool, default=True) – If True, prints a warning if NaNs were replaced with means

Returns:

Dictionary of MAP estimates for each parameter

Return type:

Dict[str, jnp.ndarray]

get_posterior_samples(rng_key=Array([0, 42], dtype=uint32), n_samples=100, store_samples=True)[source]

Sample parameters from the variational posterior distribution.

Parameters:
  • rng_key (PRNGKey)

  • n_samples (int)

  • store_samples (bool)

Return type:

Dict

get_ppc_samples(rng_key=Array([0, 42], dtype=uint32), n_samples=100, batch_size=None, store_samples=True)[source]

Generate posterior predictive check samples.

Parameters:
  • rng_key (PRNGKey)

  • n_samples (int)

  • batch_size (int | None)

  • store_samples (bool)

Return type:

Dict

get_predictive_samples(rng_key=Array([0, 42], dtype=uint32), batch_size=None, store_samples=True)[source]

Generate predictive samples using posterior parameter samples.

Parameters:
  • rng_key (PRNGKey)

  • batch_size (int | None)

  • store_samples (bool)

Return type:

Array

log_likelihood(counts, batch_size=None, return_by='cell', cells_axis=0, ignore_nans=False, split_components=False, weights=None, weight_type=None, dtype=<class 'jax.numpy.float32'>)[source]

Compute log likelihood of data under posterior samples.

Parameters:
  • counts (jnp.ndarray) – Count data to evaluate likelihood on

  • batch_size (Optional[int], default=None) – Size of mini-batches used for likelihood computation

  • return_by (str, default='cell') –

    Specifies how to return the log probabilities. Must be one of:
    • ’cell’: returns log probabilities summed over genes

    • ’gene’: returns log probabilities summed over cells

  • cells_axis (int, default=0) – Axis along which cells are arranged. 0 means cells are rows.

  • ignore_nans (bool, default=False) – If True, removes any samples that contain NaNs.

  • split_components (bool, default=False) – If True, returns log likelihoods for each mixture component separately. Only applicable for mixture models.

  • weights (Optional[jnp.ndarray], default=None) – Array used to weight the log likelihoods (for mixture models).

  • weight_type (Optional[str], default=None) –

    How to apply weights. Must be one of:
    • ’multiplicative’: multiply log probabilities by weights

    • ’additive’: add weights to log probabilities

  • dtype (jnp.dtype, default=jnp.float32) – Data type for numerical precision in computations

Returns:

Array of log likelihoods. Shape depends on model type, return_by and split_components parameters. For standard models:

  • ’cell’: shape (n_samples, n_cells)

  • ’gene’: shape (n_samples, n_genes)

For mixture models with split_components=False:
  • ’cell’: shape (n_samples, n_cells)

  • ’gene’: shape (n_samples, n_genes)

For mixture models with split_components=True:
  • ’cell’: shape (n_samples, n_cells, n_components)

  • ’gene’: shape (n_samples, n_genes, n_components)

Return type:

jnp.ndarray

Raises:

ValueError – If posterior samples have not been generated yet

log_likelihood_map(counts, batch_size=None, gene_batch_size=None, return_by='cell', cells_axis=0, split_components=False, weights=None, weight_type=None, use_mean=True, verbose=True, dtype=<class 'jax.numpy.float32'>)[source]

Compute log likelihood of data using MAP parameter estimates.

Parameters:
  • counts (jnp.ndarray) – Count data to evaluate likelihood on

  • batch_size (Optional[int], default=None) – Size of mini-batches used for likelihood computation

  • gene_batch_size (Optional[int], default=None) – Size of mini-batches used for likelihood computation by gene

  • return_by (str, default='cell') –

    Specifies how to return the log probabilities. Must be one of:
    • ’cell’: returns log probabilities summed over genes

    • ’gene’: returns log probabilities summed over cells

  • cells_axis (int, default=0) – Axis along which cells are arranged. 0 means cells are rows.

  • split_components (bool, default=False) – If True, returns log likelihoods for each mixture component separately. Only applicable for mixture models.

  • weights (Optional[jnp.ndarray], default=None) – Array used to weight the log likelihoods (for mixture models).

  • weight_type (Optional[str], default=None) –

    How to apply weights. Must be one of:
    • ’multiplicative’: multiply log probabilities by weights

    • ’additive’: add weights to log probabilities

  • use_mean (bool, default=False) – If True, replaces undefined MAP values (NaN) with posterior means

  • verbose (bool, default=True) – If True, prints a warning if NaNs were replaced with means

  • dtype (jnp.dtype, default=jnp.float32) – Data type for numerical precision in computations

Returns:

Array of log likelihoods. Shape depends on model type, return_by and split_components parameters.

Return type:

jnp.ndarray

mixture_component_entropy(counts, return_by='gene', batch_size=None, cells_axis=0, ignore_nans=False, temperature=None, dtype=<class 'jax.numpy.float32'>)[source]

Compute the entropy of mixture component assignment probabilities.

This method calculates the Shannon entropy of the posterior component assignment probabilities for each observation (cell or gene), providing a quantitative measure of assignment uncertainty in mixture models.

The entropy quantifies how uncertain the model is about which component each observation belongs to:

  • Low entropy (≈0): High confidence in component assignment

  • High entropy (≈log(n_components)): High uncertainty in assignment

  • Maximum entropy: Uniform assignment probabilities across all components

The entropy is calculated as:

H = -∑(p_i * log(p_i))

where p_i are the posterior probabilities of assignment to component i.

Parameters:
  • counts (jnp.ndarray) – Input count data to evaluate component assignments for. Shape should be (n_cells, n_genes) if cells_axis=0, or (n_genes, n_cells) if cells_axis=1.

  • return_by (str, default='gene') –

    Specifies how to compute and return the entropy. Must be one of:
    • ’cell’: Compute entropy of component assignments for each cell

    • ’gene’: Compute entropy of component assignments for each gene

  • batch_size (Optional[int], default=None) – If provided, processes the data in batches of this size to reduce memory usage. Useful for large datasets.

  • cells_axis (int, default=0) –

    Specifies which axis in the input counts contains the cells:
    • 0: cells are rows (shape: n_cells × n_genes)

    • 1: cells are columns (shape: n_genes × n_cells)

  • ignore_nans (bool, default=False) – If True, excludes any samples containing NaN values from the entropy calculation.

  • temperature (Optional[float], default=None) – If provided, applies temperature scaling to the log-likelihoods before computing entropy. Temperature scaling modifies the sharpness of probability distributions by dividing log probabilities by a temperature parameter T.

  • dtype (jnp.dtype, default=jnp.float32) – Data type for numerical precision in computations.

Returns:

Array of entropy values. Shape depends on return_by:
  • If return_by=’cell’: shape is (n_samples, n_cells)

  • If return_by=’gene’: shape is (n_samples, n_genes)

Higher values indicate more uncertainty in component assignments.

Return type:

jnp.ndarray

Raises:

ValueError – If the model is not a mixture model or if posterior samples haven’t been generated.

Notes

  • This method requires posterior samples to be available. Call get_posterior_samples() first if they haven’t been generated.

  • The entropy is computed using the full posterior predictive distribution, accounting for uncertainty in the model parameters.

  • For a mixture with K components, the maximum possible entropy is log(K).

  • Entropy values can be used to identify observations that are difficult to classify or that may belong to multiple components.

n_components: int | None = None
n_obs: int | None = None
n_vars: int | None = None
normalize_counts(rng_key=Array([0, 42], dtype=uint32), n_samples_dirichlet=1, fit_distribution=False, store_samples=True, sample_axis=0, return_concentrations=False, backend='numpyro', verbose=True)[source]

Normalize counts using posterior samples of the r parameter.

This method takes posterior samples of the dispersion parameter (r) and uses them as concentration parameters for Dirichlet distributions to generate normalized expression profiles. For mixture models, normalization is performed per component, resulting in an extra dimension in the output.

Based on the insights from the Dirichlet-multinomial model derivation, the r parameters represent the concentration parameters of a Dirichlet distribution that can be used to generate normalized expression profiles.

The method generates Dirichlet samples using all posterior samples of r, then fits a single Dirichlet distribution to all these samples (or one per component for mixture models).

Parameters:
  • rng_key (random.PRNGKey, default=random.PRNGKey(42)) – JAX random number generator key

  • n_samples_dirichlet (int, default=1000) – Number of samples to draw from each Dirichlet distribution

  • fit_distribution (bool, default=True) – If True, fits a Dirichlet distribution to the generated samples using fit_dirichlet_minka from stats.py

  • store_samples (bool, default=False) – If True, includes the raw Dirichlet samples in the output

  • sample_axis (int, default=0) – Axis containing samples in the Dirichlet fitting (passed to fit_dirichlet_minka)

  • return_concentrations (bool, default=False) – If True, returns the original r parameter samples used as concentrations

  • backend (str, default="numpyro") – Statistical package to use for distributions when fit_distribution=True. Must be one of: - “numpyro”: Returns numpyro.distributions.Dirichlet objects - “scipy”: Returns scipy.stats distributions via numpyro_to_scipy conversion

  • verbose (bool, default=True) – If True, prints progress messages

Returns:

Dictionary containing normalized expression profiles. Keys depend on input arguments:

  • ’samples’: Raw Dirichlet samples (if store_samples=True)

  • ’concentrations’: Fitted concentration parameters (if fit_distribution=True)

  • ’mean_probabilities’: Mean probabilities from fitted distribution (if fit_distribution=True)

  • ’distributions’: Dirichlet distribution objects (if fit_distribution=True)

  • ’original_concentrations’: Original r parameter samples (if return_concentrations=True)

For non-mixture models:
  • samples: shape (n_posterior_samples, n_genes, n_samples_dirichlet) or (n_posterior_samples, n_genes) if n_samples_dirichlet=1

  • concentrations: shape (n_genes,) - single fitted distribution

  • mean_probabilities: shape (n_genes,) - single fitted distribution

  • distributions: single Dirichlet distribution object

For mixture models:
  • samples: shape (n_posterior_samples, n_components, n_genes, n_samples_dirichlet) or (n_posterior_samples, n_components, n_genes) if n_samples_dirichlet=1

  • concentrations: shape (n_components, n_genes) - one fitted

distribution per component - mean_probabilities: shape (n_components, n_genes) - one fitted distribution per component - distributions: list of n_components Dirichlet distribution objects

Return type:

Dict[str, Union[jnp.ndarray, object]]

Raises:

ValueError – If posterior samples have not been generated yet, or if ‘r’ parameter is not found in posterior samples

Examples

>>> # For a non-mixture model
>>> normalized = results.normalize_counts(
...     n_samples_dirichlet=100,
...     fit_distribution=True
... )
>>> print(normalized['mean_probabilities'].shape)  # (n_genes,)
>>> print(type(normalized['distributions']))  # Single Dirichlet distribution
>>> # For a mixture model
>>> normalized = results.normalize_counts(
...     n_samples_dirichlet=100,
...     fit_distribution=True
... )
>>> print(normalized['mean_probabilities'].shape)  # (n_components, n_genes)
>>> print(len(normalized['distributions']))  # n_components
obs: DataFrame | None = None
posterior_samples: Dict | None = None
predictive_samples: Dict | None = None
uns: Dict | None = None
var: DataFrame | None = None
params: Dict
loss_history: Array
n_cells: int
n_genes: int
model_type: str
model_config: ModelConfig
prior_params: Dict[str, Any]