InfoMax VAE
The InfoMax VAE is a variant of the Variational Autoencoder (VAE) that aims to explicitly account for the maximization of mutual information between the latent space representation and the input data. The main difference between the InfoMax VAE and the MMD-VAE (InfoVAE) is that rather than using the Maximum-Mean Discrepancy (MMD) as a measure of the "distance" between the latent space, the InfoMax VAE explicitly models the mutual information between latent representations and data inputs via a separate neural network. The loss function for this separate network then takes the form of a variational lower bound on the mutual information between the latent space and the input data.
Because of the need of this separate network, the InfoMaxVAE
struct in AutoEncoderToolkit.jl
takes two arguments to construct: the original VAE
struct and a network to compute the mutual information. To properly deploy all relevant functions associated with this second network, we also provide a MutualInfoChain
struct.
Furthermore, because of the two networks and the way the training algorithm is set up, the loss function for the InfoMax VAE includes two separate loss functions: one for the MutualInfoChain
and one for the InfoMaxVAE
.
References
Rezaabad, A. L. & Vishwanath, S. Learning Representations by Maximizing Mutual Information in Variational Autoencoders. Preprint at http://arxiv.org/abs/1912.13361 (2020).
MutualInfoChain
struct
AutoEncoderToolkit.InfoMaxVAEs.MutualInfoChain
— TypeMutualInfoChain
A MutualInfoChain
is used to compute the variational mutual information when training an InfoMaxVAE. The chain is composed of a series of layers that must end with a single output: the mutual information between the latent variables and the input data.
Arguments
data::Union{Flux.Dense,Flux.Chain}
: The data layer of the MutualInfoChain. This layer is used to input the data.latent::Union{Flux.Dense,Flux.Chain}
: The latent layer of the MutualInfoChain. This layer is used to input the latent variables.mlp::Flux.Chain
: A multi-layer perceptron (MLP) that is used to compute the mutual information between the inputs and the latent representations. The MLP takes as input the latent variables and outputs a scalar representing the estimated variational mutual information.
Citation
Rezaabad, A. L. & Vishwanath, S. Learning Representations by Maximizing Mutual Information in Variational Autoencoders. in 2020 IEEE International Symposium on Information Theory (ISIT) 2729–2734 (IEEE, 2020). doi:10.1109/ISIT44484.2020.9174424.
Note
If the input data is not a flat array, make sure to include a flattening layer within data
.
InfoMaxVAE
struct
AutoEncoderToolkit.InfoMaxVAEs.InfoMaxVAE
— Type`InfoMaxVAE <: AbstractVariationalAutoEncoder`
struct
encapsulating an InfoMax variational autoencoder (InfoMaxVAE), an architecture designed to enhance the VAE framework by maximizing mutual information between the inputs and the latent representations, as per the methods described by Rezaabad and Vishwanath (2020).
The model aims to learn representations that preserve mutual information with the input data, arguably capturing more meaningful factors of variation.
Fields
vae::VAE
: The core variational autoencoder, consisting of an encoder that maps input data into a latent space representation, and a decoder that attempts to reconstruct the input from the latent representation.mi::MutualInfoChain
: A multi-layer perceptron (MLP) that estimates the mutual information between the input data and the latent representations.
Usage
The InfoMaxVAE
struct is utilized in a similar manner to a standard VAE, with the added capability of mutual information maximization as part of the training process. This involves an additional loss term that considers the output of the mi
network to encourage latent representations that are informative about the input data.
Example
# Assuming definitions for `encoder`, `decoder`, and `mi` are provided:
info_max_vae = InfoMaxVAE(VAE(encoder, decoder), mi)
# During training, one would maximize both the variational lower bound and the
# mutual information estimate provided by `mlp`.
Citation
Rezaabad, A. L. & Vishwanath, S. Learning Representations by Maximizing Mutual Information in Variational Autoencoders. in 2020 IEEE International Symposium on Information Theory (ISIT) 2729–2734 (IEEE, 2020). doi:10.1109/ISIT44484.2020.9174424.
Forward pass
Mutual Information Network
AutoEncoderToolkit.InfoMaxVAEs.MutualInfoChain
— Method(mi::MutualInfoChain)(x::AbstractArray, z::AbstractVecOrMat)
Forward pass function for the MutualInfoChain, which applies the MLP to an input x.
Arguments
x::AbstractArray
: The input array to be processed. The last dimension represents each data sample.z::AbstractVecOrMat
: The latent representation of the input data. The last dimension represents each data sample.
Returns
- The result of applying the MutualInfoChain to the input data and the latent representation simultaneously.
Description
This function applies the MLP (Multilayer Perceptron) of a MutualInfoChain instance to an input array. The MLP is a type of neural network used in the MutualInfoChain for processing the input data.
InfoMax VAE
AutoEncoderToolkit.InfoMaxVAEs.InfoMaxVAE
— Method(vae::InfoMaxVAE)(x::AbstractArray; latent::Bool=false)
Processes the input data x
through an InfoMaxVAE, which consists of an encoder, a decoder, and a multi-layer perceptron (MLP) to estimate variational mutual information.
Arguments
x::AbstractArray
: The data to be decoded. If array, the last dimension contains each data sample.
Optional Keyword Arguments
latent::Bool
: Iftrue
, returns a dictionary with latent variables and mutual information estimations along with the reconstruction. Defaults tofalse
.seed::Union{Nothing,Int}
: Optional argument. The seed for the random number generator used for shuffling the latent codes. If not provided, a random seed will be used.
Returns
- If
latent=false
: The decoder output as aNamedTuple
. - If
latent=true
: ANamedTuple
with the:vae
field that contains the outputs of the VAE, and the:mi
field that contains the estimate of the variational mutual information. Note that this estimate requires shuffling the latent codes between data samples. Therefore, it is only meaningful for batch data cases.
Description
This function first encodes the input x
. It then samples from this distribution using the reparametrization trick. The sampled latent vectors are then decoded, and the MutualInfoChain is used to estimate the mutual information.
Note
Ensure the input data x
matches the expected input dimensionality for the encoder in the InfoMaxVAE.
[Loss functions]
Mutual Information Network
AutoEncoderToolkit.InfoMaxVAEs.miloss
— Functionmiloss(
vae::VAE,
mi::MutualInfoChain,
x::AbstractArray;
regularization::Union{Function,Nothing}=nothing,
reg_strength::Float32=1.0f0,
seed::Union{Nothing,Int}=nothing
)
Calculates the loss for training the MutualInfoChain in the InfoMaxVAE algorithm to estimate mutual information between the input x
and the latent representation z
. The loss function is based on a variational approximation of mutual information, using the MutualInfoChain's output g(x, z)
. The variational mutual information is then calculated as the difference between the MutualInfoChain's output for the true x
and latent z
, and the exponentiated average of the MLP's output for x
and the shuffled latent z_shuffle
, adjusted for the regularization term if provided.
Arguments
vae::VAE
: The variational autoencoder.mi::MutualInfoChain
: The MutualInfoChain used for estimating mutual information.x::AbstractArray
: The input vector for the VAE.
Optional Keyword Arguments
regularization::Union{Function, Nothing}=nothing
: A regularization function applied to the MLP's output.reg_strength::Float32=1.0f0
: The strength of the regularization term.seed::Union{Nothing,Int}=nothing
: The seed for the random number generator used for shuffling the latent codes. If not provided, a random seed will be used.
Returns
Float32
: The computed loss, representing negative variational mutual information, adjusted by the regularization term.
Description
The function computes the loss as follows:
loss = -sum(I(x; z)) + sum(exp(I(x; z̃) - 1)) + regstrength * regterm
where I(x; z)
is the MLP's output representing an estimation of mutual information for true x
and latent z
, and z̃
represents shuffled latent variables, meaning, the latent codes are randomly swap between data points.
The function is used to separately train the MLP to estimate mutual information, which is a component of the larger InfoMaxVAE model.
Notes
- This function takes the
vae
andmi
instances of an InfoMaxVAE model as separate arguments to be able to compute a gradient only with respect to themi
parameters. - Ensure that the dimensionality of the input data
x
aligns with the encoder's expected input in the VAE. - InfoMaxVAEs fully depend on batch training as the estimation of mutual information depends on shuffling the latent codes. This method works for large enough batches (≥ 64 samples).
InfoMax VAE
AutoEncoderToolkit.InfoMaxVAEs.infomaxloss
— Functioninfomaxloss(
vae::VAE,
mi::MutualInfoChain,
x::AbstractArray;
β=1.0f0,
α=1.0f0,
n_samples::Int=1,
reconstruction_loglikelihood::Function=decoder_loglikelihood,
kl_divergence::Function=encoder_kl,
regularization::Union{Function,Nothing}=nothing,
reg_strength::Float32=1.0f0,
seed::Union{Nothing,Int}=nothing
)
Computes the loss for an InfoMax variational autoencoder (VAE) with mutual information constraints, by averaging over n_samples
latent space samples.
The loss function combines the reconstruction loss with the Kullback-Leibler (KL) divergence, the variational mutual information between input and latent representations, and possibly a regularization term, defined as:
loss = -⟨log p(x|z)⟩ + β × Dₖₗ[qᵩ(z|x) || p(z)] - α × I(x;z) + regstrength × regterm
Where:
⟨log p(x|z)⟩
is the expected log likelihood of the probabilistic decoder. -
Dₖₗ[qᵩ(z|x) || p(z)]
is the KL divergence between the approximated encoder and the prior over the latent space.
I(x;z)
is the variational mutual information between the inputsx
and the latent variablesz
.
Arguments
vae::VAE
: A VAE model with encoder and decoder networks.mi::MutualInfoChain
: A MutualInfoChain instance used to estimate mutual information term.x::AbstractArray
: Input data. The last dimension represents each data sample.
Optional Keyword Arguments
β::Float32=1.0f0
: Weighting factor for the KL-divergence term, used for annealing.α::Float32=1.0f0
: Weighting factor for the mutual information term.n_samples::Int=1
: The number of samples to draw from the latent space when computing the loss.reconstruction_loglikelihood::Function=decoder_loglikelihood
: A function that computes the log likelihood of the decoder's output.kl_divergence::Function=encoder_kl
: A function that computes the KL divergence between the encoder's output and the prior.regularization::Union{Function, Nothing}=nothing
: A function that computes the regularization term based on the VAE outputs. Should return a Float32.reg_strength::Float32=1.0f0
: The strength of the regularization term.seed::Union{Nothing,Int}
: The seed for the random number generator used for shuffling the latent codes. If not provided, a random seed will be used.
Returns
Float32
: The computed average loss value for the inputx
and its reconstructed counterparts overn_samples
samples, including possible regularization terms and the mutual information constraint.
Note
- This function takes the
vae
andmi
instances of an InfoMaxVAE model as separate arguments to be able to compute a gradient only with respect to thevae
parameters. - Ensure that the input data
x
match the expected input dimensionality for the encoder in the VAE. - InfoMaxVAEs fully depend on batch training as the estimation of mutual information depends on shuffling the latent codes. This method works for large enough batches (≥ 64 samples).
infomaxloss(
vae::VAE,
mi::MutualInfoChain,
x_in::AbstractArray,
x_out::AbstractArray;
β=1.0f0,
α=1.0f0,
n_samples::Int=1,
reconstruction_loglikelihood::Function=decoder_loglikelihood,
kl_divergence::Function=encoder_kl,
regularization::Union{Function,Nothing}=nothing,
reg_strength::Float32=1.0f0,
seed::Union{Nothing,Int}=nothing
)
Computes the loss for an InfoMax variational autoencoder (VAE) with mutual information constraints, by averaging over n_samples
latent space samples.
The loss function combines the reconstruction loss with the Kullback-Leibler (KL) divergence, the variational mutual information between input and latent representations, and possibly a regularization term, defined as:
loss = -⟨log p(x|z)⟩ + β × Dₖₗ[qᵩ(z|x) || p(z)] - α × I(x;z) + regstrength × regterm
Where:
⟨log p(x|z)⟩
is the expected log likelihood of the probabilistic decoder. -
Dₖₗ[qᵩ(z|x) || p(z)]
is the KL divergence between the approximated encoder and the prior over the latent space.
I(x;z)
is the variational mutual information between the inputsx
and the latent variablesz
.
Arguments
vae::VAE
: A VAE model with encoder and decoder networks.mi::MutualInfoChain
: A MutualInfoChain instance used to estimate mutual information term.x_in::AbstractArray
: Input matrix. The last dimension represents each data sample.x_out::AbstractArray
: Output matrix against wich reconstructions are compared. The last dimension represents each data sample.
Optional Keyword Arguments
β::Float32=1.0f0
: Weighting factor for the KL-divergence term, used for annealing.α::Float32=1.0f0
: Weighting factor for the mutual information term.n_samples::Int=1
: The number of samples to draw from the latent space when computing the loss.reconstruction_loglikelihood::Function=decoder_loglikelihood
: A function that computes the log likelihood of the decoder's output.kl_divergence::Function=encoder_kl
: A function that computes the KL divergence between the encoder's output and the prior.regularization::Union{Function, Nothing}=nothing
: A function that computes the regularization term based on the VAE outputs. Should return a Float32.reg_strength::Float32=1.0f0
: The strength of the regularization term.seed::Union{Nothing,Int}
: The seed for the random number generator used for shuffling the latent codes. If not provided, a random seed will be used.
Returns
Float32
: The computed average loss value for the inputx
and its reconstructed counterparts overn_samples
samples, including possible regularization terms and the mutual information constraint.
Note
- This function takes the
vae
andmi
instances of an InfoMaxVAE model as separate arguments to be able to compute a gradient only with respect to thevae
parameters. - Ensure that the input data
x
match the expected input dimensionality for the encoder in the VAE. - InfoMaxVAEs fully depend on batch training as the estimation of mutual information depends on shuffling the latent codes. This method works for large enough batches (≥ 64 samples).
Training
AutoEncoderToolkit.InfoMaxVAEs.train!
— Function train!(
infomaxvae, x, opt;
infomaxloss_function=infomaxloss,
infomaxloss_kwargs,
miloss_function=miloss,
miloss_kwargs,
loss_return::Bool=false,
verbose::Bool=false
)
Customized training function to update parameters of an InfoMax variational autoencoder (VAE) given a loss function of the specified form.
The InfoMax VAE loss function can be defined as:
loss_infoMax = argmin -⟨log p(x|z)⟩ + β Dₖₗ(qᵩ(z) || p(z)) -
α [⟨g(x, z)⟩ - ⟨exp(g(x, z) - 1)⟩],
where ⟨log p(x|z)⟩
is the expected log likelihood of the probabilistic decoder, Dₖₗ[qᵩ(z) || p(z)]
is the KL divergence between the approximated encoder distribution and the prior over the latent space, and g(x, z)
is the output of the MutualInfoChain estimating the mutual information between the input data and the latent representation.
This function simultaneously optimizes two neural networks: the VAE itself and a multi-layer perceptron MutualInfoChain
used to compute the mutual information between input and latent variables.
Arguments
infomaxvae::InfoMaxVAE
: Struct containing the elements of an InfoMax VAE.x::AbstractArray
: Matrix containing the data on which to evaluate the loss function. Each column represents a single data point.opt::NamedTuple
: State of the optimizer for updating parameters. Typically initialized usingFlux.Optimisers.update!
.
Optional Keyword arguments
infomaxloss_function::Function
: The loss function to be used during training for the VAE, defaulting toinfomaxloss
.infomaxloss_kwargs::NamedTuple
: Additional keyword arguments to be passed to the VAE loss function.miloss_function::Function
: The loss function to be used during training for the MLP computing the variational free energy, defaulting tomiloss
.miloss_kwargs::NamedTuple
: Additional keyword arguments to be passed to the MutualInfoChain loss function.loss_return::Bool
: Iftrue
, the function returns the loss values for the VAE and MutualInfoChain. Defaults tofalse
.verbose::Bool
: Iftrue
, the function prints the loss values for the VAE and MutualInfoChain. Defaults tofalse
.
Description
Performs one step of gradient descent on the InfoMaxVAE loss function to jointly train the VAE and MutualInfoChain. The VAE parameters are updated to minimize the InfoMaxVAE loss, while the MutualInfoChain parameters are updated to maximize the estimated mutual information. The function allows for customization of loss hyperparameters during training.
Notes
- Ensure that the dimensionality of the input data
x
aligns with the encoder's expected input in the VAE. - InfoMaxVAEs fully depend on batch training as the estimation of mutual information depends on shuffling the latent codes. This method works best for large enough batches (≥ 64 samples).
train!(
infomaxvae, x, opt;
infomaxloss_function=infomaxloss,
infomaxloss_kwargs,
miloss_function=miloss,
miloss_kwargs,
loss_return::Bool=false,
verbose::Bool=false
)
Customized training function to update parameters of an InfoMax variational autoencoder (VAE) given a loss function of the specified form.
The InfoMax VAE loss function can be defined as:
loss_infoMax = argmin -⟨log p(x|z)⟩ + β Dₖₗ(qᵩ(z) || p(z)) -
α [⟨g(x, z)⟩ - ⟨exp(g(x, z) - 1)⟩],
where ⟨log p(x|z)⟩
is the expected log likelihood of the probabilistic decoder, Dₖₗ[qᵩ(z) || p(z)]
is the KL divergence between the approximated encoder distribution and the prior over the latent space, and g(x, z)
is the output of the MutualInfoChain estimating the mutual information between the input data and the latent representation.
This function simultaneously optimizes two neural networks: the VAE itself and a multi-layer perceptron MutualInfoChain
used to compute the mutual information between input and latent variables.
Arguments
infomaxvae::InfoMaxVAE
: Struct containing the elements of an InfoMax VAE.x::AbstractArray
: Matrix containing the data on which to evaluate the loss function. Each column represents a single data point.opt::NamedTuple
: State of the optimizer for updating parameters. Typically initialized usingFlux.Optimisers.update!
.
Optional Keyword arguments
infomaxloss_function::Function
: The loss function to be used during training for the VAE, defaulting toinfomaxloss
.infomaxloss_kwargs::NamedTuple
: Additional keyword arguments to be passed to the VAE loss function.miloss_function::Function
: The loss function to be used during training for the MutualInfoChain computing the variational free energy, defaulting tomiloss
.miloss_kwargs::NamedTuple
: Additional keyword arguments to be passed to the MutualInfoChain loss function.loss_return::Bool
: Iftrue
, the function returns the loss values for the VAE and MLP. Defaults tofalse
.
Description
Performs one step of gradient descent on the InfoMaxVAE loss function to jointly train the VAE and MutualInfoChain. The VAE parameters are updated to minimize the InfoMaxVAE loss, while the MutualInfoChain parameters are updated to maximize the estimated mutual information. The function allows for customization of loss hyperparameters during training.
Notes
- Ensure that the dimensionality of the input data
x
aligns with the encoder's expected input in the VAE. - InfoMaxVAEs fully depend on batch training as the estimation of mutual information depends on shuffling the latent codes. This method works best for large enough batches (≥ 64 samples).
Other Functions
AutoEncoderToolkit.InfoMaxVAEs.shuffle_latent
— Functionshuffle_latent(z::AbstractMatrix, seed::Int=Random.seed!())
Shuffle the elements of the second dimension of a matrix representing latent space points.
Arguments
z::AbstractMatrix
: A matrix representing latent codes. Each column corresponds to a single latent code.
Optional Keyword Arguments
seed::Union{Nothing, Int}
: Optional argument. The seed for the random number generator. If not provided, a random seed will be used.
Returns
AbstractMatrix
: A new matrix with the second dimension shuffled.
AutoEncoderToolkit.InfoMaxVAEs.variational_mutual_info
— Functionvariational_mutual_info(mi, x, z, z_shuffle)
Compute a variational approximation of the mutual information between the input x
and the latent code z
using a MutualInfoChain
. Note that this estimate requires shuffling the latent codes between data samples. Therefore, it only applies to batch data cases. A single sample will not provide a meaningful estimate.
Arguments
mi::MutualInfoChain
: A MutualInfoChain instance used to estimate mutual information.x::AbstractArray
: Array of input data. The last dimension represents each data sample.z::AbstractMatrix
: Matrix of corresponding latent representations of the input data.z_shuffle::AbstractMatrix
: Matrix of latent representations where the second dimension has been shuffled.
Returns
Float32
: An approximation of the mutual information between the input data and its corresponding latent representation.
References
Rezaabad, A. L. & Vishwanath, S. Learning Representations by Maximizing Mutual Information in Variational Autoencoders. Preprint at http://arxiv.org/abs/1912.13361 (2020).
variational_mutual_info(infomaxvae, x, z, z_shuffle)
Compute a variational approximation of the mutual information between the input x
and the latent code z
using an InfoMaxVAE
instance. Note that this estimate requires shuffling the latent codes between data samples. Therefore, it only applies to batch data cases. A single sample will not provide a meaningful estimate.
Arguments
infomaxvae::InfoMaxVAE
: An InfoMaxVAE instance used to estimate mutual information.x::AbstractArray
: Array of input data. The last dimension represents each data sample.z::AbstractMatrix
: Matrix of corresponding latent representations of the input data.z_shuffle::AbstractMatrix
: Matrix of latent representations where the second dimension has been shuffled.
Returns
Float32
: An approximation of the mutual information between the input data and its corresponding latent representation.
References
Rezaabad, A. L. & Vishwanath, S. Learning Representations by Maximizing Mutual Information in Variational Autoencoders. Preprint at http://arxiv.org/abs/1912.13361 (2020).
variational_mutual_info(
infomaxvae::InfoMaxVAE,
x::AbstractArray;
seed::Union{Nothing,Int}=nothing
)
Compute a variational approximation of the mutual information between the input x
and the latent code z
using an InfoMaxVAE
instance. This function also shuffles the latent codes between data samples to provide a meaningful estimate even for a single data sample.
Arguments
infomaxvae::InfoMaxVAE
: An InfoMaxVAE instance used to estimate mutual information.x::AbstractArray
: Array of input data. The last dimension represents each data sample.
Optional Keyword Arguments
seed::Union{Nothing,Int}
: Optional argument. The seed for the random number generator used for shuffling the latent codes. If not provided, a random seed will be used.
Returns
Float32
: An approximation of the mutual information between the input data and its corresponding latent representation.
References
Rezaabad, A. L. & Vishwanath, S. Learning Representations by Maximizing Mutual Information in Variational Autoencoders. Preprint at http://arxiv.org/abs/1912.13361 (2020).
Default initializations
AutoEncoderToolkit.jl
provides default initializations for the MutualInfoChain
. Although it gives the user less flexibility, it can be useful for quick prototyping.
AutoEncoderToolkit.InfoMaxVAEs.MutualInfoChain
— MethodMutualInfoChain(
size_input::Union{Int,Vector{<:Int}},
n_latent::Int,
mlp_neurons::Vector{<:Int},
mlp_activations::Vector{<:Function},
output_activation::Function;
init::Function = Flux.glorot_uniform
)
Constructs a default MutualInfoChain
.
Arguments
n_input::Int
: Number of input features to theMutualInfoChain
.n_latent::Int
: The dimensionality of the latent space.mlp_neurons::Vector{<:Int}
: A vector of integers where each element represents the number of neurons in the corresponding hidden layer of the MLP.mlp_activations::Vector{<:Function}
: A vector of activation functions to be used in the hidden layers. Length must match that ofmlp_neurons
.output_activation::Function
: Activation function for the output neuron of the MLP.
Optional Keyword Arguments
init::Function
: Initialization function for the weights of all layers in theMutualInfoChain
. Defaults toFlux.glorot_uniform
.
Returns
MutualInfoChain
: AMutualInfoChain
instance with the specified MLP architecture.
Notes
The function will throw an error if the number of provided activation functions does not match the number of layers specified in mlp_neurons.