Hamiltonian Variational Autoencoder

The Hamiltonian Variational Autoencoder (HVAE) is a variant of the Variational autoencoder (VAE) that uses Hamiltonian dynamics to improve the sampling of the latent space representation. HVAE combines ideas from Hamiltonian Monte Carlo, annealed importance sampling, and variational inference to improve the latent space representation of the VAE.

For the implementation of the HVAE in AutoEncoderToolkit.jl, the HVAE struct inherits directly from the VAE struct and adds the necessary functions to compute the Hamiltonian dynamics steps as part of the training protocol. An HVAE object is created by simply passing a VAE object to the constructor. This way, we can use Julias multiple dispatch to extend the functionality of the VAE object without having to redefine the entire structure.

Warning

HVAEs require the computation of nested gradients. This means that the AutoDiff framework must differentiate a function of an already AutoDiff differentiated function. This is known to be problematic for Julia's AutoDiff backends. See details below to understand how to we circumvent this problem.

Reference

Caterini, A. L., Doucet, A. & Sejdinovic, D. Hamiltonian Variational Auto-Encoder. 11 (2018).

HVAE struct

AutoEncoderToolkit.HVAEs.HVAEType
struct HVAE{
    V<:VAE{<:AbstractVariationalEncoder,<:AbstractVariationalDecoder}
} <: AbstractVariationalAutoEncoder

Hamiltonian Variational Autoencoder (HVAE) model defined for Flux.jl.

Fields

  • vae::V: A Variational Autoencoder (VAE) model that forms the basis of the HVAE. V is a subtype of VAE with a specific AbstractVariationalEncoder and AbstractVariationalDecoder.

An HVAE is a type of Variational Autoencoder (VAE) that uses Hamiltonian Monte Carlo (HMC) to sample from the posterior distribution in the latent space. The VAE's encoder compresses the input into a low-dimensional probabilistic representation q(z|x). The VAE's decoder tries to reconstruct the original input from a sampled point in the latent space p(x|z).

The HMC sampling in the latent space allows the HVAE to better capture complex posterior distributions compared to a standard VAE, which assumes a simple Gaussian posterior. This can lead to more accurate reconstructions and better disentanglement of latent variables.

Forward pass

AutoEncoderToolkit.HVAEs.HVAEMethod
(hvae::HVAE{VAE{E,D}})(
    x::AbstractArray;
    ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
    K::Int=3,
    βₒ::Number=0.3f0,
    ∇U_kwargs::Union{Dict,NamedTuple}=(
            reconstruction_loglikelihood=reconstruction_loglikelihood,
            latent_logprior=spherical_logprior,
    ),
    tempering_schedule::Function=quadratic_tempering,
    latent::Bool=false,
) where {E<:AbstractGaussianLogEncoder,D<:AbstractVariationalDecoder}

Run the Hamiltonian Variational Autoencoder (HVAE) on the given input.

Arguments

  • x::AbstractArray: The input to the HVAE. If Vector, it represents a single data point. If Array, the last dimension must contain each of the data points.

Optional Keyword Arguments

  • ϵ::Union{<:Number,<:AbstractVector}=0.0001: The step size for the leapfrog steps in the HMC part of the HVAE. If it is a scalar, the same step size is used for all dimensions. If it is an array, each element corresponds to the step size for a specific dimension.
  • K::Int=3: The number of leapfrog steps to perform in the Hamiltonian Monte Carlo (HMC) part of the HVAE.
  • βₒ::Number=0.3f0: The initial inverse temperature for the tempering schedule.
  • ∇U_kwargs::Union{Dict,NamedTuple}: Additional keyword arguments to be passed to the ∇potential_energy function. Default is a NamedTuple with reconstruction_loglikelihood and latent_logprior.
  • tempering_schedule::Function=quadratic_tempering: The function to compute the tempering schedule in the HVAE.
  • latent::Bool=false: If true, the function returns a NamedTuple containing the outputs of the encoder and decoder, and the final state of the phase space after the leapfrog and tempering steps. If false, the function only returns the output of the decoder.

Returns

If latent=true, the function returns a NamedTuple with the following fields:

  • encoder: The outputs of the encoder.
  • decoder: The output of the decoder.
  • phase_space: The final state of the phase space after the leapfrog and tempering steps.

If latent=false, the function only returns the output of the decoder.

Description

This function runs the HVAE on the given input. It first passes the input through the encoder to obtain the mean and log standard deviation of the latent space. It then uses the reparameterization trick to sample from the latent space. After that, it performs the leapfrog and tempering steps to refine the sample from the latent space. Finally, it passes the refined sample through the decoder to obtain the output.

Notes

Ensure that the dimensions of x match the input dimensions of the HVAE, and that the dimensions of ϵ match the dimensions of the latent space.

Loss function

AutoEncoderToolkit.HVAEs.lossFunction
loss(
    hvae::HVAE,
    x::AbstractArray;
    K::Int=3,
    ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
    βₒ::Number=0.3f0,
    ∇U_kwargs::Union{Dict,NamedTuple}=(
        reconstruction_loglikelihood=reconstruction_loglikelihood,
        latent_logprior=spherical_logprior,
    ),
    tempering_schedule::Function=quadratic_tempering,
    reg_function::Union{Function,Nothing}=nothing,
    reg_kwargs::Union{NamedTuple,Dict}=Dict(),
    reg_strength::Float32=1.0f0,
    logp_prefactor::AbstractArray=ones(Float32, 3),
    logq_prefactor::AbstractArray=ones(Float32, 3),
)

Compute the loss for a Hamiltonian Variational Autoencoder (HVAE).

Arguments

  • hvae::HVAE: The HVAE used to encode the input data and decode the latent space.
  • x::AbstractArray: Input data to the HVAE encoder. The last dimension is taken as having each of the samples in a batch.

