Training an RHVAE on synthetic data

Author
Affiliation

Manuel Razo-Mejia

Department of Biology, Stanford University

Keywords

variational autoencoders, Riemannian geometry, deep learning

(c) This work is licensed under a Creative Commons Attribution License CC-BY 4.0. All code contained herein is licensed under an MIT license.

Riemannian Hamiltonian Variational Autoencoder for Phenotypic Space Reconstruction

This notebook implements a Riemannian Hamiltonian Variational Autoencoder (RHVAE) to analyze fitness profiles generated from evolutionary simulations, as the ones generated in the Evolutionary Dynamics notebook. We’ll train the RHVAE to learn a low-dimensional representation (latent space) of these fitness profiles, with the goal of reconstructing the underlying phenotypic coordinates.

Introduction to RHVAEs

Variational Autoencoders (VAEs) are generative models that learn to encode high-dimensional data into a lower-dimensional latent space and then decode it back. A standard VAE consists of an encoder network that maps input data to a distribution in latent space, and a decoder network that reconstructs the input from samples from this distribution.

The Riemannian Hamiltonian VAE (RHVAE) extends this framework by modeling the latent space as a Riemannian manifold, learning not only the encoding/decoding functions but also the geometric structure of the latent space. This is achieved by:

  1. Learning a position-dependent metric tensor that captures how distances in latent space relate to distances in data space.
  2. Using Hamiltonian dynamics to move in latent space in a geometry-aware manner.

The RHVAE is particularly useful for our application because: - It captures nonlinear relationships between fitness profiles and underlying phenotypes. - It provides meaningful distances in latent space via the learned metric tensor. - It offers a principled way to interpolate between points in latent space.

Setup Environment

First, let’s import the necessary packages for our implementation.

# Import project package
import Antibiotic
import Antibiotic.mh as mh # Metropolis-Hastings dynamics module

# Import AutoEncoderToolkit to train VAEs
import AutoEncoderToolkit as AET

# Import packages for manipulating results
import DimensionalData as DD
import DataFrames as DF
import Glob

# Import ML libraries
import Flux

# Import CUDA (if available) to train using GPU
# import CUDA

# Import library to save models
import JLD2

# Import libraries for data handling
import IterTools
import StatsBase
import Random

# Load Plotting packages
using CairoMakie
import ColorSchemes
import Colors

# Activate backend
CairoMakie.activate!()

# Set plotting style
Antibiotic.viz.theme_makie!()

# Set random seed for reproducibility
Random.seed!(42)

Defining Directories

We’ll set up the directories where we’ll load data from and save our model and results to.

# Defining directories...

# Define output directory
out_dir = "./sim_metropolis_dynamics"

# Define model state directory
state_dir = "$(out_dir)/rhvae_model_state"

Model Architecture and Hyperparameters

The RHVAE architecture consists of three main components: 1. Encoder: Maps fitness profiles to a distribution in latent space 2. Decoder: Reconstructs fitness profiles from latent space coordinates 3. Metric Network: Learns the Riemannian metric tensor of the latent space

Let’s define the hyperparameters and architecture of our model

# Defining hyperparameters...

# Define dimensionality of latent space
n_latent = 2
# Define number of neurons in hidden layers
n_neuron = 128

# Define RHVAE hyper-parameters
T = 0.8f0 # Temperature
λ = 1.0f-2 # Regularization parameter
n_centroids = 256 # Number of centroids

# Define loss function hyper-parameters
ϵ = Float32(1E-3) # Leapfrog step size
K = 10 # Number of leapfrog steps
βₒ = 0.3f0 # Initial temperature for tempering

# Define RHVAE hyper-parameters in a named tuple
rhvae_kwargs = (K=K, ϵ=ϵ, βₒ=βₒ)

# Define training hyperparameters
n_epoch = 75 # Number of epochs
n_batch = 512 # Batch size
n_batch_loss = 512 # Batch size for loss computation
η = 10^-3 # Learning rate
split_frac = 0.85 # Train/val split

# Define ELBO prefactors
logp_prefactor = [10.0f0, 0.1f0, 0.1f0]
logq_prefactor = [0.1f0, 0.1f0, 0.1f0]

