utils
Utility modules for SCRIBE.
This package contains various utility classes and functions used throughout the SCRIBE codebase for parameter collection, data processing, and other common operations.
- class scribe.utils.ParameterCollector[source]
Bases:
objectUtility class for collecting and mapping optional parameters.
This class provides static methods to collect non-None parameters and map them to the appropriate ModelConfig attribute names based on the model parameterization and constraint settings.
Examples
>>> # Collect only non-None parameters >>> params = ParameterCollector.collect_non_none( ... r_prior=(1.0, 1.0), ... p_prior=None, ... gate_prior=(2.0, 0.5) ... ) >>> print(params) {'r_prior': (1.0, 1.0), 'gate_prior': (2.0, 0.5)}
>>> # Collect and map prior parameters for standard parameterization >>> prior_config = ParameterCollector.collect_and_map_priors( ... unconstrained=False, ... parameterization="standard", ... r_prior=(1.0, 1.0), ... p_prior=(2.0, 0.5) ... ) >>> print(prior_config) {'r_param_prior': (1.0, 1.0), 'p_param_prior': (2.0, 0.5)}
>>> # Collect VAE parameters >>> vae_config = ParameterCollector.collect_vae_params( ... vae_latent_dim=5, ... vae_hidden_dims=[256, 128], ... vae_activation="gelu" ... ) >>> print(vae_config) { 'vae_latent_dim': 5, 'vae_hidden_dims': [256, 128], 'vae_activation': 'gelu' }
- static collect_non_none(**kwargs)[source]
Return only non-None keyword arguments.
This is a simple utility to filter out None values from a dictionary of keyword arguments, useful for collecting only the parameters that were explicitly provided by the user.
- Parameters:
**kwargs – Keyword arguments to filter
- Returns:
Dictionary containing only non-None values
- Return type:
Dict[str, Any]
Examples
>>> params = ParameterCollector.collect_non_none( ... a=1, b=None, c="hello", d=None ... ) >>> print(params) {'a': 1, 'c': 'hello'}
- static collect_and_map_priors(unconstrained, parameterization, r_prior=None, p_prior=None, gate_prior=None, p_capture_prior=None, mixing_prior=None, mu_prior=None, phi_prior=None, phi_capture_prior=None)[source]
Collect and map prior parameters to ModelConfig attribute names.
This method collects all non-None prior parameters and maps them to the correct ModelConfig attribute names based on whether the model uses unconstrained parameterization and the specific parameterization type (standard, linked, odds_ratio).
- Parameters:
unconstrained (bool) – Whether the model uses unconstrained parameterization
parameterization (str) – Model parameterization type (“standard”, “linked”, “odds_ratio”)
r_prior (Optional[tuple], default=None) – Prior parameters for dispersion parameter (r)
p_prior (Optional[tuple], default=None) – Prior parameters for success probability (p)
gate_prior (Optional[tuple], default=None) – Prior parameters for zero-inflation gate
p_capture_prior (Optional[tuple], default=None) – Prior parameters for variable capture probability
mixing_prior (Optional[Any], default=None) – Prior parameters for mixture components
mu_prior (Optional[tuple], default=None) – Prior parameters for mean parameter (used in linked/odds_ratio)
phi_prior (Optional[tuple], default=None) – Prior parameters for odds ratio parameter (used in odds_ratio)
phi_capture_prior (Optional[tuple], default=None) – Prior parameters for variable capture odds ratio
- Returns:
Dictionary mapping ModelConfig attribute names to prior values
- Return type:
Dict[str, Any]
Examples
>>> # Standard parameterization (constrained) >>> config = ParameterCollector.collect_and_map_priors( ... unconstrained=False, ... parameterization="standard", ... r_prior=(1.0, 1.0), ... p_prior=(2.0, 0.5) ... ) >>> print(config) {'r_param_prior': (1.0, 1.0), 'p_param_prior': (2.0, 0.5)}
>>> # Unconstrained standard parameterization >>> config = ParameterCollector.collect_and_map_priors( ... unconstrained=True, ... parameterization="standard", ... r_prior=(0.0, 1.0), ... p_prior=(0.0, 1.0) ... ) >>> print(config) { 'r_unconstrained_prior': (0.0, 1.0), 'p_unconstrained_prior': (0.0, 1.0) }
>>> # Linked parameterization with mu parameter >>> config = ParameterCollector.collect_and_map_priors( ... unconstrained=False, ... parameterization="linked", ... p_prior=(1.0, 1.0), ... mu_prior=(0.0, 1.0) ... ) >>> print(config) {'p_param_prior': (1.0, 1.0), 'mu_param_prior': (0.0, 1.0)}
- static collect_vae_params(vae_latent_dim=3, vae_hidden_dims=None, vae_activation=None, vae_input_transformation=None, vae_vcp_hidden_dims=None, vae_vcp_activation=None, vae_prior_type='standard', vae_prior_num_layers=None, vae_prior_hidden_dims=None, vae_prior_activation=None, vae_prior_mask_type='alternating', vae_standardize=False)[source]
Collect VAE-specific parameters for ModelConfig.
This method collects all VAE-related parameters and returns them as a dictionary ready to be merged into ModelConfig kwargs. Only non-None values are included in the returned dictionary.
- Parameters:
vae_latent_dim (int, default=3) – Dimension of the VAE latent space
vae_hidden_dims (Optional[List[int]], default=None) – List of hidden layer dimensions for the VAE encoder/decoder
vae_activation (Optional[str], default=None) – Activation function name for VAE layers
vae_input_transformation (Optional[str], default=None) – Input transformation for VAE
vae_vcp_hidden_dims (Optional[List[int]], default=None) – Hidden layer dimensions for VCP encoder (variable capture models)
vae_vcp_activation (Optional[str], default=None) – Activation function for VCP encoder
vae_prior_type (str, default="standard") – Type of VAE prior (“standard” or “decoupled”)
vae_prior_num_layers (Optional[int], default=None) – Number of coupling layers for decoupled prior
vae_prior_hidden_dims (Optional[List[int]], default=None) – Hidden layer dimensions for decoupled prior coupling layers
vae_prior_activation (Optional[str], default=None) – Activation function for decoupled prior coupling layers
vae_prior_mask_type (str, default="alternating") – Mask type for decoupled prior coupling layers
vae_standardize (bool, default=False) – Whether to standardize input data for VAE models
- Returns:
Dictionary of VAE parameters ready for ModelConfig
- Return type:
Dict[str, Any]
Examples
>>> # Basic VAE configuration >>> vae_config = ParameterCollector.collect_vae_params( ... vae_latent_dim=5, ... vae_hidden_dims=[256, 128], ... vae_activation="gelu" ... ) >>> print(vae_config) {'vae_latent_dim': 5, 'vae_hidden_dims': [256, 128], 'vae_activation': 'gelu'}
>>> # VAE with decoupled prior >>> vae_config = ParameterCollector.collect_vae_params( ... vae_latent_dim=3, ... vae_prior_type="decoupled", ... vae_prior_num_layers=3, ... vae_prior_hidden_dims=[64, 64] ... ) >>> print(vae_config) { 'vae_latent_dim': 3, 'vae_prior_type': 'decoupled', 'vae_prior_num_layers': 3, 'vae_prior_hidden_dims': [64, 64] }
- scribe.utils.numpyro_to_scipy(distribution)[source]
Get the corresponding scipy.stats distribution for a numpyro.distributions.Distribution.
- Parameters:
distribution (numpyro.distributions.Distribution) – The numpyro distribution to convert
- Returns:
The corresponding scipy.stats distribution
- Return type:
- scribe.utils.use_cpu()[source]
Context manager to temporarily force JAX computations to run on CPU.
This is useful when you want to ensure specific computations run on CPU rather than GPU/TPU, for example when running out of GPU memory.
- Returns:
Yields control back to the context block
- Return type:
None
Example
>>> # Force posterior sampling to run on CPU >>> with use_cpu(): ... results.get_ppc_samples(n_samples=100)