β-Variational Autoencoder
Variational Autoencoders, first introduced by Kingma and Welling in 2014, are a type of generative model that learns to encode high-dimensional data into a low-dimensional latent space. The main idea behind VAEs is to learn a probabilistic mapping (via variational inference) from the input data to the latent space, which allows for the generation of new data points by sampling from the latent space.
Their counterpart, the β-VAE, introduced by Higgins et al. in 2017, is a variant of the original VAE that includes a hyperparameter β
that controls the relative importance of the reconstruction loss and the KL divergence term in the loss function. By adjusting β
, the user can control the trade-off between the reconstruction quality and the disentanglement of the latent space.
In terms of implementation, the VAE
struct in AutoEncoderToolkit.jl
is a simple feedforward network composed of variational encoder and decoder parts. This means that the encoder has a log-posterior function and a KL divergence function associated with it, while the decoder has a log-likehood function associated with it.
References
VAE
Kingma, D. P. & Welling, M. Auto-Encoding Variational Bayes. Preprint at http://arxiv.org/abs/1312.6114 (2014).
β-VAE
Higgins, I. et al. β-VAE: LEARNING BASIC VISUAL CONCEPTS WITH A CONSTRAINED VARIATIONAL FRAMEWORK. (2017).
VAE
struct
AutoEncoderToolkit.VAEs.VAE
— Typestruct VAE{E<:AbstractVariationalEncoder, D<:AbstractVariationalDecoder}
Variational autoencoder (VAE) model defined for Flux.jl
Fields
encoder::E
: Neural network that encodes the input into the latent space.E
is a subtype ofAbstractVariationalEncoder
.decoder::D
: Neural network that decodes the latent representation back to the original input space.D
is a subtype ofAbstractVariationalDecoder
.
A VAE consists of an encoder and decoder network with a bottleneck latent space in between. The encoder compresses the input into a low-dimensional probabilistic representation q(z|x). The decoder tries to reconstruct the original input from a sampled point in the latent space p(x|z).
Forward pass
AutoEncoderToolkit.VAEs.VAE
— Method (vae::VAE)(x::AbstractArray; latent::Bool=false)
Perform the forward pass of a Variational Autoencoder (VAE).
This function takes as input a VAE and a vector or matrix of input data x
. It first runs the input through the encoder to obtain the mean and log standard deviation of the latent variables. It then uses the reparameterization trick to sample from the latent distribution. Finally, it runs the latent sample through the decoder to obtain the output.
Arguments
vae::VAE
: The VAE used to encode the input data and decode the latent space.x::AbstractArray
: The input data. If array, the last dimension contains each of the samples in a batch.
Optional Keyword Arguments
latent::Bool
: Whether to return the latent variables along with the decoder output. Iftrue
, the function returns a tuple containing the encoder outputs, the latent sample, and the decoder outputs. Iffalse
, the function only returns the decoder outputs. Defaults tofalse
.
Returns
- If
latent
istrue
, returns a tuple containing:encoder
: The outputs of the encoder.z
: The latent sample.decoder
: The outputs of the decoder.
- If
latent
isfalse
, returns the outputs of the decoder.
Example
# Define a VAE
vae = VAE(
encoder=Flux.Chain(Flux.Dense(784, 400, relu), Flux.Dense(400, 20)),
decoder=Flux.Chain(Flux.Dense(20, 400, relu), Flux.Dense(400, 784))
)
# Define input data
x = rand(Float32, 784)
# Perform the forward pass
outputs = vae(x, latent=true)
Loss function
AutoEncoderToolkit.VAEs.loss
— Functionloss(
vae::VAE,
x::AbstractArray;
β::Number=1.0f0,
reconstruction_loglikelihood::Function=decoder_loglikelihood,
kl_divergence::Function=encoder_kl,
reg_function::Union{Function,Nothing}=nothing,
reg_kwargs::NamedTuple=NamedTuple(),
reg_strength::Number=1.0f0
)
Computes the loss for the variational autoencoder (VAE).
The loss function combines the reconstruction loss with the Kullback-Leibler (KL) divergence, and possibly a regularization term, defined as:
loss = -⟨logπ(x|z)⟩ + β × Dₖₗ[qᵩ(z|x) || π(z)] + regstrength × regterm
Where:
- π(x|z) is a probabilistic decoder: π(x|z) = N(f(z), σ² I̲̲)) - f(z) is the function defining the mean of the decoder π(x|z) - qᵩ(z|x) is the approximated encoder: qᵩ(z|x) = N(g(x), h(x))
- g(x) and h(x) define the mean and covariance of the encoder respectively.
Arguments
vae::VAE
: A VAE model with encoder and decoder networks.x::AbstractArray
: Input data. The last dimension is taken as having each of the samples in a batch.
Optional Keyword Arguments
β::Number=1.0f0
: Weighting factor for the KL-divergence term, used for annealing.reconstruction_loglikelihood::Function=decoder_loglikelihood
: A function that computes the reconstruction log likelihood.kl_divergence::Function=encoder_kl
: A function that computes the Kullback-Leibler divergence between the encoder output and a standard normal.reg_function::Union{Function, Nothing}=nothing
: A function that computes the regularization term based on the VAE outputs. Should return a Float32. This function must take as input the VAE outputs and the keyword arguments provided inreg_kwargs
.reg_kwargs::NamedTuple=NamedTuple()
: Keyword arguments to pass to the regularization function.reg_strength::Number=1.0f0
: The strength of the regularization term.
Returns
T
: The computed average loss value for the inputx
and its reconstructed counterparts, including possible regularization terms.
Note
- Ensure that the input data
x
matches the expected input dimensionality for the encoder in the VAE.
loss(
vae::VAE,
x_in::AbstractArray,
x_out::AbstractArray;
β::Number=1.0f0,
reconstruction_loglikelihood::Function=decoder_loglikelihood,
kl_divergence::Function=encoder_kl,
reg_function::Union{Function,Nothing}=nothing,
reg_kwargs::NamedTuple=NamedTuple(),
reg_strength::Number=1.0f0
)
Computes the loss for the variational autoencoder (VAE).
The loss function combines the reconstruction loss with the Kullback-Leibler (KL) divergence and possibly a regularization term, defined as:
loss = -⟨logπ(xout|z)⟩ + β × Dₖₗ[qᵩ(z|xin) || π(z)] + regstrength × regterm
Where:
- π(xout|z) is a probabilistic decoder: π(xout|z) = N(f(z), σ² I̲̲)) - f(z) is
the function defining the mean of the decoder π(xout|z) - qᵩ(z|xin) is the approximated encoder: qᵩ(z|xin) = N(g(xin), h(x_in))
- g(xin) and h(xin) define the mean and covariance of the encoder respectively.
Arguments
vae::VAE
: A VAE model with encoder and decoder networks.x_in::AbstractArray
: Input data to the VAE encoder. The last dimension is taken as having each of the samples in a batch.x_out::AbstractArray
: Target data to compute the reconstruction error. The last dimension is taken as having each of the samples in a batch.
Optional Keyword Arguments
β::Number=1.0f0
: Weighting factor for the KL-divergence term, used for annealing.reconstruction_loglikelihood::Function=decoder_loglikelihood
: A function that computes the reconstruction log likelihood.kl_divergence::Function=encoder_kl
: A function that computes the Kullback-Leibler divergence.reg_function::Union{Function, Nothing}=nothing
: A function that computes the regularization term based on the VAE outputs. Should return a Float32. This function must take as input the VAE outputs and the keyword arguments provided inreg_kwargs
.reg_kwargs::NamedTuple=NamedTuple()
: Keyword arguments to pass to the regularization function.reg_strength::Number=1.0f0
: The strength of the regularization term.
Returns
T
: The computed average loss value for the inputx_in
and its reconstructed counterpartsx_out
, including possible regularization terms.
Note
- Ensure that the input data
x_in
andx_out
match the expected input dimensionality for the encoder in the VAE.
The loss
function includes the β
optional argument that can turn a vanilla VAE into a β-VAE by changing the default value of β
from 1.0
to any other value.
Training
AutoEncoderToolkit.VAEs.train!
— Functiontrain!(vae, x, opt; loss_function, loss_kwargs, verbose, loss_return)
Customized training function to update parameters of a variational autoencoder given a specified loss function.
Arguments
vae::VAE
: A struct containing the elements of a variational autoencoder.x::AbstractArray
: Data on which to evaluate the loss function. The last dimension is taken as having each of the samples in a batch.opt::NamedTuple
: State of the optimizer for updating parameters. Typically initialized usingFlux.Train.setup
.
Optional Keyword Arguments
loss_function::Function=loss
: The loss function used for training. It should accept the VAE model, datax
, and keyword arguments in that order.loss_kwargs::NamedTuple=NamedTuple()
: Arguments for the loss function. These might include parameters likeσ
, orβ
, depending on the specific loss function in use.verbose::Bool=false
: If true, the loss value will be printed during training.loss_return::Bool=false
: If true, the loss value will be returned after training.
Description
Trains the VAE by:
- Computing the gradient of the loss w.r.t the VAE parameters.
- Updating the VAE parameters using the optimizer.
`train!(
vae, x_in, x_out, opt;
loss_function, loss_kwargs, verbose, loss_return
)`
Customized training function to update parameters of a variational autoencoder given a loss function.
Arguments
vae::VAE
: A struct containing the elements of a variational autoencoder.x_in::AbstractArray
: Input data for the loss function. Represents an individual sample. The last dimension is taken as having each of the samples in a batch.x_out::AbstractArray
: Target output data for the loss function. Represents the corresponding output for thex_in
sample. The last dimension is taken as having each of the samples in a batch.opt::NamedTuple
: State of the optimizer for updating parameters. Typically initialized usingFlux.Optimisers.update!
.
Optional Keyword Arguments
loss_function::Function=loss
: The loss function used for training. It should accept the VAE model, datax_in
,x_out
, and keyword arguments in that order.loss_kwargs::NamedTuple=NamedTuple()
: Arguments for the loss function. These might include parameters likeσ
, orβ
, depending on the specific loss function in use.verbose::Bool=false
: Whether to print the loss value after each training step.loss_return::Bool=false
: Whether to return the loss value after each training step.
Description
Trains the VAE by:
- Computing the gradient of the loss w.r.t the VAE parameters.
- Updating the VAE parameters using the optimizer.
Examples
opt = Flux.setup(Optax.adam(1e-3), vae)
for (x_in, x_out) in dataloader
train!(vae, x_in, x_out, opt)
end