Optional Keyword Arguments

  • K::Int: The number of HMC steps (default is 3).
  • ϵ::Union{<:Number,<:AbstractVector}: The step size for the leapfrog integrator (default is 0.001).
  • βₒ::Number: The initial inverse temperature (default is 0.3).
  • ∇U_kwargs::Union{Dict,NamedTuple}: Additional keyword arguments to be passed to the ∇potential_energy function.
  • tempering_schedule::Function: The tempering schedule function used in the HMC (default is quadratic_tempering).
  • reg_function::Union{Function, Nothing}=nothing: A function that computes the regularization term based on the VAE outputs. This function must take as input the VAE outputs and the keyword arguments provided in reg_kwargs.
  • reg_kwargs::Union{NamedTuple,Dict}=Dict(): Keyword arguments to pass to the regularization function.
  • reg_strength::Float32=1.0f0: The strength of the regularization term.
  • logp_prefactor::AbstractArray: A 3-element array to scale the log likelihood, log prior of the latent variables, and log prior of the momentum variables. Default is an array of ones.
  • logq_prefactor::AbstractArray: A 3-element array to scale the log posterior of the initial latent variables, log prior of the initial momentum variables, and the tempering Jacobian term. Default is an array of ones.

Returns

  • The computed loss.
loss(
    hvae::HVAE,
    x_in::AbstractArray,
    x_out::AbstractArray;
    K::Int=3,
    ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
    βₒ::Number=0.3f0,
    ∇U_kwargs::Union{Dict,NamedTuple}=(
        reconstruction_loglikelihood=reconstruction_loglikelihood,
        latent_logprior=spherical_logprior,
    ),
    tempering_schedule::Function=quadratic_tempering,
    reg_function::Union{Function,Nothing}=nothing,
    reg_kwargs::Union{NamedTuple,Dict}=Dict(),
    reg_strength::Float32=1.0f0,
    logp_prefactor::AbstractArray=ones(Float32, 3),
    logq_prefactor::AbstractArray=ones(Float32, 3),
)

Compute the loss for a Hamiltonian Variational Autoencoder (HVAE).

Arguments

  • hvae::HVAE: The HVAE used to encode the input data and decode the latent space.
  • x_in::AbstractArray: Input data to the HVAE encoder. The last dimension is taken as having each of the samples in a batch.
  • x_out::AbstractArray: The data against which the reconstruction is compared. If Array, the last dimension must contain each of the data points.

Optional Keyword Arguments

  • K::Int: The number of HMC steps (default is 3).
  • ϵ::Union{<:Number,<:AbstractVector}: The step size for the leapfrog integrator (default is 0.001).
  • βₒ::Number: The initial inverse temperature (default is 0.3).
  • ∇U_kwargs::Union{Dict,NamedTuple}: Additional keyword arguments to be passed to the ∇potential_energy function.
  • tempering_schedule::Function: The tempering schedule function used in the HMC (default is quadratic_tempering).
  • reg_function::Union{Function, Nothing}=nothing: A function that computes the regularization term based on the VAE outputs. This function must take as input the VAE outputs and the keyword arguments provided in reg_kwargs.
  • reg_kwargs::Union{NamedTuple,Dict}=Dict(): Keyword arguments to pass to the regularization function.
  • reg_strength::Float32=1.0f0: The strength of the regularization term.
  • logp_prefactor::AbstractArray: A 3-element array to scale the log likelihood, log prior of the latent variables, and log prior of the momentum variables. Default is an array of ones.
  • logq_prefactor::AbstractArray: A 3-element array to scale the log posterior of the initial latent variables, log prior of the initial momentum variables, and the tempering Jacobian term. Default is an array of ones.

Returns

  • The computed loss.

Training

AutoEncoderToolkit.HVAEs.train!Function
train!(
    hvae::HVAE, 
    x::AbstractArray, 
    opt::NamedTuple; 
    loss_function::Function=loss, 
    loss_kwargs::Union{NamedTuple,Dict}=Dict(),
    verbose::Bool=false,
    loss_return::Bool=false,
)

Customized training function to update parameters of a Hamiltonian Variational Autoencoder given a specified loss function.

Arguments

  • hvae::HVAE: A struct containing the elements of a Hamiltonian Variational Autoencoder.
  • x::AbstractArray: Input data to the HVAE encoder. The last dimension is taken as having each of the samples in a batch.
  • opt::NamedTuple: State of the optimizer for updating parameters. Typically initialized using Flux.Optimisers.update!.

Optional Keyword Arguments

  • loss_function::Function=loss: The loss function used for training. It should accept the HVAE model, data x, and keyword arguments in that order.
  • loss_kwargs::Dict=Dict(): Arguments for the loss function. These might include parameters like K, ϵ, βₒ, steps, ∇H, ∇H_kwargs, tempering_schedule, reg_function, reg_kwargs, reg_strength, depending on the specific loss function in use.
  • verbose::Bool=false: Whether to print the loss at each iteration.
  • loss_return::Bool=false: Whether to return the loss at each iteration.

Description

Trains the HVAE by:

  1. Computing the gradient of the loss w.r.t the HVAE parameters.
  2. Updating the HVAE parameters using the optimizer.
  3. Updating the metric parameters.
train!(
    hvae::HVAE, 
    x_in::AbstractArray,
    x_out::AbstractArray,
    opt::NamedTuple; 
    loss_function::Function=loss, 
    loss_kwargs::Union{NamedTuple,Dict}=Dict(),
    verbose::Bool=false,
    loss_return::Bool=false,
)

Customized training function to update parameters of a Hamiltonian Variational Autoencoder given a specified loss function.

Arguments

  • hvae::HVAE: A struct containing the elements of a Hamiltonian Variational Autoencoder.
  • x_in::AbstractArray: Input data to the HVAE encoder. The last dimension is taken as having each of the samples in a batch.
  • x_out::AbstractArray: Target data to compute the reconstruction error. The last dimension is taken as having each of the samples in a batch.
  • opt::NamedTuple: State of the optimizer for updating parameters. Typically initialized using Flux.Optimisers.update!.