# Define loss function kwargs in a NamedTuple
loss_kwargs = (
    K=K,
    ϵ=ϵ,
    βₒ=βₒ,
    logp_prefactor=logp_prefactor,
    logq_prefactor=logq_prefactor,
)

# Define by how much to subsample the time series
n_sub = 10

Loading and Preprocessing Data

Now let’s load the fitness profiles from our evolutionary simulations and preprocess them for training. Since the simulations are very densly sampled over time, we will train the RHVAE on a subsampled time series, taking every n_sub time point.

# Loading data into memory...

# Load fitnotype profiles
fitnotype_profiles = JLD2.load("$(out_dir)/sim_evo.jld2")["fitnotype_profiles"]

# Extract initial and final time points
t_init, t_final = collect(DD.dims(fitnotype_profiles, :time)[[1, end]])
# Subsample time series
fitnotype_profiles = fitnotype_profiles[time=DD.At(t_init:n_sub:t_final)]

# Define number of environments
n_env = length(DD.dims(fitnotype_profiles, :landscape))

# Extract fitness data bringing the fitness dimension to the first dimension
fit_data = permutedims(fitnotype_profiles.fitness.data, (5, 1, 2, 3, 4, 6))
# Reshape the array to a Matrix
fit_data = reshape(fit_data, size(fit_data, 1), :)

# Reshape the array to stack the 3rd dimension
fit_mat = log.(fit_data)

# Fit model to standardize data to mean zero and standard deviation 1 on each
# environment 
dt = StatsBase.fit(StatsBase.ZScoreTransform, fit_mat, dims=2)

# Standardize the data to have mean 0 and standard deviation 1
fit_std = StatsBase.transform(dt, fit_mat)

# Split indexes of data into training and validation
train_idx, val_idx = Flux.splitobs(
    1:size(fit_std, 2), at=split_frac, shuffle=true
)

# Extract train and validation data
train_data = fit_std[:, train_idx]
val_data = fit_std[:, val_idx]

Our data preprocessing involves several key steps: 1. Log transformation: We apply a logarithmic transformation to the fitness values to make their distribution more amenable to the model. 2. Standardization: We standardize the log-fitness values per environment to have mean 0 and standard deviation 1. This is common practice when training neural networks as it helps with numerical stability and convergence. 3. Train-validation split: We split the data into training (85%) and validation (15%) sets. The validation set is used to monitor the model’s performance during training.

Defining the RHVAE Model Architecture

We now define the RHVAE model architecture with its three core components: encoder, decoder, and metric network. For this, we use the AutoEncoderToolkit.jl package (Razo-Mejia 2024).

The first step is to select the centroids used to “anchor” the metric tensor learned by the RHVAE. We do this by using the centroids_kmedoids function from the AutoEncoderToolkit.jl package. These centroids will be part of the definition of the RHVAE model itself.

# Selecting centroids via k-means...

# Select centroids via k-medoids
centroids_data = AET.utils.centroids_kmedoids(fit_std, n_centroids)

Now, we can define the RHVAE architecture. We start by defining the encoder, which maps the fitness profiles to a distribution in latent space. Since AutoEncoderToolkit.jl is based on Flux.jl, we will build all of the components of the RHVAE using Flux.jl components.

# Define JointGaussianLogEncoder...

# Define encoder chain
encoder_chain = Flux.Chain(
    # First layer
    Flux.Dense(n_env => n_neuron, Flux.identity),
    # Second layer
    Flux.Dense(n_neuron => n_neuron, Flux.leakyrelu),
    # Third layer
    Flux.Dense(n_neuron => n_neuron, Flux.leakyrelu),
    # Fourth layer
    Flux.Dense(n_neuron => n_neuron, Flux.leakyrelu),
)

# Define layers for µ and log(σ)
µ_layer = Flux.Dense(n_neuron => n_latent, Flux.identity)
logσ_layer = Flux.Dense(n_neuron => n_latent, Flux.identity)

# build encoder
encoder = AET.JointGaussianLogEncoder(encoder_chain, µ_layer, logσ_layer)

Now, we define the decoder, which maps the latent space coordinates to the fitness profiles.

# Define SimpleGaussianDecoder...

