"""
Results classes for SCRIBE inference.
"""
from typing import Dict, Optional, Union, Callable, Tuple, Any
from dataclasses import dataclass, replace
import warnings
import jax.numpy as jnp
import jax.scipy as jsp
from jax.nn import sigmoid, softmax
import pandas as pd
import numpyro.distributions as dist
from jax import random, jit, vmap
import numpy as np
import scipy.stats as stats
from ..sampling import (
sample_variational_posterior,
generate_predictive_samples,
)
from ..stats import fit_dirichlet_minka
from ..models.model_config import ModelConfig
# Import multipledispatch functions from stats
from ..stats import hellinger, jensen_shannon
from numpyro.distributions.kl import kl_divergence
from ..utils import numpyro_to_scipy
from ..core.normalization import normalize_counts_from_posterior
try:
from anndata import AnnData
except ImportError:
AnnData = None
# ------------------------------------------------------------------------------
# Base class for inference results
# ------------------------------------------------------------------------------
[docs]
@dataclass
class ScribeSVIResults:
"""
Base class for SCRIBE variational inference results.
This class stores the results from SCRIBE's variational inference procedure,
including model parameters, loss history, dataset dimensions, and model
configuration. It can optionally store metadata from an AnnData object and
posterior/predictive samples.
Attributes
----------
params : Dict
Dictionary of inferred model parameters from SCRIBE
loss_history : jnp.ndarray
Array containing the ELBO loss values during training
n_cells : int
Number of cells in the dataset
n_genes : int
Number of genes in the dataset
model_type : str
Type of model used for inference
model_config : ModelConfig
Configuration object specifying model architecture and priors
prior_params : Dict[str, Any]
Dictionary of prior parameter values used during inference
obs : Optional[pd.DataFrame]
Cell-level metadata from adata.obs, if provided
var : Optional[pd.DataFrame]
Gene-level metadata from adata.var, if provided
uns : Optional[Dict]
Unstructured metadata from adata.uns, if provided
n_obs : Optional[int]
Number of observations (cells), if provided
n_vars : Optional[int]
Number of variables (genes), if provided
posterior_samples : Optional[Dict]
Samples of parameters from the posterior distribution, if generated
predictive_samples : Optional[Dict]
Predictive samples generated from the model, if generated
n_components : Optional[int]
Number of mixture components, if using a mixture model
"""
# Core inference results
params: Dict
loss_history: jnp.ndarray
n_cells: int
n_genes: int
model_type: str
model_config: ModelConfig
prior_params: Dict[str, Any]
# Standard metadata from AnnData object
obs: Optional[pd.DataFrame] = None
var: Optional[pd.DataFrame] = None
uns: Optional[Dict] = None
n_obs: Optional[int] = None
n_vars: Optional[int] = None
# Optional results
posterior_samples: Optional[Dict] = None
predictive_samples: Optional[Dict] = None
n_components: Optional[int] = None
# --------------------------------------------------------------------------
[docs]
def __post_init__(self):
"""Validate model configuration and parameters."""
# Set n_components from model_config if not explicitly provided
if (
self.n_components is None
and self.model_config.n_components is not None
):
self.n_components = self.model_config.n_components
self._validate_model_config()
# --------------------------------------------------------------------------
def _validate_model_config(self):
"""Validate model configuration matches model type."""
# Validate base model
if self.model_config.base_model != self.model_type:
raise ValueError(
f"Model type '{self.model_type}' does not match config "
f"base model '{self.model_config.base_model}'"
)
# Validate n_components consistency
if self.n_components is not None:
if not self.model_type.endswith("_mix"):
raise ValueError(
f"Model type '{self.model_type}' is not a mixture model "
f"but n_components={self.n_components} was specified"
)
if self.model_config.n_components != self.n_components:
raise ValueError(
f"n_components mismatch: {self.n_components} vs "
f"{self.model_config.n_components} in model_config"
)
# Validate required distributions based on model type and unconstrained
# flag
unconstrained = getattr(self.model_config, "unconstrained", False)
# ZINB models require gate priors
if "zinb" in self.model_type:
if unconstrained:
# Unconstrained uses gate_unconstrained_prior
if self.model_config.gate_unconstrained_prior is None:
raise ValueError(
"ZINB models with unconstrained=True require "
"gate_unconstrained_prior"
)
else:
# Constrained uses gate_param_prior
if self.model_config.gate_param_prior is None:
raise ValueError("ZINB models require gate_param_prior")
else:
# Non-ZINB models should not have gate priors
if unconstrained:
if self.model_config.gate_unconstrained_prior is not None:
raise ValueError(
"Non-ZINB models should not have "
"gate_unconstrained_prior"
)
else:
if self.model_config.gate_param_prior is not None:
raise ValueError(
"Non-ZINB models should not have gate_param_prior"
)
# VCP models require capture probability priors
if "vcp" in self.model_type:
if unconstrained:
# Unconstrained uses p_capture_unconstrained_prior
if self.model_config.p_capture_unconstrained_prior is None:
raise ValueError(
"VCP models with unconstrained=True require "
"p_capture_unconstrained_prior"
)
else:
# Constrained uses appropriate prior based on parameterization
if self.model_config.parameterization in ["standard", "linked"]:
if self.model_config.p_capture_param_prior is None:
raise ValueError(
"VCP models require p_capture_param_prior"
)
elif self.model_config.parameterization == "odds_ratio":
if self.model_config.phi_capture_param_prior is None:
raise ValueError(
"VCP models with odds_ratio parameterization "
"require phi_capture_param_prior"
)
else:
# Non-VCP models should not have capture probability priors
if unconstrained:
if self.model_config.p_capture_unconstrained_prior is not None:
raise ValueError(
"Non-VCP models should not have "
"p_capture_unconstrained_prior"
)
else:
if self.model_config.parameterization in ["standard", "linked"]:
if self.model_config.p_capture_param_prior is not None:
raise ValueError(
"Non-VCP models should not have p_capture_param_prior"
)
elif self.model_config.parameterization == "odds_ratio":
if self.model_config.phi_capture_param_prior is not None:
raise ValueError(
"Non-VCP models should not have "
"phi_capture_param_prior"
)
# --------------------------------------------------------------------------
# Create ScribeSVIResults from AnnData object
# --------------------------------------------------------------------------
[docs]
@classmethod
def from_anndata(
cls,
adata: Any,
params: Dict,
loss_history: jnp.ndarray,
model_config: ModelConfig,
**kwargs,
):
"""Create ScribeSVIResults from AnnData object."""
return cls(
params=params,
loss_history=loss_history,
n_cells=adata.n_obs,
n_genes=adata.n_vars,
model_config=model_config,
obs=adata.obs.copy(),
var=adata.var.copy(),
uns=adata.uns.copy(),
n_obs=adata.n_obs,
n_vars=adata.n_vars,
**kwargs,
)
# --------------------------------------------------------------------------
# Get distributions using configs
# --------------------------------------------------------------------------
[docs]
def get_distributions(
self,
backend: str = "numpyro",
split: bool = False,
) -> Dict[str, Any]:
"""
Get the variational distributions for all parameters.
This method now delegates to the model-specific
`get_posterior_distributions` function associated with the
parameterization.
Parameters
----------
backend : str, default="numpyro"
Statistical package to use for distributions. Must be one of: -
"scipy": Returns scipy.stats distributions - "numpyro": Returns
numpyro.distributions
split : bool, default=False
If True, returns lists of individual distributions for
multidimensional parameters instead of batch distributions.
Returns
-------
Dict[str, Any]
Dictionary mapping parameter names to their distributions.
Raises
------
ValueError
If backend is not supported.
"""
if backend not in ["scipy", "numpyro"]:
raise ValueError(f"Invalid backend: {backend}")
# Define whether the model is unconstrained
unconstrained = getattr(self.model_config, "unconstrained", False)
# Define whether the model is low-rank
low_rank = self.model_config.guide_rank is not None
# Dynamically import the correct posterior distribution function
if unconstrained and not low_rank:
# For unconstrained variants, import the _unconstrained modules
if self.model_config.parameterization == "standard":
from ..models.standard_unconstrained import (
get_posterior_distributions as get_dist_fn,
)
elif self.model_config.parameterization == "linked":
from ..models.linked_unconstrained import (
get_posterior_distributions as get_dist_fn,
)
elif self.model_config.parameterization == "odds_ratio":
from ..models.odds_ratio_unconstrained import (
get_posterior_distributions as get_dist_fn,
)
else:
raise NotImplementedError(
f"get_distributions not implemented for unconstrained "
f"'{self.model_config.parameterization}'."
)
elif unconstrained and low_rank:
# For unconstrained variants, import the _unconstrained modules
if self.model_config.parameterization == "standard":
from ..models.standard_low_rank_unconstrained import (
get_posterior_distributions as get_dist_fn,
)
elif self.model_config.parameterization == "linked":
from ..models.linked_low_rank_unconstrained import (
get_posterior_distributions as get_dist_fn,
)
elif self.model_config.parameterization == "odds_ratio":
from ..models.odds_ratio_low_rank_unconstrained import (
get_posterior_distributions as get_dist_fn,
)
else:
raise NotImplementedError(
f"get_distributions not implemented for unconstrained "
"low-rank variants of "
f"'{self.model_config.parameterization}'."
)
elif not unconstrained and not low_rank:
# For constrained variants, import the regular modules
if self.model_config.parameterization == "standard":
from ..models.standard import (
get_posterior_distributions as get_dist_fn,
)
elif self.model_config.parameterization == "linked":
from ..models.linked import (
get_posterior_distributions as get_dist_fn,
)
elif self.model_config.parameterization == "odds_ratio":
from ..models.odds_ratio import (
get_posterior_distributions as get_dist_fn,
)
else:
raise NotImplementedError(
f"get_distributions not implemented for "
f"'{self.model_config.parameterization}'."
)
elif not unconstrained and low_rank:
# For constrained variants, import the regular modules
if self.model_config.parameterization == "standard":
from ..models.standard_low_rank import (
get_posterior_distributions as get_dist_fn,
)
elif self.model_config.parameterization == "linked":
from ..models.linked_low_rank import (
get_posterior_distributions as get_dist_fn,
)
elif self.model_config.parameterization == "odds_ratio":
from ..models.odds_ratio_low_rank import (
get_posterior_distributions as get_dist_fn,
)
else:
raise NotImplementedError(
f"get_distributions not implemented for "
"low-rank variants of "
f"'{self.model_config.parameterization}'."
)
distributions = get_dist_fn(self.params, self.model_config, split=split)
if backend == "scipy":
# Handle conversion to scipy, accounting for split distributions
scipy_distributions = {}
for name, dist_obj in distributions.items():
if isinstance(dist_obj, list):
# Handle split distributions - convert each element
if all(isinstance(sublist, list) for sublist in dist_obj):
# Handle nested lists (2D case: components × genes)
scipy_distributions[name] = [
[numpyro_to_scipy(d) for d in sublist]
for sublist in dist_obj
]
else:
# Handle simple lists (1D case: genes or components)
scipy_distributions[name] = [
numpyro_to_scipy(d) for d in dist_obj
]
else:
# Handle single distribution
scipy_distributions[name] = numpyro_to_scipy(dist_obj)
return scipy_distributions
return distributions
# --------------------------------------------------------------------------
[docs]
def get_map(
self,
use_mean: bool = False,
canonical: bool = True,
verbose: bool = True,
) -> Dict[str, jnp.ndarray]:
"""
Get the maximum a posteriori (MAP) estimates from the variational
posterior.
Parameters
----------
use_mean : bool, default=False
If True, replaces undefined MAP values (NaN) with posterior means
canonical : bool, default=True
If True, includes canonical parameters (p, r) computed from other parameters
for linked, odds_ratio, and unconstrained parameterizations
verbose : bool, default=True
If True, prints a warning if NaNs were replaced with means
Returns
-------
Dict[str, jnp.ndarray]
Dictionary of MAP estimates for each parameter
"""
# Get distributions with NumPyro backend
distributions = self.get_distributions(backend="numpyro")
# Get estimate of map
map_estimates = {}
for param, dist_obj in distributions.items():
# Handle transformed distributions (dict with 'base' and 'transform')
# This is used for low-rank guides with transformations
if (
isinstance(dist_obj, dict)
and "base" in dist_obj
and "transform" in dist_obj
):
# For transformed distributions, MAP is transform(base.loc)
base_dist = dist_obj["base"]
transform = dist_obj["transform"]
if hasattr(base_dist, "loc"):
map_estimates[param] = transform(base_dist.loc)
else:
# Fallback to mean if loc not available
map_estimates[param] = transform(base_dist.mean)
# Handle multivariate distributions (like LowRankMultivariateNormal)
# For multivariate normals, mode = mean = loc
elif hasattr(dist_obj, "loc") and not hasattr(dist_obj, "mode"):
map_estimates[param] = dist_obj.loc
elif hasattr(dist_obj, "mode"):
map_estimates[param] = dist_obj.mode
else:
map_estimates[param] = dist_obj.mean
# Replace NaN values with means if requested
if use_mean:
# Initialize boolean to track if any NaNs were replaced
replaced_nans = False
# Check each parameter for NaNs and replace with means
for param, value in map_estimates.items():
# Check if any values are NaN
if jnp.any(jnp.isnan(value)):
replaced_nans = True
# Get mean value
mean_value = distributions[param].mean
# Replace NaN values with means
map_estimates[param] = jnp.where(
jnp.isnan(value), mean_value, value
)
# Print warning if NaNs were replaced
if replaced_nans and verbose:
warnings.warn(
"NaN values were replaced with means of the distributions",
UserWarning,
)
# Compute canonical parameters if requested
if canonical:
map_estimates = self._compute_canonical_parameters(
map_estimates, verbose=verbose
)
return map_estimates
# --------------------------------------------------------------------------
def _compute_canonical_parameters(
self, map_estimates: Dict, verbose: bool = True
) -> Dict:
"""
Compute canonical parameters (p, r) from other parameters for different
parameterizations.
Parameters
----------
map_estimates : Dict
Dictionary containing MAP estimates
verbose : bool, default=True
If True, prints information about parameter computation
Returns
-------
Dict
Updated dictionary with canonical parameters computed
"""
estimates = map_estimates.copy()
parameterization = self.model_config.parameterization
unconstrained = getattr(self.model_config, "unconstrained", False)
# Handle linked parameterization
if parameterization == "linked":
if "mu" in estimates and "p" in estimates and "r" not in estimates:
if verbose:
print(
"Computing r from mu and p for linked parameterization"
)
# r = mu * (1 - p) / p
p = estimates["p"]
if (
self.n_components is not None
and self.model_config.component_specific_params
):
# Mixture model: mu has shape (n_components, n_genes)
# p has shape (n_components,). Reshape for broadcasting.
p_reshaped = p[:, None]
else:
# Non-mixture or shared p: p is scalar, broadcasts.
p_reshaped = p
estimates["r"] = estimates["mu"] * (1 - p_reshaped) / p_reshaped
# Handle odds_ratio parameterization
elif parameterization == "odds_ratio":
# Convert phi to p if needed
if "phi" in estimates and "p" not in estimates:
if verbose:
print(
"Computing p from phi for odds_ratio parameterization"
)
estimates["p"] = 1.0 / (1.0 + estimates["phi"])
# Convert phi and mu to r if needed
if (
"phi" in estimates
and "mu" in estimates
and "r" not in estimates
):
if verbose:
print(
"Computing r from phi and mu for odds_ratio parameterization"
)
# Reshape phi to broadcast with mu based on mixture model
if (
self.n_components is not None
and self.model_config.component_specific_params
):
# Mixture model: mu has shape (n_components, n_genes)
phi_reshaped = estimates["phi"][:, None]
else:
# Non-mixture model: mu has shape (n_genes,)
phi_reshaped = estimates["phi"]
estimates["r"] = estimates["mu"] * phi_reshaped
# Handle VCP capture probability conversion
if "phi_capture" in estimates and "p_capture" not in estimates:
if verbose:
print(
"Computing p_capture from phi_capture for odds_ratio parameterization"
)
estimates["p_capture"] = 1.0 / (1.0 + estimates["phi_capture"])
# Handle unconstrained parameterization
if unconstrained:
# Convert r_unconstrained to r if needed
if "r_unconstrained" in estimates and "r" not in estimates:
if verbose:
print(
"Computing r from r_unconstrained for unconstrained parameterization"
)
estimates["r"] = jnp.exp(estimates["r_unconstrained"])
# Convert p_unconstrained to p if needed
if "p_unconstrained" in estimates and "p" not in estimates:
if verbose:
print(
"Computing p from p_unconstrained for unconstrained parameterization"
)
estimates["p"] = sigmoid(estimates["p_unconstrained"])
# Convert gate_unconstrained to gate if needed
if "gate_unconstrained" in estimates and "gate" not in estimates:
if verbose:
print(
"Computing gate from gate_unconstrained for unconstrained parameterization"
)
estimates["gate"] = sigmoid(estimates["gate_unconstrained"])
# Handle VCP capture probability conversion
if (
"p_capture_unconstrained" in estimates
and "p_capture" not in estimates
):
if verbose:
print(
"Computing p_capture from p_capture_unconstrained for unconstrained parameterization"
)
estimates["p_capture"] = sigmoid(
estimates["p_capture_unconstrained"]
)
# Handle mixing weights computation for mixture models
if (
"mixing_logits_unconstrained" in estimates
and "mixing_weights" not in estimates
):
# Compute mixing weights from mixing_logits_unconstrained using
# softmax
estimates["mixing_weights"] = softmax(
estimates["mixing_logits_unconstrained"], axis=-1
)
# Compute p_hat for NBVCP and ZINBVCP models if needed (applies to all parameterizations)
if (
"p" in estimates
and "p_capture" in estimates
and "p_hat" not in estimates
):
if verbose:
print("Computing p_hat from p and p_capture")
# Reshape p_capture for broadcasting
p_capture_reshaped = estimates["p_capture"][:, None]
# p_hat = p * p_capture / (1 - p * (1 - p_capture))
estimates["p_hat"] = (
estimates["p"]
* p_capture_reshaped
/ (1 - estimates["p"] * (1 - p_capture_reshaped))
).flatten()
return estimates
# --------------------------------------------------------------------------
# Indexing by genes
# --------------------------------------------------------------------------
@staticmethod
def _subset_gene_params(params, param_prefixes, index, n_components=None):
"""
Utility to subset all gene-specific parameters in params dict.
param_prefixes: list of parameter name prefixes (e.g., ["r_", "mu_",
"gate_"]) index: boolean or integer index for genes n_components: if not
None, keep component dimension
"""
new_params = dict(params)
for prefix, arg_constraints in param_prefixes:
if arg_constraints is None:
continue
for param_name in arg_constraints:
key = f"{prefix}{param_name}"
if key in params:
if n_components is not None:
new_params[key] = params[key][..., index]
else:
new_params[key] = params[key][index]
return new_params
# --------------------------------------------------------------------------
def _subset_params(self, params: Dict, index) -> Dict:
"""
Create a new parameter dictionary for the given index using a dynamic,
shape-based approach.
"""
new_params = {}
original_n_genes = self.n_genes
for key, value in params.items():
# Find the axis that corresponds to the number of genes.
# This is safer than assuming the position of the gene axis.
try:
# Find the first occurrence of an axis with size `original_n_genes`.
gene_axis = value.shape.index(original_n_genes)
# Build a slicer tuple to index the correct axis.
slicer = [slice(None)] * value.ndim
slicer[gene_axis] = index
new_params[key] = value[tuple(slicer)]
except ValueError:
# This parameter is not gene-specific (no axis matches n_genes),
# so we keep it as is.
new_params[key] = value
return new_params
# --------------------------------------------------------------------------
def _subset_posterior_samples(self, samples: Dict, index) -> Dict:
"""
Create a new posterior samples dictionary for the given index.
"""
if samples is None:
return None
new_samples = {}
# Get the original number of genes before subsetting, which is stored
# in the instance variable self.n_genes.
original_n_genes = self.n_genes
for key, value in samples.items():
# The gene dimension is typically the last one in the posterior
# sample arrays. We check if the last dimension's size matches the
# original number of genes.
if value.ndim > 0 and value.shape[-1] == original_n_genes:
# This is a gene-specific parameter, so we subset it along the
# last axis.
new_samples[key] = value[..., index]
else:
# This is not a gene-specific parameter (e.g., global,
# cell-specific), so we keep it as is.
new_samples[key] = value
return new_samples
# --------------------------------------------------------------------------
def _subset_predictive_samples(
self, samples: jnp.ndarray, index
) -> jnp.ndarray:
"""Create a new predictive samples array for the given index."""
if samples is None:
return None
# For predictive samples, subset the gene dimension (last dimension)
return samples[..., index]
# --------------------------------------------------------------------------
[docs]
def __getitem__(self, index):
"""
Enable indexing of ScribeSVIResults object.
"""
# If index is a boolean mask, use it directly
if isinstance(index, (jnp.ndarray, np.ndarray)) and index.dtype == bool:
bool_index = index
# Handle integer indexing
elif isinstance(index, int):
# Initialize boolean index
bool_index = jnp.zeros(self.n_genes, dtype=bool)
# Set True for the given index
bool_index = bool_index.at[index].set(True)
# Handle slice indexing
elif isinstance(index, slice):
# Get indices from slice
indices = jnp.arange(self.n_genes)[index]
# Initialize boolean index
bool_index = jnp.zeros(self.n_genes, dtype=bool)
# Set True for the given indices
bool_index = jnp.isin(jnp.arange(self.n_genes), indices)
# Handle list/array indexing (by integer indices)
elif isinstance(index, (list, np.ndarray, jnp.ndarray)) and not (
isinstance(index, (jnp.ndarray, np.ndarray)) and index.dtype == bool
):
indices = jnp.array(index)
bool_index = jnp.isin(jnp.arange(self.n_genes), indices)
else:
raise TypeError(f"Unsupported index type: {type(index)}")
# Create new params dict with subset of parameters
new_params = self._subset_params(self.params, bool_index)
# Create new metadata if available
new_var = self.var.iloc[bool_index] if self.var is not None else None
# Create new posterior samples if available
new_posterior_samples = (
self._subset_posterior_samples(self.posterior_samples, bool_index)
if self.posterior_samples is not None
else None
)
# Create new predictive samples if available
new_predictive_samples = (
self._subset_predictive_samples(self.predictive_samples, bool_index)
if self.predictive_samples is not None
else None
)
# Create new instance with subset data
return self._create_subset(
index=bool_index,
new_params=new_params,
new_var=new_var,
new_posterior_samples=new_posterior_samples,
new_predictive_samples=new_predictive_samples,
)
# --------------------------------------------------------------------------
def _create_subset(
self,
index,
new_params: Dict,
new_var: Optional[pd.DataFrame],
new_posterior_samples: Optional[Dict],
new_predictive_samples: Optional[jnp.ndarray],
) -> "ScribeSVIResults":
"""Create a new instance with a subset of genes."""
return type(self)(
params=new_params,
loss_history=self.loss_history,
n_cells=self.n_cells,
n_genes=int(index.sum() if hasattr(index, "sum") else len(index)),
model_type=self.model_type,
model_config=self.model_config,
prior_params=self.prior_params,
obs=self.obs,
var=new_var,
uns=self.uns,
n_obs=self.n_obs,
n_vars=new_var.shape[0] if new_var is not None else None,
posterior_samples=new_posterior_samples,
predictive_samples=new_predictive_samples,
n_components=self.n_components,
)
# --------------------------------------------------------------------------
# Indexing by component
# --------------------------------------------------------------------------
[docs]
def get_component(self, component_index):
"""
Create a view of the results selecting a specific mixture component.
This method returns a new ScribeSVIResults object that contains
parameter values for the specified component, allowing for further
gene-based indexing. Only applicable to mixture models.
Parameters
----------
component_index : int
Index of the component to select
Returns
-------
ScribeSVIResults
A new ScribeSVIResults object with parameters for the selected
component
Raises
------
ValueError
If the model is not a mixture model
"""
# Check if this is a mixture model
if self.n_components is None or self.n_components <= 1:
raise ValueError(
"Component view only applies to mixture models with multiple components"
)
# Check if component_index is valid
if component_index < 0 or component_index >= self.n_components:
raise ValueError(
f"Component index {component_index} out of range [0, {self.n_components-1}]"
)
# Create new params dict with component subset
new_params = dict(self.params)
# Handle all parameters based on their structure
self._subset_params_by_component(new_params, component_index)
# Create new posterior samples if available
new_posterior_samples = None
if self.posterior_samples is not None:
new_posterior_samples = self._subset_posterior_samples_by_component(
self.posterior_samples, component_index
)
# Create new predictive samples if available - this is more complex
# as we would need to condition on the component
new_predictive_samples = None
# Create new instance with component subset
return self._create_component_subset(
component_index=component_index,
new_params=new_params,
new_posterior_samples=new_posterior_samples,
new_predictive_samples=new_predictive_samples,
)
# --------------------------------------------------------------------------
def _subset_params_by_component(
self, new_params: Dict, component_index: int
):
"""
Handle subsetting of all parameters based on their structure.
This method intelligently handles parameters based on their dimensions
and naming conventions, regardless of parameterization.
"""
# Define parameter categories based on their structure
# Component-gene-specific parameters (shape: [n_components, n_genes])
# These parameters have both component and gene dimensions
component_gene_specific = [
# Standard parameterization
"r_loc",
"r_scale", # dispersion parameters
"gate_alpha",
"gate_beta", # zero-inflation parameters
# Standard unconstrained parameterization
"r_unconstrained_loc",
"r_unconstrained_scale",
"gate_unconstrained_loc",
"gate_unconstrained_scale",
# Linked parameterization
"mu_loc",
"mu_scale", # mean parameters
"gate_alpha",
"gate_beta", # zero-inflation parameters
# Odds ratio parameterization
"phi_alpha",
"phi_beta", # odds ratio parameters
"gate_alpha",
"gate_beta", # zero-inflation parameters
# Odds ratio unconstrained parameterization
"phi_unconstrained_loc",
"phi_unconstrained_scale",
"gate_unconstrained_loc",
"gate_unconstrained_scale",
# Low-rank guide parameters (standard constrained)
"log_r_loc",
"log_r_W",
"log_r_raw_diag",
# Low-rank guide parameters (standard unconstrained)
"r_unconstrained_W",
"r_unconstrained_raw_diag",
# Low-rank guide parameters (linked/odds_ratio constrained)
"log_mu_loc",
"log_mu_W",
"log_mu_raw_diag",
# Low-rank guide parameters (linked/odds_ratio unconstrained)
"mu_unconstrained_loc",
"mu_unconstrained_W",
"mu_unconstrained_raw_diag",
# Low-rank guide parameters (gate - unconstrained)
"gate_unconstrained_W",
"gate_unconstrained_raw_diag",
]
# Component-specific parameters (shape: [n_components])
# These parameters have only component dimension
component_specific = [
# Standard unconstrained parameterization
"p_unconstrained_loc",
"p_unconstrained_scale",
"mixing_logits_unconstrained_loc",
"mixing_logits_unconstrained_scale",
# Odds ratio unconstrained parameterization
"phi_unconstrained_loc",
"phi_unconstrained_scale",
"mixing_logits_unconstrained_loc",
"mixing_logits_unconstrained_scale",
]
# Cell-specific parameters (shape: [n_cells])
# These parameters are cell-specific and not component-specific
cell_specific = [
# Standard parameterization
"p_capture_alpha",
"p_capture_beta", # capture probability parameters
# Standard unconstrained parameterization
"p_capture_unconstrained_loc",
"p_capture_unconstrained_scale",
# Linked parameterization
"p_capture_alpha",
"p_capture_beta", # capture probability parameters
# Odds ratio parameterization
"phi_capture_alpha",
"phi_capture_beta", # capture odds ratio parameters
# Odds ratio unconstrained parameterization
"phi_capture_unconstrained_loc",
"phi_capture_unconstrained_scale",
]
# Parameters that can be either component-specific or shared depending on model config
# These need special handling based on component_specific_params setting
configurable_params = [
# Standard parameterization
"p_alpha",
"p_beta", # success probability parameters
# Linked parameterization
"p_alpha",
"p_beta", # success probability parameters
# Odds ratio parameterization
"phi_alpha",
"phi_beta", # odds ratio parameters
]
# Shared parameters (scalar or global)
# These parameters are shared across all components
shared_params = [
# Standard parameterization
"mixing_conc", # mixture concentrations
# Standard unconstrained parameterization
"mixing_logits_unconstrained_loc",
"mixing_logits_unconstrained_scale",
# Linked parameterization
"mixing_conc", # mixture concentrations
# Odds ratio parameterization
"mixing_conc", # mixture concentrations
# Odds ratio unconstrained parameterization
"mixing_logits_unconstrained_loc",
"mixing_logits_unconstrained_scale",
]
# Additional parameters that might be present but not categorized above
# These are typically scalar or global parameters
additional_params = [
# Any other parameters that don't fit the above categories
# This list can be expanded as needed
]
# Handle component-gene-specific parameters (shape: [n_components, n_genes])
for param_name in component_gene_specific:
if param_name in self.params:
param = self.params[param_name]
# Check if parameter has component dimension
if param.ndim > 1: # Has component dimension
new_params[param_name] = param[component_index]
else: # Scalar parameter, copy as-is
new_params[param_name] = param
# Handle component-specific parameters (shape: [n_components])
for param_name in component_specific:
if param_name in self.params:
param = self.params[param_name]
# Check if parameter has component dimension
if param.ndim > 0: # Has component dimension
new_params[param_name] = param[component_index]
else: # Scalar parameter, copy as-is
new_params[param_name] = param
# Handle cell-specific parameters (copy as-is, not component-specific)
for param_name in cell_specific:
if param_name in self.params:
new_params[param_name] = self.params[param_name]
# Handle configurable parameters (can be component-specific or shared)
for param_name in configurable_params:
if param_name in self.params:
param = self.params[param_name]
# Check if parameter has component dimension
if param.ndim > 0: # Has component dimension
new_params[param_name] = param[component_index]
else: # Scalar parameter, copy as-is
new_params[param_name] = param
# Handle shared parameters (copy as-is, used across all components)
for param_name in shared_params:
if param_name in self.params:
new_params[param_name] = self.params[param_name]
# Handle any additional parameters that might be present
for param_name in additional_params:
if param_name in self.params:
new_params[param_name] = self.params[param_name]
# Handle any remaining parameters not explicitly categorized
# This ensures we don't miss any parameters
for param_name in self.params:
if param_name not in new_params:
# For any uncategorized parameters, copy as-is
new_params[param_name] = self.params[param_name]
# --------------------------------------------------------------------------
def _subset_posterior_samples_by_component(
self, samples: Dict, component_index: int
) -> Dict:
"""
Create a new posterior samples dictionary for the given component index.
This method handles all parameter types based on their dimensions.
"""
if samples is None:
return None
new_posterior_samples = {}
# Define parameter categories for posterior samples
component_gene_specific_samples = [
# Standard parameterization
"r", # dispersion parameter
"gate", # zero-inflation parameter
# Standard unconstrained parameterization
"r_unconstrained", # dispersion parameter
"gate_unconstrained", # zero-inflation parameter
# Linked parameterization
"mu", # mean parameter
"gate", # zero-inflation parameter
# Odds ratio parameterization
"phi", # odds ratio parameter
"gate", # zero-inflation parameter
# Odds ratio unconstrained parameterization
"phi_unconstrained", # odds ratio parameter
"gate_unconstrained", # zero-inflation parameter
]
component_specific_samples = [
# Standard parameterization
"p", # success probability parameter
# Standard unconstrained parameterization
"p_unconstrained", # success probability parameter
"mixing_logits_unconstrained", # mixing logits
# Linked parameterization
"p", # success probability parameter
# Odds ratio parameterization
"phi", # odds ratio parameter
# Odds ratio unconstrained parameterization
"phi_unconstrained", # odds ratio parameter
"mixing_logits_unconstrained", # mixing logits
]
cell_specific_samples = [
# Standard parameterization
"p_capture", # capture probability parameter
# Standard unconstrained parameterization
"p_capture_unconstrained", # capture probability parameter
# Linked parameterization
"p_capture", # capture probability parameter
# Odds ratio parameterization
"phi_capture", # capture odds ratio parameter
# Odds ratio unconstrained parameterization
"phi_capture_unconstrained", # capture odds ratio parameter
]
# Shared parameters (scalar or global)
shared_samples = [
# Standard parameterization
"mixing_weights", # mixture weights
# Standard unconstrained parameterization
"mixing_logits_unconstrained", # mixing logits
# Linked parameterization
"mixing_weights", # mixture weights
# Odds ratio parameterization
"mixing_weights", # mixture weights
# Odds ratio unconstrained parameterization
"mixing_logits_unconstrained", # mixing logits
]
# Configurable parameters (can be component-specific or shared)
configurable_samples = [
# Standard parameterization
"p", # success probability parameter
# Linked parameterization
"p", # success probability parameter
# Odds ratio parameterization
"phi", # odds ratio parameter
]
# Additional parameters that might be present in posterior samples
# These are typically derived parameters or deterministic values
additional_samples = [
# Any other parameters that don't fit the above categories
# This list can be expanded as needed
]
# Handle component-gene-specific samples
# (shape: [n_samples, n_components, n_genes])
for param_name in component_gene_specific_samples:
if param_name in samples:
sample_value = samples[param_name]
if sample_value.ndim > 2: # Has component dimension
new_posterior_samples[param_name] = sample_value[
:, component_index, :
]
else: # Scalar parameter, copy as-is
new_posterior_samples[param_name] = sample_value
# Handle component-specific samples (shape: [n_samples, n_components])
for param_name in component_specific_samples:
if param_name in samples:
sample_value = samples[param_name]
if sample_value.ndim > 1: # Has component dimension
new_posterior_samples[param_name] = sample_value[
:, component_index
]
else: # Scalar parameter, copy as-is
new_posterior_samples[param_name] = sample_value
# Handle cell-specific samples (copy as-is, not component-specific)
for param_name in cell_specific_samples:
if param_name in samples:
new_posterior_samples[param_name] = samples[param_name]
# Handle shared samples (copy as-is, used across all components)
for param_name in shared_samples:
if param_name in samples:
new_posterior_samples[param_name] = samples[param_name]
# Handle configurable samples (can be component-specific or shared)
for param_name in configurable_samples:
if param_name in samples:
sample_value = samples[param_name]
if sample_value.ndim > 1: # Has component dimension
new_posterior_samples[param_name] = sample_value[
:, component_index
]
else: # Scalar parameter, copy as-is
new_posterior_samples[param_name] = sample_value
# Handle any additional samples that might be present
for param_name in additional_samples:
if param_name in samples:
new_posterior_samples[param_name] = samples[param_name]
# Handle any remaining samples not explicitly categorized
# This ensures we don't miss any parameters
for param_name in samples:
if param_name not in new_posterior_samples:
# For any uncategorized parameters, copy as-is
new_posterior_samples[param_name] = samples[param_name]
return new_posterior_samples
# --------------------------------------------------------------------------
def _create_component_subset(
self,
component_index,
new_params: Dict,
new_posterior_samples: Optional[Dict],
new_predictive_samples: Optional[jnp.ndarray],
) -> "ScribeSVIResults":
"""Create a new instance for a specific component."""
# Create a non-mixture model type
base_model = self.model_type.replace("_mix", "")
# Create a modified model config with n_components=None to indicate
# this is now a non-mixture result after component selection
new_model_config = replace(
self.model_config,
base_model=base_model,
n_components=None,
)
return type(self)(
params=new_params,
loss_history=self.loss_history,
n_cells=self.n_cells,
n_genes=self.n_genes,
model_type=base_model, # Remove _mix suffix
model_config=new_model_config,
prior_params=self.prior_params,
obs=self.obs,
var=self.var,
uns=self.uns,
n_obs=self.n_obs,
n_vars=self.n_vars,
posterior_samples=new_posterior_samples,
predictive_samples=new_predictive_samples,
n_components=None, # No longer a mixture model
)
# --------------------------------------------------------------------------
# Get model and guide functions
# --------------------------------------------------------------------------
def _model_and_guide(self) -> Tuple[Callable, Optional[Callable]]:
"""Get the model and guide functions based on model type."""
from ..models.model_registry import get_model_and_guide
parameterization = self.model_config.parameterization or ""
inference_method = self.model_config.inference_method or ""
prior_type = self.model_config.vae_prior_type or ""
unconstrained = getattr(self.model_config, "unconstrained", False)
guide_rank = self.model_config.guide_rank
return get_model_and_guide(
self.model_type,
parameterization,
inference_method,
prior_type,
unconstrained=unconstrained,
guide_rank=guide_rank,
)
# --------------------------------------------------------------------------
# Get parameterization
# --------------------------------------------------------------------------
def _parameterization(self) -> str:
"""Get the parameterization type."""
return self.model_config.parameterization or ""
# --------------------------------------------------------------------------
# Get if unconstrained
# --------------------------------------------------------------------------
def _unconstrained(self) -> bool:
"""Get if the parameterization is unconstrained."""
return self.model_config.unconstrained
# --------------------------------------------------------------------------
# Get log likelihood function
# --------------------------------------------------------------------------
def _log_likelihood_fn(self) -> Callable:
"""Get the log likelihood function for this model type."""
from ..models.model_registry import get_log_likelihood_fn
return get_log_likelihood_fn(self.model_type)
# --------------------------------------------------------------------------
# Posterior sampling methods
# --------------------------------------------------------------------------
[docs]
def get_posterior_samples(
self,
rng_key: random.PRNGKey = random.PRNGKey(42),
n_samples: int = 100,
store_samples: bool = True,
) -> Dict:
"""Sample parameters from the variational posterior distribution."""
# Get the guide function
model, guide = self._model_and_guide()
if guide is None:
raise ValueError(
f"Could not find a guide for model '{self.model_type}'."
)
# Prepare base model arguments
model_args = {
"n_cells": self.n_cells,
"n_genes": self.n_genes,
"model_config": self.model_config,
}
# Sample from posterior
posterior_samples = sample_variational_posterior(
guide,
self.params,
model,
model_args,
rng_key=rng_key,
n_samples=n_samples,
)
# Store samples if requested
if store_samples:
self.posterior_samples = posterior_samples
return posterior_samples
# --------------------------------------------------------------------------
[docs]
def get_predictive_samples(
self,
rng_key: random.PRNGKey = random.PRNGKey(42),
batch_size: Optional[int] = None,
store_samples: bool = True,
) -> jnp.ndarray:
"""Generate predictive samples using posterior parameter samples."""
from ..models.model_registry import get_model_and_guide
# For predictive sampling, we need the *constrained* model, which has the
# 'counts' sample site. The posterior samples from the unconstrained guide
# can be used with the constrained model.
model, _ = get_model_and_guide(
self.model_type,
self.model_config.parameterization,
self.model_config.inference_method,
self.model_config.vae_prior_type,
unconstrained=False, # Explicitly get the constrained model
guide_rank=None, # Not relevant for the model
)
# Prepare base model arguments
model_args = {
"n_cells": self.n_cells,
"n_genes": self.n_genes,
"model_config": self.model_config,
}
# Check if posterior samples exist
if self.posterior_samples is None:
raise ValueError(
"No posterior samples found. Call get_posterior_samples() first."
)
# Generate predictive samples
predictive_samples = generate_predictive_samples(
model,
self.posterior_samples,
model_args,
rng_key=rng_key,
batch_size=batch_size,
)
# Store samples if requested
if store_samples:
self.predictive_samples = predictive_samples
return predictive_samples
# --------------------------------------------------------------------------
[docs]
def get_ppc_samples(
self,
rng_key: random.PRNGKey = random.PRNGKey(42),
n_samples: int = 100,
batch_size: Optional[int] = None,
store_samples: bool = True,
) -> Dict:
"""Generate posterior predictive check samples."""
# Check if we need to resample parameters
need_params = self.posterior_samples is None
# Generate posterior samples if needed
if need_params:
# Sample parameters and generate predictive samples
self.get_posterior_samples(
rng_key=rng_key,
n_samples=n_samples,
store_samples=store_samples,
)
# Generate predictive samples using existing parameters
_, key_pred = random.split(rng_key)
self.get_predictive_samples(
rng_key=key_pred,
batch_size=batch_size,
store_samples=store_samples,
)
return {
"parameter_samples": self.posterior_samples,
"predictive_samples": self.predictive_samples,
}
# --------------------------------------------------------------------------
# Compute log likelihood methods
# --------------------------------------------------------------------------
[docs]
def log_likelihood(
self,
counts: jnp.ndarray,
batch_size: Optional[int] = None,
return_by: str = "cell",
cells_axis: int = 0,
ignore_nans: bool = False,
split_components: bool = False,
weights: Optional[jnp.ndarray] = None,
weight_type: Optional[str] = None,
dtype: jnp.dtype = jnp.float32,
) -> jnp.ndarray:
"""
Compute log likelihood of data under posterior samples.
Parameters
----------
counts : jnp.ndarray
Count data to evaluate likelihood on
batch_size : Optional[int], default=None
Size of mini-batches used for likelihood computation
return_by : str, default='cell'
Specifies how to return the log probabilities. Must be one of:
- 'cell': returns log probabilities summed over genes
- 'gene': returns log probabilities summed over cells
cells_axis : int, default=0
Axis along which cells are arranged. 0 means cells are rows.
ignore_nans : bool, default=False
If True, removes any samples that contain NaNs.
split_components : bool, default=False
If True, returns log likelihoods for each mixture component
separately. Only applicable for mixture models.
weights : Optional[jnp.ndarray], default=None
Array used to weight the log likelihoods (for mixture models).
weight_type : Optional[str], default=None
How to apply weights. Must be one of:
- 'multiplicative': multiply log probabilities by weights
- 'additive': add weights to log probabilities
dtype : jnp.dtype, default=jnp.float32
Data type for numerical precision in computations
Returns
-------
jnp.ndarray
Array of log likelihoods. Shape depends on model type, return_by and
split_components parameters. For standard models:
- 'cell': shape (n_samples, n_cells)
- 'gene': shape (n_samples, n_genes)
For mixture models with split_components=False:
- 'cell': shape (n_samples, n_cells)
- 'gene': shape (n_samples, n_genes)
For mixture models with split_components=True:
- 'cell': shape (n_samples, n_cells, n_components)
- 'gene': shape (n_samples, n_genes, n_components)
Raises
------
ValueError
If posterior samples have not been generated yet
"""
# Check if posterior samples exist
if self.posterior_samples is None:
raise ValueError(
"No posterior samples found. Call get_posterior_samples() first."
)
# Convert posterior samples to canonical form
self._convert_to_canonical()
# Get parameter samples
parameter_samples = self.posterior_samples
# Get number of samples from first parameter
n_samples = parameter_samples[next(iter(parameter_samples))].shape[0]
# Get likelihood function
likelihood_fn = self._log_likelihood_fn()
# Determine if this is a mixture model
is_mixture = self.n_components is not None and self.n_components > 1
# Define function to compute likelihood for a single sample
@jit
def compute_sample_lik(i):
# Extract parameters for this sample
params_i = {k: v[i] for k, v in parameter_samples.items()}
# For mixture models we need to pass split_components and weights
if is_mixture:
return likelihood_fn(
counts,
params_i,
batch_size=batch_size,
cells_axis=cells_axis,
return_by=return_by,
split_components=split_components,
weights=weights,
weight_type=weight_type,
dtype=dtype,
)
else:
return likelihood_fn(
counts,
params_i,
batch_size=batch_size,
cells_axis=cells_axis,
return_by=return_by,
dtype=dtype,
)
# Use vmap for parallel computation (more memory intensive)
log_liks = vmap(compute_sample_lik)(jnp.arange(n_samples))
# Handle NaNs if requested
if ignore_nans:
# Check for NaNs appropriately based on dimensions
if is_mixture and split_components:
# Handle case with component dimension
valid_samples = ~jnp.any(
jnp.any(jnp.isnan(log_liks), axis=-1), axis=-1
)
else:
# Standard case
valid_samples = ~jnp.any(jnp.isnan(log_liks), axis=-1)
# Filter out samples with NaNs
if jnp.any(~valid_samples):
print(
f" - Fraction of samples removed: {1 - jnp.mean(valid_samples)}"
)
return log_liks[valid_samples]
return log_liks
# --------------------------------------------------------------------------
[docs]
def log_likelihood_map(
self,
counts: jnp.ndarray,
batch_size: Optional[int] = None,
gene_batch_size: Optional[int] = None,
return_by: str = "cell",
cells_axis: int = 0,
split_components: bool = False,
weights: Optional[jnp.ndarray] = None,
weight_type: Optional[str] = None,
use_mean: bool = True,
verbose: bool = True,
dtype: jnp.dtype = jnp.float32,
) -> jnp.ndarray:
"""
Compute log likelihood of data using MAP parameter estimates.
Parameters
----------
counts : jnp.ndarray
Count data to evaluate likelihood on
batch_size : Optional[int], default=None
Size of mini-batches used for likelihood computation
gene_batch_size : Optional[int], default=None
Size of mini-batches used for likelihood computation by gene
return_by : str, default='cell'
Specifies how to return the log probabilities. Must be one of:
- 'cell': returns log probabilities summed over genes
- 'gene': returns log probabilities summed over cells
cells_axis : int, default=0
Axis along which cells are arranged. 0 means cells are rows.
split_components : bool, default=False
If True, returns log likelihoods for each mixture component separately.
Only applicable for mixture models.
weights : Optional[jnp.ndarray], default=None
Array used to weight the log likelihoods (for mixture models).
weight_type : Optional[str], default=None
How to apply weights. Must be one of:
- 'multiplicative': multiply log probabilities by weights
- 'additive': add weights to log probabilities
use_mean : bool, default=False
If True, replaces undefined MAP values (NaN) with posterior means
verbose : bool, default=True
If True, prints a warning if NaNs were replaced with means
dtype : jnp.dtype, default=jnp.float32
Data type for numerical precision in computations
Returns
-------
jnp.ndarray
Array of log likelihoods. Shape depends on model type, return_by and
split_components parameters.
"""
# Get the log likelihood function
likelihood_fn = self._log_likelihood_fn()
# Determine if this is a mixture model
is_mixture = self.n_components is not None and self.n_components > 1
# Get the MAP estimates with canonical parameters included
map_estimates = self.get_map(
use_mean=use_mean, canonical=True, verbose=verbose
)
# If computing by gene and gene_batch_size is provided, use batched computation
if return_by == "gene" and gene_batch_size is not None:
# Determine output shape
if (
is_mixture
and split_components
and self.n_components is not None
):
result_shape = (self.n_genes, self.n_components)
else:
result_shape = (self.n_genes,)
# Initialize result array
log_liks = np.zeros(result_shape, dtype=dtype)
# Process genes in batches
for i in range(0, self.n_genes, gene_batch_size):
if verbose and i > 0:
print(
f"Processing genes {i}-{min(i+gene_batch_size, self.n_genes)} of {self.n_genes}"
)
# Get gene indices for this batch
end_idx = min(i + gene_batch_size, self.n_genes)
gene_indices = list(range(i, end_idx))
# Get subset of results for these genes
results_subset = self[gene_indices]
# Get the MAP estimates for this subset (with canonical parameters)
subset_map_estimates = results_subset.get_map(
use_mean=use_mean, canonical=True, verbose=False
)
# Get subset of counts for these genes
if cells_axis == 0:
counts_subset = counts[:, gene_indices]
else:
counts_subset = counts[gene_indices, :]
# Get subset of weights if provided
weights_subset = None
if weights is not None:
if weights.ndim == 1: # Shape: (n_genes,)
weights_subset = weights[gene_indices]
else:
weights_subset = weights
# Compute log likelihood for this gene batch
if is_mixture:
batch_log_liks = likelihood_fn(
counts_subset,
subset_map_estimates,
batch_size=batch_size,
cells_axis=cells_axis,
return_by=return_by,
split_components=split_components,
weights=weights_subset,
weight_type=weight_type,
dtype=dtype,
)
else:
batch_log_liks = likelihood_fn(
counts_subset,
subset_map_estimates,
batch_size=batch_size,
cells_axis=cells_axis,
return_by=return_by,
dtype=dtype,
)
# Store results
log_liks[i:end_idx] = np.array(batch_log_liks)
# Convert to JAX array for consistency
return jnp.array(log_liks)
# Standard computation (no gene batching)
else:
# Compute log-likelihood for mixture model
if is_mixture:
log_liks = likelihood_fn(
counts,
map_estimates,
batch_size=batch_size,
cells_axis=cells_axis,
return_by=return_by,
split_components=split_components,
weights=weights,
weight_type=weight_type,
dtype=dtype,
)
# Compute log-likelihood for non-mixture model
else:
log_liks = likelihood_fn(
counts,
map_estimates,
batch_size=batch_size,
cells_axis=cells_axis,
return_by=return_by,
dtype=dtype,
)
return log_liks
# --------------------------------------------------------------------------
# Compute entropy of component assignments
# --------------------------------------------------------------------------
[docs]
def mixture_component_entropy(
self,
counts: jnp.ndarray,
return_by: str = "gene",
batch_size: Optional[int] = None,
cells_axis: int = 0,
ignore_nans: bool = False,
temperature: Optional[float] = None,
dtype: jnp.dtype = jnp.float32,
) -> jnp.ndarray:
"""
Compute the entropy of mixture component assignment probabilities.
This method calculates the Shannon entropy of the posterior component
assignment probabilities for each observation (cell or gene), providing
a quantitative measure of assignment uncertainty in mixture models.
The entropy quantifies how uncertain the model is about which component
each observation belongs to:
- Low entropy (≈0): High confidence in component assignment
- High entropy (≈log(n_components)): High uncertainty in assignment
- Maximum entropy: Uniform assignment probabilities across all
components
The entropy is calculated as:
H = -∑(p_i * log(p_i))
where p_i are the posterior probabilities of assignment to component i.
Parameters
----------
counts : jnp.ndarray
Input count data to evaluate component assignments for. Shape should
be (n_cells, n_genes) if cells_axis=0, or (n_genes, n_cells) if
cells_axis=1.
return_by : str, default='gene'
Specifies how to compute and return the entropy. Must be one of:
- 'cell': Compute entropy of component assignments for each cell
- 'gene': Compute entropy of component assignments for each gene
batch_size : Optional[int], default=None
If provided, processes the data in batches of this size to reduce
memory usage. Useful for large datasets.
cells_axis : int, default=0
Specifies which axis in the input counts contains the cells:
- 0: cells are rows (shape: n_cells × n_genes)
- 1: cells are columns (shape: n_genes × n_cells)
ignore_nans : bool, default=False
If True, excludes any samples containing NaN values from the entropy
calculation.
temperature : Optional[float], default=None
If provided, applies temperature scaling to the log-likelihoods
before computing entropy. Temperature scaling modifies the sharpness
of probability distributions by dividing log probabilities by a
temperature parameter T.
dtype : jnp.dtype, default=jnp.float32
Data type for numerical precision in computations.
Returns
-------
jnp.ndarray
Array of entropy values. Shape depends on return_by:
- If return_by='cell': shape is (n_samples, n_cells)
- If return_by='gene': shape is (n_samples, n_genes)
Higher values indicate more uncertainty in component assignments.
Raises
------
ValueError
If the model is not a mixture model or if posterior samples haven't
been generated.
Notes
-----
- This method requires posterior samples to be available. Call
get_posterior_samples() first if they haven't been generated.
- The entropy is computed using the full posterior predictive
distribution, accounting for uncertainty in the model parameters.
- For a mixture with K components, the maximum possible entropy is
log(K).
- Entropy values can be used to identify observations that are difficult
to classify or that may belong to multiple components.
"""
# Check if this is a mixture model
if self.n_components is None or self.n_components <= 1:
raise ValueError(
"Mixture component entropy calculation only applies to mixture "
"models with multiple components"
)
# Check if posterior samples exist
if self.posterior_samples is None:
raise ValueError(
"No posterior samples found. Call get_posterior_samples() first."
)
# Convert posterior samples to canonical form
self._convert_to_canonical()
print("Computing log-likelihoods...")
# Compute log-likelihoods for each component
log_liks = self.log_likelihood(
counts,
batch_size=batch_size,
cells_axis=cells_axis,
return_by=return_by,
ignore_nans=ignore_nans,
dtype=dtype,
split_components=True, # Ensure we get per-component likelihoods
)
# Apply temperature scaling if requested
if temperature is not None:
from ..core.cell_type_assignment import temperature_scaling
log_liks = temperature_scaling(log_liks, temperature, dtype=dtype)
print("Converting log-likelihoods to probabilities...")
# Convert log-likelihoods to probabilities
probs = softmax(log_liks, axis=-1)
print("Computing entropy...")
# Compute entropy: -∑(p_i * log(p_i))
# Add small epsilon to avoid log(0)
eps = jnp.finfo(dtype).eps
entropy = -jnp.sum(probs * jnp.log(probs + eps), axis=-1)
return entropy
# --------------------------------------------------------------------------
[docs]
def assignment_entropy_map(
self,
counts: jnp.ndarray,
return_by: str = "gene",
batch_size: Optional[int] = None,
cells_axis: int = 0,
temperature: Optional[float] = None,
use_mean: bool = True,
verbose: bool = True,
dtype: jnp.dtype = jnp.float32,
) -> jnp.ndarray:
"""
Compute the entropy of component assignments for each cell evaluated at
the MAP.
This method calculates the entropy of the posterior component assignment
probabilities for each cell or gene, providing a measure of assignment
uncertainty. Higher entropy values indicate more uncertainty in the
component assignments, while lower values indicate more confident
assignments.
The entropy is calculated as:
H = -∑(p_i * log(p_i))
where p_i are the normalized probabilities for each component.
Parameters
----------
counts : jnp.ndarray
The count matrix with shape (n_cells, n_genes).
return_by : str, default='gene'
Whether to return the entropy by cell or gene.
batch_size : Optional[int], default=None
Size of mini-batches for likelihood computation
cells_axis : int, default=0
Axis along which cells are arranged. 0 means cells are rows.
temperature : Optional[float], default=None
If provided, applies temperature scaling to the log-likelihoods
before computing entropy.
use_mean : bool, default=True
If True, uses the mean of the posterior component probabilities
instead of the MAP.
verbose : bool, default=True
If True, prints a warning if NaNs were replaced with means
dtype : jnp.dtype, default=jnp.float32
Data type for numerical precision in computations
Returns
-------
jnp.ndarray
The component entropy for each cell evaluated at the MAP. Shape:
(n_cells,).
Raises
------
ValueError
- If the model is not a mixture model
- If posterior samples have not been generated yet
"""
# Check if this is a mixture model
if self.n_components is None or self.n_components <= 1:
raise ValueError(
"Component entropy calculation only applies to mixture models "
"with multiple components"
)
# Compute log-likelihood at the MAP
log_liks = self.log_likelihood_map(
counts,
batch_size=batch_size,
cells_axis=cells_axis,
use_mean=use_mean,
verbose=verbose,
dtype=dtype,
return_by=return_by,
split_components=True,
)
# Apply temperature scaling if requested
if temperature is not None:
from ..core.cell_type_assignment import temperature_scaling
log_liks = temperature_scaling(log_liks, temperature, dtype=dtype)
# Compute log-sum-exp for normalization
log_sum_exp = jsp.special.logsumexp(log_liks, axis=-1, keepdims=True)
# Compute probabilities (avoiding log space for final entropy calculation)
probs = jnp.exp(log_liks - log_sum_exp)
# Compute entropy: -∑(p_i * log(p_i))
# Add small epsilon to avoid log(0)
eps = jnp.finfo(dtype).eps
entropy = -jnp.sum(probs * jnp.log(probs + eps), axis=-1)
return entropy
# --------------------------------------------------------------------------
# Cell type assignment method for mixture models
# --------------------------------------------------------------------------
[docs]
def cell_type_probabilities(
self,
counts: jnp.ndarray,
batch_size: Optional[int] = None,
cells_axis: int = 0,
ignore_nans: bool = False,
dtype: jnp.dtype = jnp.float32,
fit_distribution: bool = True,
temperature: Optional[float] = None,
weights: Optional[jnp.ndarray] = None,
weight_type: Optional[str] = None,
verbose: bool = True,
) -> Dict[str, jnp.ndarray]:
"""
Compute probabilistic cell type assignments and fit Dirichlet
distributions to characterize assignment uncertainty.
For each cell, this method:
1. Computes component-specific log-likelihoods using posterior
samples
2. Converts these to probability distributions over cell types
3. Fits a Dirichlet distribution to characterize the uncertainty in
these assignments
Parameters
----------
counts : jnp.ndarray
Count data to evaluate assignments for
batch_size : Optional[int], default=None
Size of mini-batches for likelihood computation
cells_axis : int, default=0
Axis along which cells are arranged. 0 means cells are rows.
ignore_nans : bool, default=False
If True, removes any samples that contain NaNs.
dtype : jnp.dtype, default=jnp.float32
Data type for numerical precision in computations
fit_distribution : bool, default=True
If True, fits a Dirichlet distribution to the assignment
probabilities
temperature : Optional[float], default=None
If provided, apply temperature scaling to log probabilities
weights : Optional[jnp.ndarray], default=None
Array used to weight genes when computing log likelihoods
weight_type : Optional[str], default=None
How to apply weights. Must be one of:
- 'multiplicative': multiply log probabilities by weights
- 'additive': add weights to log probabilities
verbose : bool, default=True
If True, prints progress messages
Returns
-------
Dict[str, jnp.ndarray]
Dictionary containing:
- 'concentration': Dirichlet concentration parameters for each
cell. Shape: (n_cells, n_components). Only returned if
fit_distribution is True.
- 'mean_probabilities': Mean assignment probabilities for each
cell. Shape: (n_cells, n_components). Only returned if
fit_distribution is True.
- 'sample_probabilities': Assignment probabilities for each
posterior sample. Shape: (n_samples, n_cells, n_components)
Raises
------
ValueError
- If the model is not a mixture model
- If posterior samples have not been generated yet
Note
----
Most of the log-likelihood value differences between cell types are
extremely large. Thus, the computation usually returns either 0 or 1.
This computation is therefore not very useful, but it is included for
completeness.
"""
from ..core.cell_type_assignment import compute_cell_type_probabilities
return compute_cell_type_probabilities(
results=self,
counts=counts,
batch_size=batch_size,
cells_axis=cells_axis,
ignore_nans=ignore_nans,
dtype=dtype,
fit_distribution=fit_distribution,
temperature=temperature,
weights=weights,
weight_type=weight_type,
verbose=verbose,
)
# --------------------------------------------------------------------------
[docs]
def cell_type_probabilities_map(
self,
counts: jnp.ndarray,
batch_size: Optional[int] = None,
cells_axis: int = 0,
dtype: jnp.dtype = jnp.float32,
temperature: Optional[float] = None,
weights: Optional[jnp.ndarray] = None,
weight_type: Optional[str] = None,
use_mean: bool = False,
verbose: bool = True,
) -> Dict[str, jnp.ndarray]:
"""
Compute probabilistic cell type assignments using MAP estimates of
parameters.
For each cell, this method:
1. Computes component-specific log-likelihoods using MAP parameter
estimates
2. Converts these to probability distributions over cell types
Parameters
----------
counts : jnp.ndarray
Count data to evaluate assignments for
batch_size : Optional[int], default=None
Size of mini-batches for likelihood computation
cells_axis : int, default=0
Axis along which cells are arranged. 0 means cells are rows.
dtype : jnp.dtype, default=jnp.float32
Data type for numerical precision in computations
temperature : Optional[float], default=None
If provided, apply temperature scaling to log probabilities
weights : Optional[jnp.ndarray], default=None
Array used to weight genes when computing log likelihoods
weight_type : Optional[str], default=None
How to apply weights. Must be one of:
- 'multiplicative': multiply log probabilities by weights
- 'additive': add weights to log probabilities
use_mean : bool, default=False
If True, replaces undefined MAP values (NaN) with posterior means
verbose : bool, default=True
If True, prints progress messages
Returns
-------
Dict[str, jnp.ndarray]
Dictionary containing:
- 'probabilities': Assignment probabilities for each cell.
Shape: (n_cells, n_components)
Raises
------
ValueError
If the model is not a mixture model
"""
from ..core.cell_type_assignment import (
compute_cell_type_probabilities_map,
)
return compute_cell_type_probabilities_map(
results=self,
counts=counts,
batch_size=batch_size,
cells_axis=cells_axis,
dtype=dtype,
temperature=temperature,
weights=weights,
weight_type=weight_type,
use_mean=use_mean,
verbose=verbose,
)
# --------------------------------------------------------------------------
# Count normalization methods
# --------------------------------------------------------------------------
[docs]
def normalize_counts(
self,
rng_key: random.PRNGKey = random.PRNGKey(42),
n_samples_dirichlet: int = 1,
fit_distribution: bool = False,
store_samples: bool = True,
sample_axis: int = 0,
return_concentrations: bool = False,
backend: str = "numpyro",
verbose: bool = True,
) -> Dict[str, Union[jnp.ndarray, object]]:
"""
Normalize counts using posterior samples of the r parameter.
This method takes posterior samples of the dispersion parameter (r) and
uses them as concentration parameters for Dirichlet distributions to
generate normalized expression profiles. For mixture models,
normalization is performed per component, resulting in an extra
dimension in the output.
Based on the insights from the Dirichlet-multinomial model derivation,
the r parameters represent the concentration parameters of a Dirichlet
distribution that can be used to generate normalized expression
profiles.
The method generates Dirichlet samples using all posterior samples of r,
then fits a single Dirichlet distribution to all these samples (or one
per component for mixture models).
Parameters
----------
rng_key : random.PRNGKey, default=random.PRNGKey(42)
JAX random number generator key
n_samples_dirichlet : int, default=1000
Number of samples to draw from each Dirichlet distribution
fit_distribution : bool, default=True
If True, fits a Dirichlet distribution to the generated samples
using fit_dirichlet_minka from stats.py
store_samples : bool, default=False
If True, includes the raw Dirichlet samples in the output
sample_axis : int, default=0
Axis containing samples in the Dirichlet fitting (passed to
fit_dirichlet_minka)
return_concentrations : bool, default=False
If True, returns the original r parameter samples used as
concentrations
backend : str, default="numpyro"
Statistical package to use for distributions when
fit_distribution=True. Must be one of: - "numpyro": Returns
numpyro.distributions.Dirichlet objects - "scipy": Returns
scipy.stats distributions via numpyro_to_scipy conversion
verbose : bool, default=True
If True, prints progress messages
Returns
-------
Dict[str, Union[jnp.ndarray, object]]
Dictionary containing normalized expression profiles. Keys depend on
input arguments:
- 'samples': Raw Dirichlet samples (if store_samples=True)
- 'concentrations': Fitted concentration parameters (if
fit_distribution=True)
- 'mean_probabilities': Mean probabilities from fitted
distribution (if fit_distribution=True)
- 'distributions': Dirichlet distribution objects (if
fit_distribution=True)
- 'original_concentrations': Original r parameter samples (if
return_concentrations=True)
For non-mixture models:
- samples: shape (n_posterior_samples, n_genes,
n_samples_dirichlet) or (n_posterior_samples, n_genes) if
n_samples_dirichlet=1
- concentrations: shape (n_genes,) - single fitted distribution
- mean_probabilities: shape (n_genes,) - single fitted
distribution
- distributions: single Dirichlet distribution object
For mixture models:
- samples: shape (n_posterior_samples, n_components, n_genes,
n_samples_dirichlet) or (n_posterior_samples, n_components,
n_genes) if n_samples_dirichlet=1
- concentrations: shape (n_components, n_genes) - one fitted
distribution per component - mean_probabilities: shape
(n_components, n_genes) - one fitted distribution per component
- distributions: list of n_components Dirichlet distribution
objects
Raises
------
ValueError
If posterior samples have not been generated yet, or if 'r'
parameter is not found in posterior samples
Examples
--------
>>> # For a non-mixture model
>>> normalized = results.normalize_counts(
... n_samples_dirichlet=100,
... fit_distribution=True
... )
>>> print(normalized['mean_probabilities'].shape) # (n_genes,)
>>> print(type(normalized['distributions'])) # Single Dirichlet distribution
>>> # For a mixture model
>>> normalized = results.normalize_counts(
... n_samples_dirichlet=100,
... fit_distribution=True
... )
>>> print(normalized['mean_probabilities'].shape) # (n_components, n_genes)
>>> print(len(normalized['distributions'])) # n_components
"""
# Check if posterior samples exist
if self.posterior_samples is None:
raise ValueError(
"No posterior samples found. Call get_posterior_samples() first."
)
# Convert to canonical form to ensure r parameter is available
self._convert_to_canonical()
# Use the shared normalization function
return normalize_counts_from_posterior(
posterior_samples=self.posterior_samples,
n_components=self.n_components,
rng_key=rng_key,
n_samples_dirichlet=n_samples_dirichlet,
fit_distribution=fit_distribution,
store_samples=store_samples,
sample_axis=sample_axis,
return_concentrations=return_concentrations,
backend=backend,
verbose=verbose,
)
# --------------------------------------------------------------------------
# Parameter conversion method
# --------------------------------------------------------------------------
def _convert_to_canonical(self):
"""
[DEPRECATED] Convert posterior samples to canonical (p, r) form.
This method is deprecated and will be removed in a future version. The
posterior sampling process now automatically returns both constrained and
unconstrained parameters.
"""
warnings.warn(
"The '_convert_to_canonical' method is deprecated and will be removed. "
"Posterior samples are now automatically converted.",
DeprecationWarning,
stacklevel=2,
)
return self