Optional Keyword Arguments

  • loss_function::Function=loss: The loss function used for training. It should accept the HVAE model, data x, and keyword arguments in that order.
  • loss_kwargs::Dict=Dict(): Arguments for the loss function. These might include parameters like K, ϵ, βₒ, steps, ∇H, ∇H_kwargs, tempering_schedule, reg_function, reg_kwargs, reg_strength, depending on the specific loss function in use.
  • verbose::Bool=false: Whether to print the loss at each iteration.
  • loss_return::Bool=false: Whether to return the loss at each iteration.

Description

Trains the HVAE by:

  1. Computing the gradient of the loss w.r.t the HVAE parameters.
  2. Updating the HVAE parameters using the optimizer.
  3. Updating the metric parameters.

Computing the gradient of the potential energy

One of the crucial components in the training of the HVAE is the computation of the gradient of the potential energy $\nabla U$ with respect to the latent space representation. This gradient is used in the leapfrog steps of the Hamiltonian dynamics. When training the HVAE, we need to backpropagate through the leapfrog steps to update the parameters of the neural network. This requires computing a gradient of a function of the gradient of the potential energy, i.e., nested gradients. Zygote.jl the main AutoDiff backend in Flux.jl famously struggle with these types of computations. Specifically, Zygote.jl does not support Zygote over Zygote differentiation (meaning differentiating a function of something previously differentiated with Zygote using Zygote), or Zygote over ForwardDiff (meaning differentiating a function of something differentiated with ForwardDiff using Zygote).

With this, we are left with a couple of options to compute the gradient of the potential energy:

  • Use finite differences to approximate the gradient of the potential energy.
  • Use the relatively new TaylorDiff.jl AutoDiff backend to compute the gradient of the potential energy. This backend is composable with Zygote.jl, so we can, in principle, do Zygote over TaylorDiff differentiation.

The second option would be preferred, as the gradients computed with TaylorDiff are much more accurate than the ones computed with finite differences. However, there are two problems with this approach:

  1. The TaylorDiff nested gradient capability stopped working with Julia ≥ 1.10, as discussed in #70.
  2. Even for Julia < 1.10, we could not get TaylorDiff to work on CUDA devices. (PRs are welcome!)

With these limitations in mind, we have implemented the gradient of the potential using both finite differences and TaylorDiff. The user can choose which method to use by setting the adtype keyword argument in the ∇U_kwargs in the loss function to either :finite or :TaylorDiff. This means that for the train! function, the user can pass loss_kwargs that looks like this:

# Define the autodiff backend to use
loss_kwargs = Dict(
    :∇U_kwargs => Dict(
        :adtype => :finite
    )
)
Note

Although verbose, the nested dictionaries help to keep everything organized. (PRs with better design ideas are welcome!)

The default both for cpu and gpu devices is :finite.

AutoEncoderToolkit.HVAEs.∇potential_energy_finiteFunction
∇potential_energy_finite(
    x::AbstractArray,
    z::AbstractVecOrMat,
    decoder::AbstractVariationalDecoder,
    decoder_output::NamedTuple;
    reconstruction_loglikelihood::Function=decoder_loglikelihood,
    latent_logprior::Function=spherical_logprior,
    fdtype::Symbol=:central
)

Compute the gradient of the potential energy of a Hamiltonian Variational Autoencoder (HVAE) with respect to the latent variables z using finite difference method. This function returns the gradient of the potential energy computed for given data x and latent variable z.

Arguments

  • x::AbstractArray: An array representing the input data. The last dimension corresponds to different data points.
  • z::AbstractVecOrMat: A latent variable encoding of the input data. If a matrix, each column corresponds to a different data point.
  • decoder::AbstractVariationalDecoder: A decoder that maps the latent variables to the data space.
  • decoder_output::NamedTuple: The output of the decoder.

Optional Keyword Arguments

  • reconstruction_loglikelihood::Function=decoder_loglikelihood: A function representing the log-likelihood function used by the decoder. The function must take as first input an AbstractVariationalDecoder struct, as second input an array x representing the data, and as third input a vector or matrix z representing the latent variable. Default is decoder_loglikelihood.
  • latent_logprior::Function=spherical_logprior: A function representing the log-prior distribution used in the autoencoder. The function must take as single input a vector or matrix z representing the latent variable. Default is spherical_logprior.
  • fdtype::Symbol=:central: A symbol representing the type of finite difference method to use. Default is :central, but it can also be :forward.

Returns

  • gradient: The computed gradient of the potential energy for the given input x and latent variable z.
∇potential_energy_finite(
    x::AbstractArray,
    z::AbstractVecOrMat,
    hvae::HVAE;
    reconstruction_loglikelihood::Function=decoder_loglikelihood,
    latent_logprior::Function=spherical_logprior,
    fdtype::Symbol=:central
)

Compute the gradient of the potential energy of a Hamiltonian Variational Autoencoder (HVAE) with respect to the latent variables z using finite difference method. This function returns the gradient of the potential energy computed for given data x and latent variable z.

Arguments

  • x::AbstractArray: An array representing the input data. The last dimension corresponds to different data points.
  • z::AbstractVecOrMat: A latent variable encoding of the input data. If a matrix, each column corresponds to a different data point.
  • hvae::HVAE: An HVAE model that contains a decoder which maps the latent variables to the data space.

Optional Keyword Arguments

  • reconstruction_loglikelihood::Function=decoder_loglikelihood: A function representing the log-likelihood function used by the decoder. The function must take as first input an array x representing the data, as second input a vector or matrix z representing the latent variable, and as third input a decoder. Default is decoder_loglikelihood.
  • latent_logprior::Function=spherical_logprior: A function representing the log-prior distribution used in the autoencoder. The function must take as single input a vector or matrix z representing the latent variable. Default is spherical_logprior.
  • fdtype::Symbol=:central: A symbol representing the type of finite difference method to use. Default is :central, but it can also be :forward.

Returns

  • gradient: The computed gradient of the potential energy for the given input x and latent variable z.
