β-Variational Autoencoder

Variational Autoencoders, first introduced by Kingma and Welling in 2014, are a type of generative model that learns to encode high-dimensional data into a low-dimensional latent space. The main idea behind VAEs is to learn a probabilistic mapping (via variational inference) from the input data to the latent space, which allows for the generation of new data points by sampling from the latent space.

Their counterpart, the β-VAE, introduced by Higgins et al. in 2017, is a variant of the original VAE that includes a hyperparameter β that controls the relative importance of the reconstruction loss and the KL divergence term in the loss function. By adjusting β, the user can control the trade-off between the reconstruction quality and the disentanglement of the latent space.

In terms of implementation, the VAE struct in AutoEncoderToolkit.jl is a simple feedforward network composed of variational encoder and decoder parts. This means that the encoder has a log-posterior function and a KL divergence function associated with it, while the decoder has a log-likehood function associated with it.

References

VAE

Kingma, D. P. & Welling, M. Auto-Encoding Variational Bayes. Preprint at http://arxiv.org/abs/1312.6114 (2014).

β-VAE

Higgins, I. et al. β-VAE: LEARNING BASIC VISUAL CONCEPTS WITH A CONSTRAINED VARIATIONAL FRAMEWORK. (2017).

VAE struct

AutoEncoderToolkit.VAEs.VAEType

struct VAE{E<:AbstractVariationalEncoder, D<:AbstractVariationalDecoder}

Variational autoencoder (VAE) model defined for Flux.jl

Fields

  • encoder::E: Neural network that encodes the input into the latent space. E is a subtype of AbstractVariationalEncoder.
  • decoder::D: Neural network that decodes the latent representation back to the original input space. D is a subtype of AbstractVariationalDecoder.

A VAE consists of an encoder and decoder network with a bottleneck latent space in between. The encoder compresses the input into a low-dimensional probabilistic representation q(z|x). The decoder tries to reconstruct the original input from a sampled point in the latent space p(x|z).

Forward pass

AutoEncoderToolkit.VAEs.VAEMethod
    (vae::VAE)(x::AbstractArray; latent::Bool=false)

Perform the forward pass of a Variational Autoencoder (VAE).

This function takes as input a VAE and a vector or matrix of input data x. It first runs the input through the encoder to obtain the mean and log standard deviation of the latent variables. It then uses the reparameterization trick to sample from the latent distribution. Finally, it runs the latent sample through the decoder to obtain the output.

Arguments

  • vae::VAE: The VAE used to encode the input data and decode the latent space.
  • x::AbstractArray: The input data. If array, the last dimension contains each of the samples in a batch.

Optional Keyword Arguments

  • latent::Bool: Whether to return the latent variables along with the decoder output. If true, the function returns a tuple containing the encoder outputs, the latent sample, and the decoder outputs. If false, the function only returns the decoder outputs. Defaults to false.

Returns

  • If latent is true, returns a tuple containing:
    • encoder: The outputs of the encoder.
    • z: The latent sample.
    • decoder: The outputs of the decoder.
  • If latent is false, returns the outputs of the decoder.

Example

# Define a VAE
vae = VAE(
    encoder=Flux.Chain(Flux.Dense(784, 400, relu), Flux.Dense(400, 20)),
    decoder=Flux.Chain(Flux.Dense(20, 400, relu), Flux.Dense(400, 784))
)

# Define input data
x = rand(Float32, 784)

# Perform the forward pass
outputs = vae(x, latent=true)

Loss function

AutoEncoderToolkit.VAEs.lossFunction
loss(
    vae::VAE,
    x::AbstractArray;
    β::Number=1.0f0,
    reconstruction_loglikelihood::Function=decoder_loglikelihood,
    kl_divergence::Function=encoder_kl,
    reg_function::Union{Function,Nothing}=nothing,
    reg_kwargs::Union{NamedTuple,Dict}=Dict(),
    reg_strength::Number=1.0f0
)

Computes the loss for the variational autoencoder (VAE).

The loss function combines the reconstruction loss with the Kullback-Leibler (KL) divergence, and possibly a regularization term, defined as:

loss = -⟨logπ(x|z)⟩ + β × Dₖₗ[qᵩ(z|x) || π(z)] + regstrength × regterm

Where:

  • π(x|z) is a probabilistic decoder: π(x|z) = N(f(z), σ² I̲̲)) - f(z) is the function defining the mean of the decoder π(x|z) - qᵩ(z|x) is the approximated encoder: qᵩ(z|x) = N(g(x), h(x))
  • g(x) and h(x) define the mean and covariance of the encoder respectively.

Arguments

  • vae::VAE: A VAE model with encoder and decoder networks.
  • x::AbstractArray: Input data. The last dimension is taken as having each of the samples in a batch.

Optional Keyword Arguments

  • β::Number=1.0f0: Weighting factor for the KL-divergence term, used for annealing.
  • reconstruction_loglikelihood::Function=decoder_loglikelihood: A function that computes the reconstruction log likelihood.
  • kl_divergence::Function=encoder_kl: A function that computes the Kullback-Leibler divergence between the encoder output and a standard normal.
  • reg_function::Union{Function, Nothing}=nothing: A function that computes the regularization term based on the VAE outputs. Should return a Float32. This function must take as input the VAE outputs and the keyword arguments provided in reg_kwargs.
  • reg_kwargs::Union{NamedTuple,Dict}=Dict(): Keyword arguments to pass to the regularization function.
  • reg_strength::Number=1.0f0: The strength of the regularization term.

