sampling

Sampling utilities for SCRIBE.

scribe.sampling.sample_variational_posterior(guide, params, model, model_args, rng_key=Array([0, 42], dtype=uint32), n_samples=100, return_sites=None)[source]

Sample parameters from the variational posterior distribution.

Parameters:
  • guide (Callable) – Guide function

  • params (Dict) – Dictionary containing optimized variational parameters

  • model (Callable) – Model function

  • model_args (Dict) – Dictionary containing model arguments. For standard models, this is just the number of cells and genes. For mixture models, this is the number of cells, genes, and components.

  • rng_key (random.PRNGKey) – JAX random number generator key

  • n_samples (int, optional) – Number of posterior samples to generate (default: 100)

  • return_sites (Optional[Union[str, List[str]]], optional) – Sites to return from the model. If None, returns all sites.

Returns:

Dictionary containing samples from the variational posterior

Return type:

Dict

scribe.sampling.generate_predictive_samples(model, posterior_samples, model_args, rng_key, batch_size=None)[source]

Generate predictive samples using posterior parameter samples.

Parameters:
  • model (Callable) – Model function

  • posterior_samples (Dict) – Dictionary containing samples from the variational posterior

  • model_args (Dict) – Dictionary containing model arguments. For standard models, this is just the number of cells and genes. For mixture models, this is the number of cells, genes, and components.

  • rng_key (random.PRNGKey) – JAX random number generator key

  • batch_size (int, optional) – Batch size for generating samples. If None, uses full dataset.

Returns:

Array of predictive samples

Return type:

jnp.ndarray

scribe.sampling.generate_ppc_samples(model, guide, params, model_args, rng_key, n_samples=100, batch_size=None)[source]

Generate posterior predictive check samples.

Parameters:
  • model (Callable) – Model function

  • guide (Callable) – Guide function

  • params (Dict) – Dictionary containing optimized variational parameters

  • model_args (Dict) – Dictionary containing model arguments. For standard models, this is just the number of cells and genes. For mixture models, this is the number of cells, genes, and components.

  • rng_key (random.PRNGKey) – JAX random number generator key

  • n_samples (int, optional) – Number of posterior samples to generate (default: 100)

  • batch_size (int, optional) – Batch size for generating samples. If None, uses full dataset.

Returns:

Dictionary containing: - ‘parameter_samples’: Samples from the variational posterior - ‘predictive_samples’: Samples from the predictive distribution

Return type:

Dict

scribe.sampling.generate_prior_predictive_samples(model, model_args, rng_key, n_samples=100, batch_size=None)[source]

Generate prior predictive samples using the model.

Parameters:
  • model (Callable) – Model function

  • model_args (Dict) – Dictionary containing model arguments. For standard models, this is just the number of cells and genes. For mixture models, this is the number of cells, genes, and components.

  • rng_key (random.PRNGKey) – JAX random number generator key

  • n_samples (int, optional) – Number of prior predictive samples to generate (default: 100)

  • batch_size (int, optional) – Batch size for generating samples. If None, uses full dataset.

Returns:

Array of prior predictive samples

Return type:

jnp.ndarray