AutoEncoderToolkit.HVAEs.∇potential_energy_TaylorDiffFunction
∇potential_energy_TaylorDiff(
    x::AbstractArray,
    z::AbstractVecOrMat,
    hvae::HVAE;
    reconstruction_loglikelihood::Function=decoder_loglikelihood,
    latent_logprior::Function=spherical_logprior,
)

Compute the gradient of the potential energy of a Hamiltonian Variational Autoencoder (HVAE) with respect to the latent variables z using Taylor series differentiation. This function returns the gradient of the potential energy computed for given data x and latent variable z.

Arguments

  • x::AbstractArray: An array representing the input data. The last dimension corresponds to different data points.
  • z::AbstractVecOrMat: A latent variable encoding of the input data. If a matrix, each column corresponds to a different data point.
  • hvae::HVAE: An HVAE model that contains a decoder which maps the latent variables to the data space.

Optional Keyword Arguments

  • reconstruction_loglikelihood::Function=decoder_loglikelihood: A function representing the log-likelihood function used by the decoder. The function must take as first input an array x representing the data, as second input a vector or matrix z representing the latent variable, and as third input a decoder. Default is decoder_loglikelihood.
  • latent_logprior::Function=spherical_logprior: A function representing the log-prior distribution used in the autoencoder. The function must take as single input a vector or matrix z representing the latent variable. Default is spherical_logprior.

Returns

  • gradient: The computed gradient of the potential energy for the given input x and latent variable z.
∇potential_energy_TaylorDiff(
    x::AbstractArray,
    z::AbstractVecOrMat,
    hvae::HVAE;
    reconstruction_loglikelihood::Function=decoder_loglikelihood,
    latent_logprior::Function=spherical_logprior,
)

Compute the gradient of the potential energy of a Hamiltonian Variational Autoencoder (HVAE) with respect to the latent variables z using Taylor series differentiation. This function returns the gradient of the potential energy computed for given data x and latent variable z.

Arguments

  • x::AbstractArray: An array representing the input data. The last dimension corresponds to different data points.
  • z::AbstractVecOrMat: A latent variable encoding of the input data. If a matrix, each column corresponds to a different data point.
  • hvae::HVAE: An HVAE model that contains a decoder which maps the latent variables to the data space.

Optional Keyword Arguments

  • reconstruction_loglikelihood::Function=decoder_loglikelihood: A function representing the log-likelihood function used by the decoder. The function must take as first input an array x representing the data, as second input a vector or matrix z representing the latent variable, and as third input a decoder. Default is decoder_loglikelihood.
  • latent_logprior::Function=spherical_logprior: A function representing the log-prior distribution used in the autoencoder. The function must take as single input a vector or matrix z representing the latent variable. Default is spherical_logprior.

Returns

  • gradient: The computed gradient of the potential energy for the given input x and latent variable z.

Other Functions

AutoEncoderToolkit.HVAEs.potential_energyFunction
potential_energy(
    x::AbstractVector,
    z::AbstractVector,
    decoder::AbstractVariationalDecoder,
    decoder_output::NamedTuple;
    reconstruction_loglikelihood::Function=decoder_loglikelihood,
    latent_logprior::Function=spherical_logprior
)

Compute the potential energy of a Hamiltonian Variational Autoencoder (HVAE). In the context of Hamiltonian Monte Carlo (HMC), the potential energy is defined as the negative log-posterior. This function computes the potential energy for given data x and latent variable z. It does this by computing the log-likelihood of x under the distribution defined by reconstruction_loglikelihood(x, z, decoder, decoder_output), and the log-prior of z under the latent_logprior distribution. The potential energy is then computed as:

    U(x, z) = -log p(x | z) - log p(z)

Arguments

  • x::AbstractArray: An array representing the input data. The last dimension corresponds to different data points.
  • z::AbstractVecOrMat: A latent variable encoding of the input data. If a matrix, each column corresponds to a different data point.
  • decoder::AbstractVariationalDecoder: A decoder that maps the latent variables to the data space.
  • decoder_output::NamedTuple: The output of the decoder.

Optional Keyword Arguments

  • reconstruction_loglikelihood::Function=decoder_loglikelihood: A function representing the log-likelihood function used by the decoder. The function must take as first input a vector x representing the data, as second input a vector z representing the latent variable, as third input a decoder, and as fourth input a NamedTuple representing the decoder output. Default is decoder_loglikelihood.
  • latent_logprior::Function=spherical_logprior: A function representing the log-prior distribution used in the autoencoder. The function must take as single input a vector z representing the latent variable. Default is spherical_logprior.

Returns

  • energy: The computed potential energy for the given input x and latent variable z.
potential_energy(
    x::AbstractArray,
    z::AbstractVecOrMat,
    hvae::HVAE;
    reconstruction_loglikelihood::Function=decoder_loglikelihood,
    latent_logprior::Function=spherical_logprior
)

Compute the potential energy of a Hamiltonian Variational Autoencoder (HVAE). In the context of Hamiltonian Monte Carlo (HMC), the potential energy is defined as the negative log-posterior. This function computes the potential energy for given data x and latent variable z. It does this by computing the log-likelihood of x under the distribution defined by reconstruction_loglikelihood(x, z, hvae.vae.decoder, decoder_output), and the log-prior of z under the latent_logprior distribution. The potential energy is then computed as:

            U(x, z) = -log p(x | z) - log p(z)

Arguments

  • x::AbstractArray: An array representing the input data. The last dimension corresponds to different data points.
  • z::AbstractVecOrMat: A latent variable encoding of the input data. If a matrix, each column corresponds to a different data point.
  • hvae::HVAE: A Hamiltonian Variational Autoencoder that contains the decoder.

Optional Keyword Arguments

  • reconstruction_loglikelihood::Function=decoder_loglikelihood: A function representing the log-likelihood function used by the decoder. The function must take as first input an array x representing the data, as second input a vector or matrix z representing the latent variable, as third input a decoder, and as fourth input a NamedTuple representing the decoder output. Default is decoder_loglikelihood.
  • latent_logprior::Function=spherical_logprior: A function representing the log-prior distribution used in the autoencoder. The function must take as single input a vector or matrix z representing the latent variable. Default is spherical_logprior.