# Initialize decoder
decoder = AET.SimpleGaussianDecoder(
    Flux.Chain(
        # First layer
        Flux.Dense(n_latent => n_neuron, Flux.identity),
        # Second Layer
        Flux.Dense(n_neuron => n_neuron, Flux.leakyrelu),
        # Third layer
        Flux.Dense(n_neuron => n_neuron, Flux.leakyrelu),
        # Fourth layer
        Flux.Dense(n_neuron => n_neuron, Flux.leakyrelu),
        # Output layer
        Flux.Dense(n_neuron => n_env, Flux.identity)
    )
)

Finally, we define the metric network, which learns the Riemannian metric tensor of the latent space.

# Define MetricChain...

# Define mlp chain
mlp_chain = Flux.Chain(
    # First layer
    Flux.Dense(n_env => n_neuron, Flux.identity),
    # Second layer
    Flux.Dense(n_neuron => n_neuron, Flux.leakyrelu),
    # Third layer
    Flux.Dense(n_neuron => n_neuron, Flux.leakyrelu),
    # Fourth layer
    Flux.Dense(n_neuron => n_neuron, Flux.leakyrelu),
)

# Define layers for the diagonal and lower triangular part of the covariance
# matrix
diag = Flux.Dense(n_neuron => n_latent, Flux.identity)
lower = Flux.Dense(
    n_neuron => n_latent * (n_latent - 1) ÷ 2, Flux.identity
)

# Build metric chain
metric_chain = AET.RHVAEs.MetricChain(mlp_chain, diag, lower)

For the final step, we bring all the components together to define the RHVAE model itself.

# Initialize rhvae
rhvae = AET.RHVAEs.RHVAE(
    encoder * decoder,
    metric_chain,
    centroids_data,
    T,
    λ
)

The RHVAE architecture comprises:

  1. Encoder:
    • Takes fitness profiles (dimension = number of environments) as input
    • Processes through 4 hidden layers with 128 neurons each
    • Outputs mean (μ) and log-variance (log σ) parameters of a diagonal Gaussian distribution in latent space
  2. Decoder:
    • Takes latent space coordinates (dimension = 2) as input
    • Processes through 4 hidden layers with 128 neurons each
    • Outputs reconstructed fitness profiles
  3. Metric Network:
    • Learns the parameters of the Riemannian metric tensor
    • Takes decoded fitness profiles as input
    • Outputs parameters for the diagonal and lower triangular parts of the metric tensor

We also select centroids using k-medoids clustering to approximate the data manifold, which helps in computing the metric tensor more efficiently.

Training the RHVAE

Now, let’s train our RHVAE model on the fitness profiles:

println("Checking previous model states...")

# List previous model parameters
model_states = sort(Glob.glob("$(state_dir)/beta-rhvae_epoch*.jld2"[2:end], "/"))

# Check if model states exist
if length(model_states) > 0
    # Load model state
    model_state = JLD2.load(model_states[end])["model_state"]
    # Input parameters to model
    Flux.loadmodel!(rhvae, model_state)
    # Update metric parameters
    AET.RHVAEs.update_metric!(rhvae)
    # Extract epoch number
    epoch_init = parse(
        Int, match(r"epoch(\d+)", model_states[end]).captures[1]
    ) + 1
else
    epoch_init = 1
end # if

println("Initial epoch: $epoch_init")

println("Uploading model to GPU...")

# Check if CUDA is available
if CUDA.functional()
    # Upload model to GPU
    rhvae = Flux.gpu(rhvae)
    # Upload data to GPU
    train_data = Flux.gpu(train_data)
    val_data = Flux.gpu(val_data)
end

# Explicit setup of optimizer
opt_rhvae = Flux.Train.setup(
    Flux.Optimisers.Adam(η),
    rhvae
)

println("\nTraining RHVAE...\n")

