Inference Methods¶
SCRIBE supports four inference backends that all share the same scribe.fit()
entry point. Choose the one that best fits your goals and computational budget.
Choosing an Inference Method¶
| Criterion | SVI | MCMC | VAE | Laplace |
|---|---|---|---|---|
| Speed | Fast (minutes) | Slow (hours) | Moderate (tens of minutes) | Moderate (tens of minutes) |
| Scalability | Excellent (mini-batching) | Limited (full data) | Excellent (mini-batching) | Good (mini-batching) |
| Posterior quality | Approximate | Exact | Approximate (neural) | Approximate (Hessian) |
| Latent embeddings | No | No | Yes | No |
| Models supported | All | NB-family | All | PLN, NBLN, LNM, LNMVCP |
| Best for | Exploration and production | Gold-standard uncertainty | Representation learning | Correlation recovery, rigorous PPCs |
Default recommendation
Start with SVI for NB-family models. For PLN/NBLN/LNM/LNMVCP models, use Laplace --- it avoids encoder collapse, produces rigorous per-cell posteriors from the Hessian, and has no aggregate-posterior drift. Switch to MCMC when you need exact posteriors for a publication, or use VAE when you need amortized scoring of new cells or low-dimensional embeddings.
For NBLN specifically, the recommended pipeline is SVI-cascade + freeze + loadings shrinkage — see the NBLN cascade + freeze + shrinkage workflow section below.
Stochastic Variational Inference (SVI)¶
SVI finds the best approximation to the posterior within a chosen variational family using stochastic optimization. It is the default and most commonly used inference method.
Basic usage¶
import scribe
# Default SVI inference (NBVCP model)
results = scribe.fit(adata)
# With custom parameters
results = scribe.fit(
adata,
zero_inflation=True,
n_steps=100_000,
batch_size=512,
seed=0,
)
Key parameters¶
| Parameter | Default | Description |
|---|---|---|
n_steps |
50,000 | Maximum optimization steps |
batch_size |
None (full batch) |
Mini-batch size for stochastic optimization |
optimizer_config |
None |
Custom optimizer specification (see below) |
stable_update |
True |
Numerically stable parameter updates |
restore_best |
False |
Track and restore the best variational parameters during training |
early_stopping |
None |
Automatic convergence detection (see below) |
seed |
42 | Random seed for reproducibility |
Custom optimizer¶
By default SCRIBE uses Adam. Pass an optimizer_config dict to change the
optimizer or its learning rate:
Supported optimizers: "adam", "clipped_adam", "adagrad", "rmsprop",
"sgd", "momentum".
Best-params restoration¶
The restore_best flag tracks the lowest smoothed loss during training and
restores those parameters at the end, regardless of whether early stopping
is configured. This is especially useful for normalizing flow guides, where
the ELBO can fluctuate late in training:
results = scribe.fit(
adata,
unconstrained=True,
guide_flow="affine_coupling",
restore_best=True,
n_steps=100_000,
)
Guide families¶
The variational guide controls the flexibility of the posterior approximation. SCRIBE supports several families---mean-field (default), low-rank, joint low-rank, normalizing flow, amortized, and VAE latent---each offering different trade-offs between speed and the ability to capture correlations:
# Low-rank guide for gene correlations
results = scribe.fit(adata, guide_rank=8)
# Joint low-rank across parameter groups
results = scribe.fit(
adata, guide_rank=8, joint_params="biological",
)
# Normalizing flow guide for non-Gaussian posteriors
results = scribe.fit(
adata, unconstrained=True,
guide_flow="affine_coupling",
)
# Amortized capture for VCP models
results = scribe.fit(adata, variable_capture=True, amortize_capture=True)
Full guide: Variational guide families
Early stopping¶
SVI supports automatic convergence detection to avoid wasting computation:
results = scribe.fit(
adata,
n_steps=200_000,
early_stopping={
"patience": 500,
"min_delta": 1.0,
"smoothing_window": 50,
"restore_best": True,
},
)
| Early stopping parameter | Default | Description |
|---|---|---|
patience |
500 | Steps without improvement before stopping |
min_delta |
1.0 | Minimum loss improvement to count as progress |
smoothing_window |
50 | Window size for moving-average loss |
restore_best |
True |
Restore parameters from the best checkpoint |
Results¶
scribe.fit() returns a ScribeSVIResults object. See the
Results Class page for the full API, including posterior
sampling, predictive checks, denoising, and normalization.
Markov Chain Monte Carlo (MCMC)¶
MCMC generates samples from the true posterior distribution using the No-U-Turn Sampler (NUTS). It provides the most accurate uncertainty quantification but is slower than SVI.
Basic usage¶
import scribe
results = scribe.fit(
adata,
inference_method="mcmc",
n_samples=2_000,
n_warmup=1_000,
n_chains=4,
)
Key parameters¶
| Parameter | Default | Description |
|---|---|---|
inference_method |
"svi" |
Set to "mcmc" for MCMC inference |
n_samples |
2,000 | Posterior samples per chain |
n_warmup |
1,000 | Warmup (burn-in) samples |
n_chains |
1 | Number of parallel chains |
Float64 precision
MCMC defaults to 64-bit floating point for numerical stability during Hamiltonian dynamics. This doubles memory usage compared to SVI but is important for reliable sampling.
Warm-starting from SVI¶
A common workflow is to run SVI first for exploration, then refine with MCMC using the SVI result as initialization. This dramatically reduces warmup time:
import scribe
# Step 1: fast SVI exploration
svi_results = scribe.fit(adata, n_steps=50_000)
# Step 2: refine with MCMC, initialized from SVI
mcmc_results = scribe.fit(
adata,
inference_method="mcmc",
svi_init=svi_results,
n_samples=2_000,
n_warmup=500,
)
The svi_init parameter handles cross-parameterization mapping automatically
-- you can initialize MCMC from an SVI result that used a different
parameterization.
Results¶
MCMC returns a ScribeMCMCResults object with the same analysis API as SVI
results (posterior sampling, predictive checks, denoising, etc.), plus
MCMC-specific diagnostics:
# NUTS diagnostics
results.print_summary()
# Chain-grouped samples for convergence analysis
chain_samples = results.get_samples(group_by_chain=True)
Variational Autoencoder (VAE)¶
The VAE backend uses neural networks (Flax NNX) for amortized variational inference. It learns a low-dimensional latent representation of each cell while simultaneously fitting the SCRIBE probabilistic model.
Basic usage¶
import scribe
results = scribe.fit(
adata,
inference_method="vae",
vae_latent_dim=10,
n_steps=100_000,
batch_size=256,
)
Key parameters¶
| Parameter | Default | Description |
|---|---|---|
inference_method |
"svi" |
Set to "vae" for VAE inference |
vae_latent_dim |
10 | Dimensionality of the latent space |
vae_encoder_hidden_dims |
None |
Encoder hidden layer sizes (e.g., [512, 256]) |
vae_decoder_hidden_dims |
None |
Decoder hidden layer sizes |
vae_activation |
None |
Activation function ("relu", "gelu", "silu", etc.) |
vae_input_transform |
"log1p" |
Input preprocessing ("log1p", "log", "sqrt", "identity") |
VAE variants¶
Standard VAE -- single encoder-decoder pair with a standard normal prior.
Decoupled Prior VAE (dpVAE) -- separate priors for different parameter groups, enabling more flexible modeling of parameter relationships.
Normalizing flow priors¶
For more expressive latent distributions, attach a normalizing flow to the VAE prior:
results = scribe.fit(
adata,
inference_method="vae",
vae_latent_dim=10,
vae_flow_type="spline_coupling",
vae_flow_num_layers=4,
vae_flow_hidden_dims=[64, 64],
)
Available flow types: "affine_coupling" (fast baseline),
"spline_coupling" (expressive, recommended for production),
"maf" (fast density), "iaf" (fast sampling).
Latent space analysis¶
VAE results provide cell embeddings that can be used for visualization and clustering:
# Cell embeddings in latent space
embeddings = results.get_latent_embeddings(data=adata.X, n_samples=100)
# Conditional posterior samples
latent_samples = results.get_latent_samples_conditioned_on_data(
data=adata.X, n_samples=500,
)
Laplace Approximation¶
The Laplace inference path finds each cell's MAP (maximum a posteriori) latent via Newton iteration, then approximates the per-cell posterior as a Gaussian centered at the MAP with covariance equal to the negative inverse Hessian. The outer loop optimizes global parameters (decoder weights \(\mu\), \(W\), \(d\)) via Adam on the Laplace-approximated ELBO. There is no encoder network---each cell's posterior is computed locally.
Basic usage¶
import scribe
# PLN with Laplace inference
results = scribe.fit(
adata,
model="pln",
inference_method="laplace",
latent_dim=16,
n_steps=50_000,
batch_size=256,
)
# NBLN with Laplace inference (cascade-frozen workflow below)
results = scribe.fit(
adata,
model="nbln",
inference_method="laplace",
latent_dim=16,
n_steps=50_000,
)
# LNMVCP with Laplace inference
results = scribe.fit(
adata,
model="lnmvcp",
inference_method="laplace",
latent_dim=16,
n_steps=50_000,
)
latent_dim vs vae_latent_dim
The latent_dim kwarg is the preferred name for the rank of the
low-rank loadings matrix \(\underline{\underline{W}} \in
\mathbb{R}^{G \times k}\). For backward compatibility, the legacy
vae_latent_dim kwarg is still accepted (it was the original
name when Laplace inference didn't yet exist). Passing both
raises ValueError.
Key parameters¶
| Parameter | Default | Description |
|---|---|---|
inference_method |
"svi" |
Set to "laplace" for Laplace inference |
model |
"nbvcp" |
Must be "pln", "nbln", "lnm", or "lnmvcp" for Laplace |
n_steps |
50_000 |
Outer optimization steps |
batch_size |
None |
Mini-batch size for stochastic gradient estimation |
latent_dim |
None |
Rank \(k\) of the low-rank covariance \(\Sigma = WW^\top + \text{diag}(d)\). Legacy alias: vae_latent_dim. |
laplace_config |
None |
Dict of Newton solver settings (see below) |
informative_priors_from |
None |
Cascade source for NBLN (Phase-1 soft cascade) — see NBLN workflow below |
informative_priors_freeze |
("r", "eta") |
Cascade freeze parameters for NBLN (Phase-2). Accepts either internal short names ("r", "mu", "eta") or their descriptive aliases ("dispersion", "expression"/"mean_expression", "capture_efficiency"). Both forms work identically. |
priors={"loadings": ...} |
None |
Loadings-matrix shrinkage strategy spec for PLN/NBLN (Phase-3) |
Laplace configuration¶
Fine-tune the inner Newton solver via laplace_config:
results = scribe.fit(
adata,
model="lnmvcp",
inference_method="laplace",
laplace_config={
"n_newton_steps": 15, # more iterations for hard cells
"damping": 1e-3, # tighter Tikhonov regularization
"newton_tolerance": 1e-3, # relax for production fits
"convergence_action": "warn",
},
)
| Config key | Default | Description |
|---|---|---|
n_newton_steps |
5 |
Newton iterations per cell per outer step |
damping |
1e-2 |
Tikhonov regularization added to Hessian diagonal |
newton_tolerance |
1e-4 |
Gradient norm threshold for declaring convergence |
convergence_action |
"warn" |
Action when cells don't converge: "warn", "raise", or "ignore" |
How it works¶
The training loop alternates two steps per outer iteration:
-
Inner Newton on per-cell latents (holding globals fixed). For PLN: joint \((\underline{x}, \eta)\) Newton via Schur-complement back-substitution. For LNM/LNMVCP: composition Newton (\(\underline{z}\) or \(\underline{y}_\text{ALR}\)) plus scalar \(\eta\) Newton.
-
Outer Adam step on global parameters \((\mu, W, d)\) using the gradient of the Laplace ELBO with MAPs treated as
stop_gradientconstants.
Each Newton step costs \(O(Gk + k^3)\) per cell using nested Woodbury identities on the low-rank covariance --- no \(G \times G\) matrices are ever formed.
When to use Laplace¶
| Use Laplace when... | Use SVI/VAE when... |
|---|---|
| You need rigorous per-cell posteriors from the Hessian | You need amortized inference for new cells |
| You suspect the encoder is collapsing on a per-cell latent | The encoder is well-calibrated |
| You want no aggregate-posterior drift | You need fast serving-time scoring |
| Your data has high cell-to-cell variability | The dataset is small enough that encoder collapse isn't a concern |
Progress-bar diagnostics¶
During training, the progress bar reports per-cell Newton convergence:
LNM Laplace (learned + capture): 21%|██ |
init loss: -8.857e+07,
avg. loss [10001-10500]: -8.896e+07,
comp max/p99/med 1.38e+01/3.42e+00/2.51e-03;
η max/p99/med 1.79e-06/1.61e-06/4.92e-07
The comp and η lines show per-cell Newton gradient norms (max, 99th
percentile, median) for the composition and capture blocks respectively.
Healthy fit: median well below tolerance, max trending down.
Problem cells: median small but max large and bouncing --- a few
pathological cells (typically low-count) are slow to converge but don't
affect the bulk fit.
Divergence handling¶
The engine has three layered defenses against single-cell explosive divergence:
- Sherman--Morrison denominator floor --- prevents catastrophic float32
cancellation in the
y_alrNewton step. - Per-cell NaN/Inf mask --- divergent cells are masked from the current step's gradient on globals.
- Outer-loop divergence detector --- clean abort with diagnostic context if loss becomes NaN or grows by > 1000× from init.
If a divergence abort fires, typical remedies are:
- Increase
n_newton_stepsto 20--30 - Tighten
dampingto 1e-3 or below - Pre-filter outlier cells (very low \(u_T\) or extreme compositional skew)
NBLN cascade + freeze + shrinkage workflow¶
NBLN has a per-cell rigid-translation gauge (\(C\) degrees of freedom, one per cell) that needs structural pinning to produce a well-identified fit. SCRIBE addresses this with a three-phase pipeline:
import scribe, numpy as np
# Phase 1: SVI cascade source (NBVCP-SVI on the same data)
svi_results = scribe.fit(
adata, model="nbvcp", parameterization="mean_odds",
priors={"capture_efficiency": (np.log(100_000), 0.5)},
inference_method="svi", n_steps=50_000,
)
# Phases 2+3: NBLN-Laplace with cascade freeze + loadings shrinkage
laplace_results = scribe.fit(
adata, model="nbln", inference_method="laplace",
# Phase 1: pass the cascade source. Empirical Gaussian priors on
# r, mu, eta from the SVI posterior are derived and injected as
# soft priors in the Laplace loss.
informative_priors_from=svi_results,
informative_priors_tau=1.0,
# Phase 2: freeze r and eta at the cascade MAPs. Pins the per-cell
# rigid-translation gauge structurally. Frozen params route
# through cascade_source for PPC and distributions to preserve
# full SVI guide fidelity.
informative_priors_freeze=("r", "eta"), # default
# Phase 3: loadings shrinkage. Lets latent_dim be generous and
# picks the effective rank adaptively. See the loadings-shrinkage
# theory page for the strategy catalog and calibration workflow.
priors={
"capture_efficiency": (np.log(100_000), 0.5),
"loadings": {
"type": "horseshoe_columnwise",
"tau_scale": 1.0,
},
},
latent_dim=16,
n_steps=20_000,
)
# Inspect effective rank + correlation structure
print(laplace_results.w_prior_diagnostics["effective_rank"]) # adaptive rank
diag = laplace_results.get_gauge_diagnostics()
print(diag["gauge_contamination_ratio"]) # should be < 0.05
# Gauge-invariant loadings for cross-gene correlation analysis
W_perp = laplace_results.get_W_compositional()
| Phase | Mechanism | Default |
|---|---|---|
| 1. Soft cascade | SVI posterior → empirical Gaussian priors → Laplace loss | Activated by informative_priors_from= |
| 2. Hard freeze | Selected cascade-derived params pinned at MAP, excluded from optimizer | ("r", "eta") |
| 3. Loadings shrinkage | Adaptive rank selection on the columns of \(W_\perp\) | None (opt-in) |
The three phases are orthogonal — you can use any subset. The combined recipe above is the recommended production workflow for NBLN fits.
Theory: NB Log-Normal Model, Loadings-Matrix Shrinkage Priors
Results¶
scribe.fit() with inference_method="laplace" returns a
ScribeLaplaceResults object. See the Results Class page for
the full API, including MAP-only and Laplace-uncertainty posterior predictive
checks.
Combining Inference Methods¶
SVI then MCMC¶
The most common multi-method workflow is SVI for fast exploration followed by MCMC for publication-quality posteriors:
flowchart LR
A["SVI (fast)"] -->|"svi_init="| B["MCMC (exact)"]
A --> C["Explore results"]
B --> D["Final analysis"]
SVI then DE / Model Comparison¶
SVI results feed directly into downstream analyses:
flowchart LR
A["scribe.fit()"] --> B["Differential Expression"]
A --> C["Model Comparison"]
A --> D["Posterior Predictive Checks"]
A --> E["Bayesian Denoising"]
See the Differential Expression and Model Comparison guides for details on these downstream analyses.