Returns

  • energy: The computed potential energy for the given input x and latent variable z.
AutoEncoderToolkit.HVAEs.∇potential_energyFunction
∇potential_energy(
    x::AbstractArray,
    z::AbstractVecOrMat,
    decoder::AbstractVariationalDecoder,
    decoder_output::NamedTuple;
    reconstruction_loglikelihood::Function=decoder_loglikelihood,
    latent_logprior::Function=spherical_logprior,
    adtype::Union{Symbol,Nothing}=nothing,
    adkwargs::Union{NamedTuple,Dict}=Dict(),
)

Compute the gradient of the potential energy of a Hamiltonian Variational Autoencoder (HVAE) with respect to the latent variables z using the specified automatic differentiation method. This function returns the gradient of the potential energy computed for given data x and latent variable z.

Arguments

  • x::AbstractArray: An array representing the input data. The last dimension corresponds to different data points.
  • z::AbstractVecOrMat: A latent variable encoding of the input data. If a matrix, each column corresponds to a different data point.
  • decoder::AbstractVariationalDecoder: A decoder that maps the latent variables to the data space.
  • decoder_output::NamedTuple: The output of the decoder.

Optional Keyword Arguments

  • reconstruction_loglikelihood::Function=decoder_loglikelihood: A function representing the log-likelihood function used by the decoder. The function must take as first input an AbstractVariationalDecoder struct, as second input an array x representing the data, and as third input a vector or matrix z representing the latent variable. Default is decoder_loglikelihood.
  • latent_logprior::Function=spherical_logprior: A function representing the log-prior distribution used in the autoencoder. The function must take as single input a vector or matrix z representing the latent variable. Default is spherical_logprior.
  • adtype::Symbol=:finite: The type of automatic differentiation method to use. Must be:finiteor:TaylorDiff. Default is:finite`.
  • adkwargs::Union{NamedTuple,Dict}=Dict(): Additional keyword arguments to pass to the automatic differentiation method.

Returns

  • gradient: The computed gradient of the potential energy for the given input x and latent variable z.
∇potential_energy(
    x::AbstractArray,
    z::AbstractVecOrMat,
    hvae::HVAE;
    reconstruction_loglikelihood::Function=decoder_loglikelihood,
    latent_logprior::Function=spherical_logprior,
    adtype::Union{Symbol,Nothing}=nothing,
    adkwargs::Union{NamedTuple,Dict}=Dict(),
)

Compute the gradient of the potential energy of a Hamiltonian Variational Autoencoder (HVAE) with respect to the latent variables z using the specified automatic differentiation method. This function returns the gradient of the potential energy computed for given data x and latent variable z.

Arguments

  • x::AbstractArray: An array representing the input data. The last dimension corresponds to different data points.
  • z::AbstractVecOrMat: A latent variable encoding of the input data. If a matrix, each column corresponds to a different data point.
  • hvae::HVAE: An HVAE model that contains a decoder which maps the latent variables to the data space.

Optional Keyword Arguments

  • reconstruction_loglikelihood::Function=decoder_loglikelihood: A function representing the log-likelihood function used by the decoder. The function must take as first input an array x representing the data, as second input a vector or matrix z representing the latent variable, and as third input a decoder. Default is decoder_loglikelihood.
  • latent_logprior::Function=spherical_logprior: A function representing the log-prior distribution used in the autoencoder. The function must take as single input a vector or matrix z representing the latent variable. Default is spherical_logprior.
    • adtype::Symbol=:finite`: The type of automatic differentiation method to
    use. Must be :finite or :TaylorDiff. Default is :finite.
  • adkwargs::Union{NamedTuple,Dict}=Dict(): Additional keyword arguments to pass to the automatic differentiation method.

Returns

  • gradient: The computed gradient of the potential energy for the given input x and latent variable z.
AutoEncoderToolkit.HVAEs.leapfrog_stepFunction
leapfrog_step(
    x::AbstractArray,
    z::AbstractVecOrMat,
    ρ::AbstractVecOrMat,
    decoder::AbstractVariationalDecoder,
    decoder_output::NamedTuple;
    ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
    ∇U_kwargs::Union{Dict,NamedTuple}=(
        reconstruction_loglikelihood=reconstruction_loglikelihood,
        latent_logprior=spherical_logprior,
    )
)

Perform a full step of the leapfrog integrator for Hamiltonian dynamics.

The leapfrog integrator is a numerical integration scheme used to simulate Hamiltonian dynamics. It consists of three steps:

  1. Half update of the momentum variable:

     ρ(t + ϵ/2) = ρ(t) - 0.5 * ϵ * ∇z_U(z(t), ρ(t + ϵ/2)).
  2. Full update of the position variable:

     z(t + ϵ) = z(t) + ϵ * ρ(t + ϵ/2).
  3. Half update of the momentum variable:

     ρ(t + ϵ) = ρ(t + ϵ/2) - 0.5 * ϵ * ∇z_U(z(t + ϵ), ρ(t + ϵ/2)).

This function performs these three steps in sequence.

Arguments

  • x::AbstractArray: The point in the data space. This does not necessarily need to be a vector. Array inputs are supported. The last dimension is assumed to have each of the data points.
  • z::AbstractVecOrMat: The point in the latent space. If matrix, each column represents a point in the latent space.
  • ρ::AbstractVecOrMat: The momentum. If matrix, each column represents a momentum vector.
  • decoder::AbstractVariationalDecoder: The decoder instance.
  • decoder_output::NamedTuple: The output of the decoder.

Optional Keyword Arguments

  • ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4): The step size. Default is 0.0001.
  • ∇U_kwargs::Union{Dict,NamedTuple}: The keyword arguments for ∇potential_energy. Default is a tuple with reconstruction_loglikelihood and latent_logprior.

Returns