Returns

  • T: The computed average loss value for the input x and its reconstructed counterparts, including possible regularization terms.

Note

  • Ensure that the input data x matches the expected input dimensionality for the encoder in the VAE.
loss(
    vae::VAE,
    x_in::AbstractArray,
    x_out::AbstractArray;
    β::Number=1.0f0,
    reconstruction_loglikelihood::Function=decoder_loglikelihood,
    kl_divergence::Function=encoder_kl,
    reg_function::Union{Function,Nothing}=nothing,
    reg_kwargs::Union{NamedTuple,Dict}=Dict(),
    reg_strength::Number=1.0f0
)

Computes the loss for the variational autoencoder (VAE).

The loss function combines the reconstruction loss with the Kullback-Leibler (KL) divergence and possibly a regularization term, defined as:

loss = -⟨logπ(xout|z)⟩ + β × Dₖₗ[qᵩ(z|xin) || π(z)] + regstrength × regterm

Where:

  • π(xout|z) is a probabilistic decoder: π(xout|z) = N(f(z), σ² I̲̲)) - f(z) is

the function defining the mean of the decoder π(xout|z) - qᵩ(z|xin) is the approximated encoder: qᵩ(z|xin) = N(g(xin), h(x_in))

  • g(xin) and h(xin) define the mean and covariance of the encoder respectively.

Arguments

  • vae::VAE: A VAE model with encoder and decoder networks.
  • x_in::AbstractArray: Input data to the VAE encoder. The last dimension is taken as having each of the samples in a batch.
  • x_out::AbstractArray: Target data to compute the reconstruction error. The last dimension is taken as having each of the samples in a batch.

Optional Keyword Arguments

  • β::Number=1.0f0: Weighting factor for the KL-divergence term, used for annealing.
  • reconstruction_loglikelihood::Function=decoder_loglikelihood: A function that computes the reconstruction log likelihood.
  • kl_divergence::Function=encoder_kl: A function that computes the Kullback-Leibler divergence.
  • reg_function::Union{Function, Nothing}=nothing: A function that computes the regularization term based on the VAE outputs. Should return a Float32. This function must take as input the VAE outputs and the keyword arguments provided in reg_kwargs.
  • reg_kwargs::Union{NamedTuple,Dict}=Dict(): Keyword arguments to pass to the regularization function.
  • reg_strength::Number=1.0f0: The strength of the regularization term.

Returns

  • T: The computed average loss value for the input x_in and its reconstructed counterparts x_out, including possible regularization terms.

Note

  • Ensure that the input data x_in and x_out match the expected input dimensionality for the encoder in the VAE.
Note

The loss function includes the β optional argument that can turn a vanilla VAE into a β-VAE by changing the default value of β from 1.0 to any other value.

Training

AutoEncoderToolkit.VAEs.train!Function
train!(vae, x, opt; loss_function, loss_kwargs, verbose, loss_return)

Customized training function to update parameters of a variational autoencoder given a specified loss function.

Arguments

  • vae::VAE: A struct containing the elements of a variational autoencoder.
  • x::AbstractArray: Data on which to evaluate the loss function. The last dimension is taken as having each of the samples in a batch.
  • opt::NamedTuple: State of the optimizer for updating parameters. Typically initialized using Flux.Train.setup.

Optional Keyword Arguments

  • loss_function::Function=loss: The loss function used for training. It should accept the VAE model, data x, and keyword arguments in that order.
  • loss_kwargs::Union{NamedTuple,Dict} = Dict(): Arguments for the loss function. These might include parameters like σ, or β, depending on the specific loss function in use.
  • verbose::Bool=false: If true, the loss value will be printed during training.
  • loss_return::Bool=false: If true, the loss value will be returned after training.

Description

Trains the VAE by:

  1. Computing the gradient of the loss w.r.t the VAE parameters.
  2. Updating the VAE parameters using the optimizer.

Examples

opt = Flux.setup(Optax.adam(1e-3), vae)
for x in dataloader
    train!(vae, x, opt; loss_fn, loss_kwargs=Dict(:β => 1.0f0,), verbose=true)
end
    `train!(
        vae, x_in, x_out, opt; 
        loss_function, loss_kwargs, verbose, loss_return
    )`

Customized training function to update parameters of a variational autoencoder given a loss function.

Arguments

  • vae::VAE: A struct containing the elements of a variational autoencoder.
  • x_in::AbstractArray: Input data for the loss function. Represents an individual sample. The last dimension is taken as having each of the samples in a batch.
  • x_out::AbstractArray: Target output data for the loss function. Represents the corresponding output for the x_in sample. The last dimension is taken as having each of the samples in a batch.
  • opt::NamedTuple: State of the optimizer for updating parameters. Typically initialized using Flux.Optimisers.update!.

Optional Keyword Arguments

  • loss_function::Function=loss: The loss function used for training. It should accept the VAE model, data x_in, x_out, and keyword arguments in that order.
  • loss_kwargs::Union{NamedTuple,Dict} = Dict(): Arguments for the loss function. These might include parameters like σ, or β, depending on the specific loss function in use.
  • verbose::Bool=false: Whether to print the loss value after each training step.
  • loss_return::Bool=false: Whether to return the loss value after each training step.

Description

Trains the VAE by:

  1. Computing the gradient of the loss w.r.t the VAE parameters.
  2. Updating the VAE parameters using the optimizer.

Examples

opt = Flux.setup(Optax.adam(1e-3), vae)
for (x_in, x_out) in dataloader
        train!(vae, x_in, x_out, opt) 
end