sampling
Sampling utilities for SCRIBE.
- scribe.sampling.sample_variational_posterior(guide, params, model, model_args, rng_key=Array([0, 42], dtype=uint32), n_samples=100, return_sites=None)[source]
Sample parameters from the variational posterior distribution.
- Parameters:
guide (Callable) – Guide function
params (Dict) – Dictionary containing optimized variational parameters
model (Callable) – Model function
model_args (Dict) – Dictionary containing model arguments. For standard models, this is just the number of cells and genes. For mixture models, this is the number of cells, genes, and components.
rng_key (random.PRNGKey) – JAX random number generator key
n_samples (int, optional) – Number of posterior samples to generate (default: 100)
return_sites (Optional[Union[str, List[str]]], optional) – Sites to return from the model. If None, returns all sites.
- Returns:
Dictionary containing samples from the variational posterior
- Return type:
Dict
- scribe.sampling.generate_predictive_samples(model, posterior_samples, model_args, rng_key, batch_size=None)[source]
Generate predictive samples using posterior parameter samples.
- Parameters:
model (Callable) – Model function
posterior_samples (Dict) – Dictionary containing samples from the variational posterior
model_args (Dict) – Dictionary containing model arguments. For standard models, this is just the number of cells and genes. For mixture models, this is the number of cells, genes, and components.
rng_key (random.PRNGKey) – JAX random number generator key
batch_size (int, optional) – Batch size for generating samples. If None, uses full dataset.
- Returns:
Array of predictive samples
- Return type:
jnp.ndarray
- scribe.sampling.generate_ppc_samples(model, guide, params, model_args, rng_key, n_samples=100, batch_size=None)[source]
Generate posterior predictive check samples.
- Parameters:
model (Callable) – Model function
guide (Callable) – Guide function
params (Dict) – Dictionary containing optimized variational parameters
model_args (Dict) – Dictionary containing model arguments. For standard models, this is just the number of cells and genes. For mixture models, this is the number of cells, genes, and components.
rng_key (random.PRNGKey) – JAX random number generator key
n_samples (int, optional) – Number of posterior samples to generate (default: 100)
batch_size (int, optional) – Batch size for generating samples. If None, uses full dataset.
- Returns:
Dictionary containing: - ‘parameter_samples’: Samples from the variational posterior - ‘predictive_samples’: Samples from the predictive distribution
- Return type:
Dict
- scribe.sampling.generate_prior_predictive_samples(model, model_args, rng_key, n_samples=100, batch_size=None)[source]
Generate prior predictive samples using the model.
- Parameters:
model (Callable) – Model function
model_args (Dict) – Dictionary containing model arguments. For standard models, this is just the number of cells and genes. For mixture models, this is the number of cells, genes, and components.
rng_key (random.PRNGKey) – JAX random number generator key
n_samples (int, optional) – Number of prior predictive samples to generate (default: 100)
batch_size (int, optional) – Batch size for generating samples. If None, uses full dataset.
- Returns:
Array of prior predictive samples
- Return type:
jnp.ndarray