# Loop through number of epochs
for epoch in epoch_init:n_epoch
    # Define number of batches
    num_batches = size(train_data, 2) ÷ n_batch
    # Shuffle data indexes
    idx_shuffle = Random.shuffle(1:size(train_data, 2))
    # Split indexes into batches
    idx_batches = IterTools.partition(idx_shuffle, n_batch)
    # Loop through batches
    for (i, idx_tuple) in enumerate(idx_batches)
        println("Epoch: $(epoch) | Batch: $(i) / $(length(idx_batches))")
        # Extract indexes
        idx_batch = collect(idx_tuple)
        # Train RHVAE
        loss_epoch = AET.RHVAEs.train!(
            rhvae, train_data[:, idx_batch], opt_rhvae;
            loss_kwargs=loss_kwargs, verbose=false, loss_return=true
        )
        println("Loss: $(loss_epoch)")
    end # for train_loader

    # Sample train data
    train_sample = train_data[
        :,
        StatsBase.sample(1:size(train_data, 2), n_batch_loss, replace=false)
    ]
    # Sample val data
    val_sample = val_data

    println("Computing loss in training and validation data...")
    loss_train = AET.RHVAEs.loss(rhvae, train_sample; loss_kwargs...)
    loss_val = AET.RHVAEs.loss(rhvae, val_sample; loss_kwargs...)

    # Forward pass sample through model
    println("Computing MSE in training and validation data...")
    out_train = rhvae(train_sample; rhvae_kwargs...).μ
    mse_train = Flux.mse(train_sample, out_train)
    out_val = rhvae(val_sample; rhvae_kwargs...).μ
    mse_val = Flux.mse(val_sample, out_val)

    println(
        "\n Epoch: $(epoch) / $(n_epoch)\n " *
        "   - loss_train: $(loss_train)\n" *
        "   - loss_val: $(loss_val)\n" *
        "   - mse_train: $(mse_train)\n" *
        "   - mse_val: $(mse_val)\n"
    )

    # Save checkpoint
    JLD2.jldsave(
        "$(state_dir)/beta-rhvae_epoch$(lpad(epoch, 5, "0")).jld2",
        model_state=Flux.state(rhvae) |> Flux.cpu,
        loss_train=loss_train,
        loss_val=loss_val,
        mse_train=mse_train,
        mse_val=mse_val,
        train_idx=train_idx,
        val_idx=val_idx,
    )
end # for n_epoch

The training process involves:

  1. Batched Training: We process the data in batches of size 512.
  2. Loss Computation: We use a custom loss function for RHVAEs that combines:
    • Reconstruction loss (how well the model reconstructs the input)
    • KL divergence (a regularization term)
    • Hamiltonian loss (related to the metric tensor)
  3. Optimization: We use the Adam optimizer with a learning rate of 0.001.
  4. Checkpointing: We save the model state after each epoch, allowing us to resume training if needed.

The loss includes several components weighted by the logp_prefactor and logq_prefactor parameters, which balance reconstruction quality against regularization.

Analyzing Training Results

After training, we can analyze the training results to understand how well our model is learning. First, let’s load the loss and mean squared error (MSE) training curves into a dataframe.

# Loading trained model...

# Find model file
model_file = first(Glob.glob("$(out_dir)/rhvae_model*.jld2"))
# List epoch parameters
model_states = Glob.glob("$(state_dir)/*.jld2")

# Initialize dataframe to store files metadata
df_meta = DF.DataFrame()

# Loop over files
for f in model_states
    # Extract epoch number from file name using regular expression
    epoch = parse(Int, match(r"epoch(\d+)", f).captures[1])
    # Load model_state file
    f_load = JLD2.load(f)
    # Extract values
    loss_train = f_load["loss_train"]
    loss_val = f_load["loss_val"]
    mse_train = f_load["mse_train"]
    mse_val = f_load["mse_val"]
    # Generate temporary dataframe to store metadata
    df_tmp = DF.DataFrame(
        :epoch => epoch,
        :loss_train => loss_train,
        :loss_val => loss_val,
        :mse_train => mse_train,
        :mse_val => mse_val,
        :model_file => model_file,
        :model_state => f,
    )
    # Append temporary dataframe to main dataframe
    global df_meta = DF.vcat(df_meta, df_tmp)
end # for f in model_states

Now, we can plot the training loss and mean squared error (MSE) training curves.

# Plotting training loss...

# Initialize figure
fig = Figure(size=(600, 300))

# Add axis
ax = Axis(
    fig[1, 1],
    xlabel="epoch",
    ylabel="loss",
)

# Plot training loss
lines!(
    ax,
    df_meta.epoch,
    df_meta.loss_train,
    label="train",
)
# Plot validation loss
lines!(
    ax,
    df_meta.epoch,
    df_meta.loss_val,
    label="validation",
)

# Add legend
axislegend(ax, position=:rt)

# Add axis
ax = Axis(
    fig[1, 2],
    xlabel="epoch",
    ylabel="mean squared error",
)