A tuple (z̄, ρ̄, decoder_output_z̄) representing the updated position and momentum after performing the full leapfrog step as well as the decoder output of the updated position.

leapfrog_step(
    x::AbstractArray,
    z::AbstractVecOrMat,
    ρ::AbstractVecOrMat,
    hvae::HVAE;
    ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
    ∇U_kwargs::Union{Dict,NamedTuple}=(
        reconstruction_loglikelihood=reconstruction_loglikelihood,
        latent_logprior=spherical_logprior,
    )
)

Perform a full step of the leapfrog integrator for Hamiltonian dynamics.

The leapfrog integrator is a numerical integration scheme used to simulate Hamiltonian dynamics. It consists of three steps:

  1. Half update of the momentum variable:

     ρ(t + ϵ/2) = ρ(t) - 0.5 * ϵ * ∇z_U(z(t), ρ(t + ϵ/2)).
  2. Full update of the position variable:

     z(t + ϵ) = z(t) + ϵ * ρ(t + ϵ/2).
  3. Half update of the momentum variable:

     ρ(t + ϵ) = ρ(t + ϵ/2) - 0.5 * ϵ * ∇z_U(z(t + ϵ), ρ(t + ϵ/2)).

This function performs these three steps in sequence.

Arguments

  • x::AbstractArray: The point in the data space. This does not necessarily need to be a vector. Array inputs are supported. The last dimension is assumed to have each of the data points.
  • z::AbstractVecOrMat: The point in the latent space. If matrix, each column represents a point in the latent space.
  • ρ::AbstractVecOrMat: The momentum. If matrix, each column represents a momentum vector.
  • hvae::HVAE: An HVAE model that contains the decoder.

Optional Keyword Arguments

  • ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4): The step size. Default is 0.0001.
  • ∇U_kwargs::Union{Dict,NamedTuple}: The keyword arguments for ∇potential_energy. Default is a tuple with reconstruction_loglikelihood and latent_logprior.

Returns

A tuple (z̄, ρ̄, decoder_output_z̄) representing the updated position and momentum after performing the full leapfrog step as well as the decoder output of the updated position.

AutoEncoderToolkit.HVAEs.quadratic_temperingFunction
quadratic_tempering(βₒ::AbstractFloat, k::Int, K::Int)

Compute the inverse temperature βₖ at a given stage k of a tempering schedule with K total stages, using a quadratic tempering scheme.

Tempering is a technique used in sampling algorithms to improve mixing and convergence. It involves running parallel chains of the algorithm at different "temperatures", and swapping states between the chains. The "temperature" of a chain is controlled by an inverse temperature parameter β, which is varied according to a tempering schedule.

In a quadratic tempering schedule, the inverse temperature βₖ at stage k is computed as the square of the quantity ((1 - 1 / √(βₒ)) * (k / K)^2 + 1 / √(βₒ)), where βₒ is the initial inverse temperature. This schedule starts at βₒ when k = 0, and increases quadratically as k increases, reaching 1 when k = K.

Arguments

  • βₒ::AbstractFloat: The initial inverse temperature.
  • k::Int: The current stage of the tempering schedule.
  • K::Int: The total number of stages in the tempering schedule.

Returns

  • βₖ::AbstractFloat: The inverse temperature at stage k.
AutoEncoderToolkit.HVAEs.null_temperingFunction
    null_tempering(βₒ::T, k::Int, K::Int) where {T<:AbstractFloat}

Return the initial inverse temperature βₒ. This function is used in the context of tempered Hamiltonian Monte Carlo (HMC) methods, where tempering involves running HMC at different "temperatures" to improve mixing and convergence.

In this case, null_tempering is a simple tempering schedule that does not actually change the temperature—it always returns the initial inverse temperature βₒ. This can be useful as a default or placeholder tempering schedule.

Arguments

  • βₒ::AbstractFloat: The initial inverse temperature.
  • k::Int: The current step in the tempering schedule. Not used in this function, but included for compatibility with other tempering schedules.
  • K::Int: The total number of steps in the tempering schedule. Not used in this function, but included for compatibility with other tempering schedules.

Returns

  • β::T: The inverse temperature for the current step, which is always βₒ in this case.

Example

βₒ = 0.5
k = 1
K = 10
β = null_tempering(βₒ, k, K)  # β will be 0.5
AutoEncoderToolkit.HVAEs.leapfrog_tempering_stepFunction
leapfrog_tempering_step(
    x::AbstractArray,
    zₒ::AbstractVecOrMat,
    decoder::AbstractVariationalDecoder,
    decoder_output::NamedTuple;
    ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
    K::Int=3,
    βₒ::Number=0.3f0,
    ∇U_kwargs::Union{Dict,NamedTuple}=(
        reconstruction_loglikelihood=reconstruction_loglikelihood,
        latent_logprior=spherical_logprior,
    ),
    tempering_schedule::Function=quadratic_tempering,
)

Combines the leapfrog and tempering steps into a single function for the Hamiltonian Variational Autoencoder (HVAE).

Arguments

  • x::AbstractArray: The data to be processed. If Array, the last dimension must be of size 1.
  • zₒ::AbstractVecOrMat: The initial latent variable.
  • decoder::AbstractVariationalDecoder: The decoder of the HVAE model.
  • decoder_output::NamedTuple: The output of the decoder.

Optional Keyword Arguments

  • ϵ::Union{<:Number,<:AbstractVector}: The step size for the leapfrog steps in the HMC algorithm. This can be a scalar or an array. Default is 0.0001.
  • K::Int: The number of leapfrog steps to perform in the Hamiltonian Monte Carlo (HMC) algorithm. Default is 3.
  • βₒ::Number: The initial inverse temperature for the tempering schedule. Default is 0.3f0.
  • ∇U_kwargs::Union{Dict,NamedTuple}: Additional keyword arguments to be passed to the ∇potential_energy function. Default is a NamedTuple with reconstruction_loglikelihood and latent_logprior.
  • tempering_schedule::Function: The function to compute the inverse temperature at each step in the HMC algorithm. Defaults to quadratic_tempering. This function must take three arguments: First, βₒ, an initial inverse temperature, second, k, the current step in the tempering schedule, and third, K, the total number of steps in the tempering schedule.

