svi
Stochastic Variational Inference (SVI) module for single-cell RNA sequencing data analysis.
This module implements SVI inference for SCRIBE models using Numpyro’s SVI.
- class scribe.svi.SVIInferenceEngine[source]
Bases:
objectHandles SVI inference execution.
- static run_inference(model_config, count_data, n_cells, n_genes, optimizer=<numpyro.optim.Adam object>, loss=<numpyro.infer.elbo.TraceMeanField_ELBO object>, n_steps=100000, batch_size=None, seed=42, stable_update=True)[source]
Execute SVI inference.
- Parameters:
model_config (ModelConfig) – Model configuration object
count_data (jnp.ndarray) – Processed count data (cells as rows)
n_cells (int) – Number of cells
n_genes (int) – Number of genes
optimizer (numpyro.optim.optimizers, default=Adam(step_size=0.001)) – Optimizer for variational inference
loss (numpyro.infer.elbo, default=TraceMeanField_ELBO()) – Loss function for variational inference
n_steps (int, default=100_000) – Number of optimization steps
batch_size (Optional[int], default=None) – Mini-batch size. If None, uses full dataset.
seed (int, default=42) – Random seed for reproducibility
stable_update (bool, default=True) – Whether to use numerically stable parameter updates
- Returns:
Results from SVI run containing optimized parameters and loss history
- Return type:
numpyro.infer.svi.SVIRunResult
- class scribe.svi.SVIResultsFactory[source]
Bases:
objectFactory for creating SVI results objects.
- static create_results(svi_results, model_config, adata, count_data, n_cells, n_genes, model_type, n_components, prior_params)[source]
Package SVI results into ScribeSVIResults object.
- Parameters:
svi_results (Any) – Raw SVI results from numpyro
model_config (ModelConfig) – Model configuration object
adata (Optional[AnnData]) – Original AnnData object (if provided)
count_data (jnp.ndarray) – Processed count data
n_cells (int) – Number of cells
n_genes (int) – Number of genes
model_type (str) – Type of model
n_components (Optional[int]) – Number of mixture components
prior_params (Dict[str, Any]) – Dictionary of prior parameters
- Returns:
Packaged results object
- Return type:
- class scribe.svi.ScribeSVIResults(params, loss_history, n_cells, n_genes, model_type, model_config, prior_params, obs=None, var=None, uns=None, n_obs=None, n_vars=None, posterior_samples=None, predictive_samples=None, n_components=None)[source]
Bases:
objectBase class for SCRIBE variational inference results.
This class stores the results from SCRIBE’s variational inference procedure, including model parameters, loss history, dataset dimensions, and model configuration. It can optionally store metadata from an AnnData object and posterior/predictive samples.
- Parameters:
params (Dict)
loss_history (Array)
n_cells (int)
n_genes (int)
model_type (str)
model_config (ModelConfig)
obs (DataFrame | None)
var (DataFrame | None)
uns (Dict | None)
n_obs (int | None)
n_vars (int | None)
posterior_samples (Dict | None)
predictive_samples (Dict | None)
n_components (int | None)
- params
Dictionary of inferred model parameters from SCRIBE
- Type:
Dict
- loss_history
Array containing the ELBO loss values during training
- Type:
jnp.ndarray
- model_config
Configuration object specifying model architecture and priors
- Type:
ModelConfig
- obs
Cell-level metadata from adata.obs, if provided
- Type:
Optional[pd.DataFrame]
- var
Gene-level metadata from adata.var, if provided
- Type:
Optional[pd.DataFrame]
- uns
Unstructured metadata from adata.uns, if provided
- Type:
Optional[Dict]
- posterior_samples
Samples of parameters from the posterior distribution, if generated
- Type:
Optional[Dict]
- predictive_samples
Predictive samples generated from the model, if generated
- Type:
Optional[Dict]
- assignment_entropy_map(counts, return_by='gene', batch_size=None, cells_axis=0, temperature=None, use_mean=True, verbose=True, dtype=<class 'jax.numpy.float32'>)[source]
Compute the entropy of component assignments for each cell evaluated at the MAP.
This method calculates the entropy of the posterior component assignment probabilities for each cell or gene, providing a measure of assignment uncertainty. Higher entropy values indicate more uncertainty in the component assignments, while lower values indicate more confident assignments.
- The entropy is calculated as:
H = -∑(p_i * log(p_i))
where p_i are the normalized probabilities for each component.
- Parameters:
counts (jnp.ndarray) – The count matrix with shape (n_cells, n_genes).
return_by (str, default='gene') – Whether to return the entropy by cell or gene.
batch_size (Optional[int], default=None) – Size of mini-batches for likelihood computation
cells_axis (int, default=0) – Axis along which cells are arranged. 0 means cells are rows.
temperature (Optional[float], default=None) – If provided, applies temperature scaling to the log-likelihoods before computing entropy.
use_mean (bool, default=True) – If True, uses the mean of the posterior component probabilities instead of the MAP.
verbose (bool, default=True) – If True, prints a warning if NaNs were replaced with means
dtype (jnp.dtype, default=jnp.float32) – Data type for numerical precision in computations
- Returns:
The component entropy for each cell evaluated at the MAP. Shape: (n_cells,).
- Return type:
jnp.ndarray
- Raises:
If the model is not a mixture model - If posterior samples have not been generated yet
- cell_type_probabilities(counts, batch_size=None, cells_axis=0, ignore_nans=False, dtype=<class 'jax.numpy.float32'>, fit_distribution=True, temperature=None, weights=None, weight_type=None, verbose=True)[source]
Compute probabilistic cell type assignments and fit Dirichlet distributions to characterize assignment uncertainty.
- For each cell, this method:
Computes component-specific log-likelihoods using posterior samples
Converts these to probability distributions over cell types
Fits a Dirichlet distribution to characterize the uncertainty in these assignments
- Parameters:
counts (jnp.ndarray) – Count data to evaluate assignments for
batch_size (Optional[int], default=None) – Size of mini-batches for likelihood computation
cells_axis (int, default=0) – Axis along which cells are arranged. 0 means cells are rows.
ignore_nans (bool, default=False) – If True, removes any samples that contain NaNs.
dtype (jnp.dtype, default=jnp.float32) – Data type for numerical precision in computations
fit_distribution (bool, default=True) – If True, fits a Dirichlet distribution to the assignment probabilities
temperature (Optional[float], default=None) – If provided, apply temperature scaling to log probabilities
weights (Optional[jnp.ndarray], default=None) – Array used to weight genes when computing log likelihoods
weight_type (Optional[str], default=None) –
- How to apply weights. Must be one of:
’multiplicative’: multiply log probabilities by weights
’additive’: add weights to log probabilities
verbose (bool, default=True) – If True, prints progress messages
- Returns:
- Dictionary containing:
’concentration’: Dirichlet concentration parameters for each cell. Shape: (n_cells, n_components). Only returned if fit_distribution is True.
’mean_probabilities’: Mean assignment probabilities for each cell. Shape: (n_cells, n_components). Only returned if fit_distribution is True.
’sample_probabilities’: Assignment probabilities for each posterior sample. Shape: (n_samples, n_cells, n_components)
- Return type:
Dict[str, jnp.ndarray]
- Raises:
If the model is not a mixture model - If posterior samples have not been generated yet
Note
Most of the log-likelihood value differences between cell types are extremely large. Thus, the computation usually returns either 0 or 1. This computation is therefore not very useful, but it is included for completeness.
- cell_type_probabilities_map(counts, batch_size=None, cells_axis=0, dtype=<class 'jax.numpy.float32'>, temperature=None, weights=None, weight_type=None, use_mean=False, verbose=True)[source]
Compute probabilistic cell type assignments using MAP estimates of parameters.
- For each cell, this method:
1. Computes component-specific log-likelihoods using MAP parameter estimates 2. Converts these to probability distributions over cell types
- Parameters:
counts (jnp.ndarray) – Count data to evaluate assignments for
batch_size (Optional[int], default=None) – Size of mini-batches for likelihood computation
cells_axis (int, default=0) – Axis along which cells are arranged. 0 means cells are rows.
dtype (jnp.dtype, default=jnp.float32) – Data type for numerical precision in computations
temperature (Optional[float], default=None) – If provided, apply temperature scaling to log probabilities
weights (Optional[jnp.ndarray], default=None) – Array used to weight genes when computing log likelihoods
weight_type (Optional[str], default=None) –
- How to apply weights. Must be one of:
’multiplicative’: multiply log probabilities by weights
’additive’: add weights to log probabilities
use_mean (bool, default=False) – If True, replaces undefined MAP values (NaN) with posterior means
verbose (bool, default=True) – If True, prints progress messages
- Returns:
- Dictionary containing:
’probabilities’: Assignment probabilities for each cell.
Shape: (n_cells, n_components)
- Return type:
Dict[str, jnp.ndarray]
- Raises:
ValueError – If the model is not a mixture model
- classmethod from_anndata(adata, params, loss_history, model_config, **kwargs)[source]
Create ScribeSVIResults from AnnData object.
- get_component(component_index)[source]
Create a view of the results selecting a specific mixture component.
This method returns a new ScribeSVIResults object that contains parameter values for the specified component, allowing for further gene-based indexing. Only applicable to mixture models.
- Parameters:
component_index (int) – Index of the component to select
- Returns:
A new ScribeSVIResults object with parameters for the selected component
- Return type:
- Raises:
ValueError – If the model is not a mixture model
- get_distributions(backend='numpyro', split=False)[source]
Get the variational distributions for all parameters.
This method now delegates to the model-specific get_posterior_distributions function associated with the parameterization.
- Parameters:
backend (str, default="numpyro") – Statistical package to use for distributions. Must be one of: - “scipy”: Returns scipy.stats distributions - “numpyro”: Returns numpyro.distributions
split (bool, default=False) – If True, returns lists of individual distributions for multidimensional parameters instead of batch distributions.
- Returns:
Dictionary mapping parameter names to their distributions.
- Return type:
Dict[str, Any]
- Raises:
ValueError – If backend is not supported.
- get_map(use_mean=False, canonical=True, verbose=True)[source]
Get the maximum a posteriori (MAP) estimates from the variational posterior.
- Parameters:
use_mean (bool, default=False) – If True, replaces undefined MAP values (NaN) with posterior means
canonical (bool, default=True) – If True, includes canonical parameters (p, r) computed from other parameters for linked, odds_ratio, and unconstrained parameterizations
verbose (bool, default=True) – If True, prints a warning if NaNs were replaced with means
- Returns:
Dictionary of MAP estimates for each parameter
- Return type:
Dict[str, jnp.ndarray]
- get_posterior_samples(rng_key=Array([0, 42], dtype=uint32), n_samples=100, store_samples=True)[source]
Sample parameters from the variational posterior distribution.
- get_ppc_samples(rng_key=Array([0, 42], dtype=uint32), n_samples=100, batch_size=None, store_samples=True)[source]
Generate posterior predictive check samples.
- get_predictive_samples(rng_key=Array([0, 42], dtype=uint32), batch_size=None, store_samples=True)[source]
Generate predictive samples using posterior parameter samples.
- log_likelihood(counts, batch_size=None, return_by='cell', cells_axis=0, ignore_nans=False, split_components=False, weights=None, weight_type=None, dtype=<class 'jax.numpy.float32'>)[source]
Compute log likelihood of data under posterior samples.
- Parameters:
counts (jnp.ndarray) – Count data to evaluate likelihood on
batch_size (Optional[int], default=None) – Size of mini-batches used for likelihood computation
return_by (str, default='cell') –
- Specifies how to return the log probabilities. Must be one of:
’cell’: returns log probabilities summed over genes
’gene’: returns log probabilities summed over cells
cells_axis (int, default=0) – Axis along which cells are arranged. 0 means cells are rows.
ignore_nans (bool, default=False) – If True, removes any samples that contain NaNs.
split_components (bool, default=False) – If True, returns log likelihoods for each mixture component separately. Only applicable for mixture models.
weights (Optional[jnp.ndarray], default=None) – Array used to weight the log likelihoods (for mixture models).
weight_type (Optional[str], default=None) –
- How to apply weights. Must be one of:
’multiplicative’: multiply log probabilities by weights
’additive’: add weights to log probabilities
dtype (jnp.dtype, default=jnp.float32) – Data type for numerical precision in computations
- Returns:
Array of log likelihoods. Shape depends on model type, return_by and split_components parameters. For standard models:
’cell’: shape (n_samples, n_cells)
’gene’: shape (n_samples, n_genes)
- For mixture models with split_components=False:
’cell’: shape (n_samples, n_cells)
’gene’: shape (n_samples, n_genes)
- For mixture models with split_components=True:
’cell’: shape (n_samples, n_cells, n_components)
’gene’: shape (n_samples, n_genes, n_components)
- Return type:
jnp.ndarray
- Raises:
ValueError – If posterior samples have not been generated yet
- log_likelihood_map(counts, batch_size=None, gene_batch_size=None, return_by='cell', cells_axis=0, split_components=False, weights=None, weight_type=None, use_mean=True, verbose=True, dtype=<class 'jax.numpy.float32'>)[source]
Compute log likelihood of data using MAP parameter estimates.
- Parameters:
counts (jnp.ndarray) – Count data to evaluate likelihood on
batch_size (Optional[int], default=None) – Size of mini-batches used for likelihood computation
gene_batch_size (Optional[int], default=None) – Size of mini-batches used for likelihood computation by gene
return_by (str, default='cell') –
- Specifies how to return the log probabilities. Must be one of:
’cell’: returns log probabilities summed over genes
’gene’: returns log probabilities summed over cells
cells_axis (int, default=0) – Axis along which cells are arranged. 0 means cells are rows.
split_components (bool, default=False) – If True, returns log likelihoods for each mixture component separately. Only applicable for mixture models.
weights (Optional[jnp.ndarray], default=None) – Array used to weight the log likelihoods (for mixture models).
weight_type (Optional[str], default=None) –
- How to apply weights. Must be one of:
’multiplicative’: multiply log probabilities by weights
’additive’: add weights to log probabilities
use_mean (bool, default=False) – If True, replaces undefined MAP values (NaN) with posterior means
verbose (bool, default=True) – If True, prints a warning if NaNs were replaced with means
dtype (jnp.dtype, default=jnp.float32) – Data type for numerical precision in computations
- Returns:
Array of log likelihoods. Shape depends on model type, return_by and split_components parameters.
- Return type:
jnp.ndarray
- mixture_component_entropy(counts, return_by='gene', batch_size=None, cells_axis=0, ignore_nans=False, temperature=None, dtype=<class 'jax.numpy.float32'>)[source]
Compute the entropy of mixture component assignment probabilities.
This method calculates the Shannon entropy of the posterior component assignment probabilities for each observation (cell or gene), providing a quantitative measure of assignment uncertainty in mixture models.
The entropy quantifies how uncertain the model is about which component each observation belongs to:
Low entropy (≈0): High confidence in component assignment
High entropy (≈log(n_components)): High uncertainty in assignment
Maximum entropy: Uniform assignment probabilities across all components
- The entropy is calculated as:
H = -∑(p_i * log(p_i))
where p_i are the posterior probabilities of assignment to component i.
- Parameters:
counts (jnp.ndarray) – Input count data to evaluate component assignments for. Shape should be (n_cells, n_genes) if cells_axis=0, or (n_genes, n_cells) if cells_axis=1.
return_by (str, default='gene') –
- Specifies how to compute and return the entropy. Must be one of:
’cell’: Compute entropy of component assignments for each cell
’gene’: Compute entropy of component assignments for each gene
batch_size (Optional[int], default=None) – If provided, processes the data in batches of this size to reduce memory usage. Useful for large datasets.
cells_axis (int, default=0) –
- Specifies which axis in the input counts contains the cells:
0: cells are rows (shape: n_cells × n_genes)
1: cells are columns (shape: n_genes × n_cells)
ignore_nans (bool, default=False) – If True, excludes any samples containing NaN values from the entropy calculation.
temperature (Optional[float], default=None) – If provided, applies temperature scaling to the log-likelihoods before computing entropy. Temperature scaling modifies the sharpness of probability distributions by dividing log probabilities by a temperature parameter T.
dtype (jnp.dtype, default=jnp.float32) – Data type for numerical precision in computations.
- Returns:
- Array of entropy values. Shape depends on return_by:
If return_by=’cell’: shape is (n_samples, n_cells)
If return_by=’gene’: shape is (n_samples, n_genes)
Higher values indicate more uncertainty in component assignments.
- Return type:
jnp.ndarray
- Raises:
ValueError – If the model is not a mixture model or if posterior samples haven’t been generated.
Notes
This method requires posterior samples to be available. Call get_posterior_samples() first if they haven’t been generated.
The entropy is computed using the full posterior predictive distribution, accounting for uncertainty in the model parameters.
For a mixture with K components, the maximum possible entropy is log(K).
Entropy values can be used to identify observations that are difficult to classify or that may belong to multiple components.
- normalize_counts(rng_key=Array([0, 42], dtype=uint32), n_samples_dirichlet=1, fit_distribution=False, store_samples=True, sample_axis=0, return_concentrations=False, backend='numpyro', verbose=True)[source]
Normalize counts using posterior samples of the r parameter.
This method takes posterior samples of the dispersion parameter (r) and uses them as concentration parameters for Dirichlet distributions to generate normalized expression profiles. For mixture models, normalization is performed per component, resulting in an extra dimension in the output.
Based on the insights from the Dirichlet-multinomial model derivation, the r parameters represent the concentration parameters of a Dirichlet distribution that can be used to generate normalized expression profiles.
The method generates Dirichlet samples using all posterior samples of r, then fits a single Dirichlet distribution to all these samples (or one per component for mixture models).
- Parameters:
rng_key (random.PRNGKey, default=random.PRNGKey(42)) – JAX random number generator key
n_samples_dirichlet (int, default=1000) – Number of samples to draw from each Dirichlet distribution
fit_distribution (bool, default=True) – If True, fits a Dirichlet distribution to the generated samples using fit_dirichlet_minka from stats.py
store_samples (bool, default=False) – If True, includes the raw Dirichlet samples in the output
sample_axis (int, default=0) – Axis containing samples in the Dirichlet fitting (passed to fit_dirichlet_minka)
return_concentrations (bool, default=False) – If True, returns the original r parameter samples used as concentrations
backend (str, default="numpyro") – Statistical package to use for distributions when fit_distribution=True. Must be one of: - “numpyro”: Returns numpyro.distributions.Dirichlet objects - “scipy”: Returns scipy.stats distributions via numpyro_to_scipy conversion
verbose (bool, default=True) – If True, prints progress messages
- Returns:
Dictionary containing normalized expression profiles. Keys depend on input arguments:
’samples’: Raw Dirichlet samples (if store_samples=True)
’concentrations’: Fitted concentration parameters (if fit_distribution=True)
’mean_probabilities’: Mean probabilities from fitted distribution (if fit_distribution=True)
’distributions’: Dirichlet distribution objects (if fit_distribution=True)
’original_concentrations’: Original r parameter samples (if return_concentrations=True)
- For non-mixture models:
samples: shape (n_posterior_samples, n_genes, n_samples_dirichlet) or (n_posterior_samples, n_genes) if n_samples_dirichlet=1
concentrations: shape (n_genes,) - single fitted distribution
mean_probabilities: shape (n_genes,) - single fitted distribution
distributions: single Dirichlet distribution object
- For mixture models:
samples: shape (n_posterior_samples, n_components, n_genes, n_samples_dirichlet) or (n_posterior_samples, n_components, n_genes) if n_samples_dirichlet=1
concentrations: shape (n_components, n_genes) - one fitted
distribution per component - mean_probabilities: shape (n_components, n_genes) - one fitted distribution per component - distributions: list of n_components Dirichlet distribution objects
- Return type:
- Raises:
ValueError – If posterior samples have not been generated yet, or if ‘r’ parameter is not found in posterior samples
Examples
>>> # For a non-mixture model >>> normalized = results.normalize_counts( ... n_samples_dirichlet=100, ... fit_distribution=True ... ) >>> print(normalized['mean_probabilities'].shape) # (n_genes,) >>> print(type(normalized['distributions'])) # Single Dirichlet distribution
>>> # For a mixture model >>> normalized = results.normalize_counts( ... n_samples_dirichlet=100, ... fit_distribution=True ... ) >>> print(normalized['mean_probabilities'].shape) # (n_components, n_genes) >>> print(len(normalized['distributions'])) # n_components
- model_config: ModelConfig