# Plot training loss
lines!(
    ax,
    df_meta.epoch,
    df_meta.mse_train,
    label="train",
)
# Plot validation loss
lines!(
    ax,
    df_meta.epoch,
    df_meta.mse_val,
    label="validation",
)

# Add legend
axislegend(ax, position=:rt)

fig

This code loads all the training checkpoints and plots the training curves, showing: 1. The total loss over time for both training and validation sets 2. The Mean Squared Error (MSE) over time for both sets

These plots help us assess if the model is learning effectively and if it’s overfitting or underfitting the data. Since both curves perfectly track each other, we can conclude that the model is learning effectively without overfitting the training data.

Mapping Data to Latent Space

Now that we have a trained model, we can use it to map our fitness profiles to the latent space. Again, we will take advantage of the DimArray type to preserve the metadata of the fitness profiles when mapping them to the latent space.

# Mapping data to latent space...

# Standardize the data to have mean 0 and standard deviation 1
log_fitnotype_std = DD.DimArray(
    mapslices(slice -> StatsBase.transform(dt, slice),
        log.(fitnotype_profiles.fitness.data),
        dims=[5]),
    fitnotype_profiles.fitness.dims,
)

# Load model
rhvae = JLD2.load(model_file)["model"]
# Load latest model state
Flux.loadmodel!(rhvae, JLD2.load(df_meta.model_state[end])["model_state"])
# Update metric parameters
AET.RHVAEs.update_metric!(rhvae)

# Define latent space dimensions
latent = DD.Dim{:latent}([:latent1, :latent2])

# Map data to latent space
dd_latent = DD.DimArray(
    dropdims(
        mapslices(slice -> rhvae.vae.encoder(slice).μ,
            log_fitnotype_std.data,
            dims=[5]);
        dims=1
    ),
    (log_fitnotype_std.dims[2:4]..., latent, log_fitnotype_std.dims[6]),
)

This process involves: 1. Loading the fitness data again and applying the same preprocessing steps 2. Loading our trained RHVAE model 3. Using the encoder part of the RHVAE to map each fitness profile to a point in the 2D latent space 4. Storing the results in a dimensional array that preserves the metadata (lineage, time, etc.)

We now have a low-dimensional representation of our fitness profiles that we can visualize and analyze.

Visualizing the Latent Space

Let’s visualize the latent space to see how our model has organized the fitness profiles:

# Plotting latent space coordinates...

# Initialize figure
fig = Figure(size=(300, 300))

# Add axis
ax = Axis(
    fig[1, 1],
    xlabel="latent dimension 1",
    ylabel="latent dimension 2",
    aspect=AxisAspect(1)
)

# Plot latent space
scatter!(
    ax,
    vec(dd_latent[latent=DD.At(:latent1)]),
    vec(dd_latent[latent=DD.At(:latent2)]),
    markersize=5,
)

fig

This basic plot shows all points in the latent space. Each point represents the encoding of a fitness profile. Next, let’s colorize the points by lineage to gain more insight:

# Plotting latent space coordinates colored by lineage...

# Initialize figure
fig = Figure(size=(300, 300))

# Add axis
ax = Axis(
    fig[1, 1],
    xlabel="latent dimension 1",
    ylabel="latent dimension 2",
    aspect=AxisAspect(1)
)

# Loop over lineages
for (i, lin) in enumerate(DD.dims(dd_latent, :lineage))
    # Plot latent space
    scatter!(
        ax,
        vec(dd_latent[latent=DD.At(:latent1), lineage=lin]),
        vec(dd_latent[latent=DD.At(:latent2), lineage=lin]),
        markersize=5,
        color=(ColorSchemes.seaborn_colorblind[i], 0.25),
    )
end # for 

fig

This plot shows the latent space colored by lineage. Each lineage is assigned a distinct color, allowing us to see how different lineages are distributed in the latent space.

Visualizing the latent space curvature

One of the main advantages of the RHVAE is that it learns a Riemannian metric on the latent space, which allows us to relate distances in the latent space with distances in the original data space, despite the nonlinear transformations.