Returns

  • A NamedTuple with the following keys:
    • z_init: The initial latent variable.
    • ρ_init: The initial momentum variable.
    • z_final: The final latent variable after K leapfrog steps.
    • ρ_final: The final momentum variable after K leapfrog steps.
  • The decoder output at the final latent variable is also returned. Note: This is not in the same named tuple as the other outputs, but as a separate output.

Description

The function first samples a random momentum variable γₒ from a standard normal distribution and scales it by the inverse square root of the initial inverse temperature βₒ to obtain the initial momentum variable ρₒ. Then, it performs K leapfrog steps, each followed by a tempering step, to generate a new sample from the latent space.

Note

Ensure the input data x and the initial latent variable zₒ match the expected input dimensionality for the HVAE model.

leapfrog_tempering_step(
    x::AbstractArray,
    zₒ::AbstractVecOrMat,
    hvae::HVAE;
    ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
    K::Int=3,
    βₒ::Number=0.3f0,
    ∇U_kwargs::Union{Dict,NamedTuple}=(
        reconstruction_loglikelihood=reconstruction_loglikelihood,
        latent_logprior=spherical_logprior,
    ),
    tempering_schedule::Function=quadratic_tempering,
)

Combines the leapfrog and tempering steps into a single function for the Hamiltonian Variational Autoencoder (HVAE).

Arguments

  • x::AbstractArray: The data to be processed. If Array, the last dimension must be of size 1.
  • zₒ::AbstractVecOrMat: The initial latent variable.
  • hvae::HVAE: An HVAE model that contains the decoder.

Optional Keyword Arguments

  • ϵ::Union{<:Number,<:AbstractVector}: The step size for the leapfrog steps in the HMC algorithm. This can be a scalar or an array. Default is 0.0001.
  • K::Int: The number of leapfrog steps to perform in the Hamiltonian Monte Carlo (HMC) algorithm. Default is 3.
  • βₒ::Number: The initial inverse temperature for the tempering schedule. Default is 0.3f0.
  • ∇U_kwargs::Union{Dict,NamedTuple}: Additional keyword arguments to be passed to the ∇potential_energy function. Default is a NamedTuple with reconstruction_loglikelihood and latent_logprior.
  • tempering_schedule::Function: The function to compute the inverse temperature at each step in the HMC algorithm. Defaults to quadratic_tempering. This function must take three arguments: First, βₒ, an initial inverse temperature, second, k, the current step in the tempering schedule, and third, K, the total number of steps in the tempering schedule.

Returns

  • A NamedTuple with the following keys:
    • z_init: The initial latent variable.
    • ρ_init: The initial momentum variable.
    • z_final: The final latent variable after K leapfrog steps.
    • ρ_final: The final momentum variable after K leapfrog steps.
  • The decoder output at the final latent variable is also returned. Note: This is not in the same named tuple as the other outputs, but as a separate output.

Description

The function first samples a random momentum variable γₒ from a standard normal distribution and scales it by the inverse square root of the initial inverse temperature βₒ to obtain the initial momentum variable ρₒ. Then, it performs K leapfrog steps, each followed by a tempering step, to generate a new sample from the latent space.

Note

Ensure the input data x and the initial latent variable zₒ match the expected input dimensionality for the HVAE model.

AutoEncoderToolkit.HVAEs._log_p̄Function
_log_p̄(
    x::AbstractArray,
    hvae::HVAE{VAE{E,D}},
    hvae_outputs::NamedTuple;
    reconstruction_loglikelihood::Function=decoder_loglikelihood,
    logprior::Function=spherical_logprior,
    prefactor::AbstractArray=ones(Float32, 3),
)

This is an internal function used in hamiltonian_elbo to compute the numerator of the unbiased estimator of the marginal likelihood. The function computes the sum of the log likelihood of the data given the latent variables, the log prior of the latent variables, and the log prior of the momentum variables.

    log p̄ = log p(x | zₖ) + log p(zₖ) + log p(ρₖ)

Arguments

  • x::AbstractArray: The input data. If Array, the last dimension must contain each of the data points.
  • hvae::HVAE{<:VAE{<:AbstractGaussianEncoder,<:AbstractGaussianLogDecoder}}: The Hamiltonian Variational Autoencoder (HVAE) model.
  • hvae_outputs::NamedTuple: The outputs of the HVAE, including the final latent variables zₖ and the final momentum variables ρₖ.

Optional Keyword Arguments

  • reconstruction_loglikelihood::Function: The function to compute the log likelihood of the data given the latent variables. Default is decoder_loglikelihood.
  • logprior::Function: The function to compute the log prior of the latent variables. Default is spherical_logprior.
  • prefactor::AbstractArray: A 3-element array to scale the log likelihood, log prior of the latent variables, and log prior of the momentum variables. Default is an array of ones.

Returns

  • log_p̄::AbstractVector: The first term of the log of the unbiased estimator of the marginal likelihood for each data point.

Note

This is an internal function and should not be called directly. It is used as part of the hamiltonian_elbo function.

AutoEncoderToolkit.HVAEs._log_q̄Function
_log_q̄(
    hvae::HVAE,
    hvae_outputs::NamedTuple,
    βₒ::Number;
    logprior::Function=spherical_logprior,
    prefactor::AbstractArray=ones(Float32, 3),
)

This is an internal function used in hamiltonian_elbo to compute the second term of the unbiased estimator of the marginal likelihood. The function computes the sum of the log posterior of the initial latent variables and the log prior of the initial momentum variables, minus a term that depends on the dimensionality of the latent space and the initial temperature.

log q̄ = log q(zₒ | x) + log p(ρₒ | zₒ) - d/2 log(βₒ)