The metric tensor, \(\underline{\underline{G}}(\underline{z})\), is a position-dependent positive-definite matrix that we can evaluate at any point \(\underline{z}\) in the latent space. For at 2D latent space, this results in a 2x2 matrix. A way to visualize the local curvature of the latent space is to compute the so-called metric volume, defined as the square root of the determinant of the metric tensor, \(\sqrt{\det(\underline{\underline{G}}(\underline{z}))}\). AutoEncoderToolkit.jl provides a function to compute the inverse metric tensor, \(G^{-1}(\underline{z})\), from the model. We can use this directly to compute the metric volume at each point in the latent space.

Let’s define the ranges of the latent space dimensions and evaluate the metric tensor at each point in the latent space.

# Computing metric tensor...

# Define number of points per axis
n_points = 250

# Extract latent space ranges
latent1_range = range(
    minimum(dd_latent[latent=DD.At(:latent1)]) - 1.5,
    maximum(dd_latent[latent=DD.At(:latent1)]) + 1.5,
    length=n_points
)
latent2_range = range(
    minimum(dd_latent[latent=DD.At(:latent2)]) - 1.5,
    maximum(dd_latent[latent=DD.At(:latent2)]) + 1.5,
    length=n_points
)
# Define latent points to evaluate
z_mat = reduce(hcat, [[x, y] for x in latent1_range, y in latent2_range])

# Compute inverse metric tensor
Ginv = AET.RHVAEs.G_inv(z_mat, rhvae)

# Compute metric 
logdetG = reshape(
    -1 / 2 * AET.utils.slogdet(Ginv), n_points, n_points
)

Now, let’s plot the heatmap of the metric volume and the surface of the latent space.

# Plotting latent space metric...

# Initialize figure
fig = Figure(size=(700, 300))

# Add axis
ax1 = Axis(
    fig[1, 1],
    xlabel="latent dimension 1",
    ylabel="latent dimension 2",
    aspect=AxisAspect(1)
)
ax2 = Axis(
    fig[1, 2],
    xlabel="latent dimension 1",
    ylabel="latent dimension 2",
    aspect=AxisAspect(1)
)

# Plot heatmpat of log determinant of metric tensor
hm = heatmap!(
    ax1,
    latent1_range,
    latent2_range,
    logdetG,
    colormap=Reverse(to_colormap(ColorSchemes.PuBu)),
)

surface!(
    ax2,
    latent1_range,
    latent2_range,
    logdetG,
    colormap=Reverse(to_colormap(ColorSchemes.PuBu)),
    shading=NoShading,
    rasterize=true,
)

# Plot latent space
scatter!(
    ax2,
    vec(dd_latent[latent=DD.At(:latent1)]),
    vec(dd_latent[latent=DD.At(:latent2)]),
    markersize=4,
    color=(:white, 0.3),
    rasterize=true,
)

# Find axis limits from minimum and maximum of latent points
xlims!.(
    [ax2, ax1],
    minimum(dd_latent[latent=DD.At(:latent1)]) - 1.5,
    maximum(dd_latent[latent=DD.At(:latent1)]) + 1.5
)
ylims!.(
    [ax2, ax1],
    minimum(dd_latent[latent=DD.At(:latent2)]) - 1.5,
    maximum(dd_latent[latent=DD.At(:latent2)]) + 1.5
)

# Add colorbar
Colorbar(fig[1, 3], hm, label="√log[det(G̲̲)]")

fig

This plot shows the heatmap of the metric volume and the surface of the latent space. The colorbar shows the metric volume, which is a measure of the local curvature of the latent space. The darker the region, the flatter the latent space. One neat feature of the RHVAE is that the learned metric tensor “cages” the data cloud, clearly defining the regions in which the model is able to make meaningful predictions (since neural networks are famously terrible at extrapolation). We also see that within the data cloud, regions of high curvature coincide with the four regions of low genotype-phenotype fitness we originally defined in the Evolutionary Dynamics, suggesting that the model is able to capture the underlying phenotypic space.

Conclusion

In this notebook, we’ve trained a Riemannian Hamiltonian Variational Autoencoder (RHVAE) on synthetic data generated from evolutionary simulations. We’ve used the RHVAE to learn a low-dimensional representation of the fitness profiles and visualized the latent space to gain insights into the model’s organization of the fitness profiles.

References

Razo-Mejia, Manuel. 2024. AutoEncoderToolkit.jl: A Julia Package for Training (Variational) Autoencoders,” July. https://doi.org/10.21105/joss.06794.