Arguments

  • hvae::HVAE: The Hamiltonian Variational Autoencoder (HVAE) model.
  • hvae_outputs::NamedTuple: The outputs of the HVAE, including the initial latent variables zₒ and the initial momentum variables ρₒ.
  • βₒ::Number: The initial temperature for the tempering steps.

Optional Keyword Arguments

  • logprior::Function: The function to compute the log prior of the momentum variables. Default is spherical_logprior.
  • prefactor::AbstractArray: A 3-element array to scale the log posterior of the initial latent variables, log prior of the initial momentum variables, and the tempering Jacobian term. Default is an array of ones.

Returns

  • log_q̄::Vector: The second term of the log of the unbiased estimator of the marginal likelihood for each data point.

Note

This is an internal function and should not be called directly. It is used as part of the hamiltonian_elbo function.

AutoEncoderToolkit.HVAEs.hamiltonian_elboFunction
hamiltonian_elbo(
    hvae::HVAE,
    x::AbstractArray;
    ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
    K::Int=3,
    βₒ::Number=0.3f0,
    ∇U_kwargs::Union{Dict,NamedTuple}=(
        reconstruction_loglikelihood=decoder_loglikelihood,
        latent_logprior=spherical_logprior,
    ),
    tempering_schedule::Function=quadratic_tempering,
    return_outputs::Bool=false,
    logp_prefactor::AbstractArray=ones(Float32, 3),
    logq_prefactor::AbstractArray=ones(Float32, 3),
)

Compute the Hamiltonian Monte Carlo (HMC) estimate of the evidence lower bound (ELBO) for a Hamiltonian Variational Autoencoder (HVAE).

This function takes as input an HVAE and a vector of input data x. It performs K HMC steps with a leapfrog integrator and a tempering schedule to estimate the ELBO. The ELBO is computed as the difference between the log p̄ and log q̄ as

elbo = mean(log p̄ - log q̄),

Arguments

  • hvae::HVAE: The HVAE used to encode the input data and decode the latent space.
  • x::AbstractArray: The input data. If Array, the last dimension must contain each of the data points.

Optional Keyword Arguments

  • ϵ::Union{<:Number,<:AbstractVector}: The step size for the leapfrog integrator (default is 0.01).
  • K::Int: The number of HMC steps (default is 3).
  • βₒ::Number: The initial inverse temperature (default is 0.3).
  • ∇U_kwargs::Union{Dict,NamedTuple}: Additional keyword arguments to be passed to the ∇potential_energy function. Defaults to a NamedTuple with :reconstruction_loglikelihood set to decoder_loglikelihood and :latent_logprior set to spherical_logprior.
  • tempering_schedule::Function: The tempering schedule function used in the HMC (default is quadratic_tempering).
  • return_outputs::Bool: Whether to return the outputs of the HVAE. Defaults to false. NOTE: This is necessary to avoid computing the forward pass twice when computing the loss function with regularization.
  • logp_prefactor::AbstractArray: A 3-element array to scale the log likelihood, log prior of the latent variables, and log prior of the momentum variables. Default is an array of ones.
  • logq_prefactor::AbstractArray: A 3-element array to scale the log posterior of the initial latent variables, log prior of the initial momentum variables, and the tempering Jacobian term. Default is an array of ones.

Returns

  • elbo::Number: The HMC estimate of the ELBO. If return_outputs is true, also returns the outputs of the HVAE.
hamiltonian_elbo(
    hvae::HVAE,
    x_in::AbstractArray,
    x_out::AbstractArray;
    ϵ::Union{<:Number,<:AbstractVector}=Float32(1E-4),
    K::Int=3,
    βₒ::Number=0.3f0,
    ∇U_kwargs::Union{Dict,NamedTuple}=(
        reconstruction_loglikelihood=decoder_loglikelihood,
        latent_logprior=spherical_logprior,
    ),
    tempering_schedule::Function=quadratic_tempering,
    return_outputs::Bool=false,
    logp_prefactor::AbstractArray=ones(Float32, 3),
    logq_prefactor::AbstractArray=ones(Float32, 3),
)

Compute the Hamiltonian Monte Carlo (HMC) estimate of the evidence lower bound (ELBO) for a Hamiltonian Variational Autoencoder (HVAE).

This function takes as input an HVAE and a vector of input data x. It performs K HMC steps with a leapfrog integrator and a tempering schedule to estimate the ELBO. The ELBO is computed as the difference between the log p̄ and log q̄ as

elbo = mean(log p̄ - log q̄),

Arguments

  • hvae::HVAE: The HVAE used to encode the input data and decode the latent space.
  • x_in::AbstractArray: The input data. If Array, the last dimension must contain each of the data points.
  • x_out::AbstractArray: The data against which the reconstruction is compared. If Array, the last dimension must contain each of the data points.

Optional Keyword Arguments

  • ϵ::Union{<:Number,<:AbstractVector}: The step size for the leapfrog integrator (default is 0.01).
  • K::Int: The number of HMC steps (default is 3).
  • βₒ::Number: The initial inverse temperature (default is 0.3).
  • ∇U_kwargs::Union{Dict,NamedTuple}: Additional keyword arguments to be passed to the ∇potential_energy function. Defaults to a NamedTuple with :reconstruction_loglikelihood set to decoder_loglikelihood and :latent_logprior set to spherical_logprior.
  • tempering_schedule::Function: The tempering schedule function used in the HMC (default is quadratic_tempering).
  • return_outputs::Bool: Whether to return the outputs of the HVAE. Defaults to false. NOTE: This is necessary to avoid computing the forward pass twice when computing the loss function with regularization.
  • logp_prefactor::AbstractArray: A 3-element array to scale the log likelihood, log prior of the latent variables, and log prior of the momentum variables. Default is an array of ones.
  • logq_prefactor::AbstractArray: A 3-element array to scale the log posterior of the initial latent variables, log prior of the initial momentum variables, and the tempering Jacobian term. Default is an array of ones.

Returns

  • elbo::Number: The HMC estimate of the ELBO. If return_outputs is true, also returns the outputs of the HVAE.