examples
This section contains multiple example scripts to run inference using the models (included in the model module) for different experimental designs. These are meant to serve as a guide for new users to get their analysis running ASAP.
We invite the users to check the full documentation to adapt the inference to their specific needs. Also, we invite the users to open an issue in the GitHub
repository to report bugs or ask questions.
General package imports
All of the examples listed below make use of the same libraries. Therefore, we suggest adding this at the beginning of all inference scripts.
# Import project package
import BayesFitUtils
# Import library package
import BarBay
# Import libraries to manipulate data
import DataFrames as DF
import CSV
# Import library to perform Bayesian inference
import Turing
import AdvancedVI
# Import AutoDiff backend
using ReverseDiff
# Impor statistical libraries
import Random
import StatsBase
import Distributions
Random.seed!(42)
Selecting the AutoDiff backend
Recall that we can use either ForwardDiff
or ReverseDiff
as the AutoDiff backend. The choice usually depends on the number of parameters to be inferred. For large number of parameters, ReverseDiff
is usually faster, while ForwardDiff
is faster for small number of parameters. To choose the backend, we use the :advi
keyword argument in the BarBay.vi.advi
function. For ForwardDiff
, we use
:advi => Turing.ADVI(n_samples, n_steps)
as ForwardDiff
is the default backend. For ReverseDiff
, we use
:advi => Turing.ADVI{AdvancedVI.ReverseDiffAD{false}}(n_samples, n_steps)
where the false
indicates that we are not using the cache for the random number tape. See the AdvancedVI.jl
repository for more information.
Single dataset single environment variational inference
For the case where there is a single dataset produced with a series of growth-dilution cycles over a single environment.
The dataset should look something like
| time | barcode | count | neutral | freq |
|------|------------|-------|---------|-------------|
| 3 | neutral025 | 12478 | TRUE | 0.000543716 |
| 4 | neutral025 | 10252 | TRUE | 0.00034368 |
| 5 | neutral025 | 2883 | TRUE | 6.74E-05 |
| 1 | mut001 | 1044 | FALSE | 7.97E-05 |
| 2 | mut001 | 2010 | FALSE | 0.000121885 |
| 3 | mut001 | 766 | FALSE | 3.34E-05 |
| 4 | mut001 | 216 | FALSE | 7.24E-06 |
| 5 | mut001 | 120 | FALSE | 2.81E-06 |
| 1 | mut002 | 51484 | FALSE | 0.003930243 |
The script to analyze the data then looks like
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Define ADVI hyerparameters
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Define number of samples and steps
n_samples = 1
n_steps = 3_000
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Generate output directories
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Generate output directory
if !isdir("./output/")
mkdir("./output/")
end # if
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Loading the data
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
println("Loading data...")
# Import data
data = CSV.read(
"path/to/data/tidy_data.csv", DF.DataFrame
)
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Obtain priors on expected errors from neutral measurements
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Compute naive priors from neutral strains
naive_priors = BarBay.stats.naive_prior(data)
# Select standard deviation parameters
s_pop_prior = hcat(
naive_priors[:s_pop_prior],
repeat([0.05], length(naive_priors[:s_pop_prior]))
)
logσ_pop_prior = hcat(
naive_priors[:logσ_pop_prior],
repeat([1.0], length(naive_priors[:logσ_pop_prior]))
)
logσ_bc_prior = [StatsBase.mean(naive_priors[:logσ_pop_prior]), 1.0]
logλ_prior = hcat(
naive_priors[:logλ_prior],
repeat([3.0], length(naive_priors[:logλ_prior]))
)
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Define ADVI function parameters
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
param = Dict(
:data => data,
:outputname => "./output/advi_meanfield_" *
"$(lpad(n_samples, 2, "0"))samples_$(n_steps)steps",
:model => BarBay.model.fitness_normal,
:model_kwargs => Dict(
:s_pop_prior => s_pop_prior,
:logσ_pop_prior => logσ_pop_prior,
:logσ_bc_prior => logσ_bc_prior,
:s_bc_prior => [0.0, 1.0],
:logλ_prior => logλ_prior,
),
:advi => Turing.ADVI(n_samples, n_steps),
:opt => Turing.TruncatedADAGrad()
)
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Perform optimization
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Run inference
println("Running Variational Inference...")
@time BarBay.vi.advi(; param...)
Single dataset single environment MCMC sampling
If the number of barcodes is relatively small, one can try a more exact sampling of the posterior with MCMC. BarBay
is structured such that the changes from fitting a model with ADVI vs MCMC are minimal. Here is an example script to fit the same dataset as before using Dynamic HMC:
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Define MCMC hyerparameters
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
n_steps = 3_000
n_walkers = 4
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Generate output directories
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Generate output directory
if !isdir("./output/")
mkdir("./output/")
end # if
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Loading the data
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
println("Loading data...")
# Import data
data = CSV.read(
"path/to/data/tidy_data.csv", DF.DataFrame
)
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Obtain priors on expected errors from neutral measurements
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Compute naive priors from neutral strains
naive_priors = BarBay.stats.naive_prior(data)
# Select standard deviation parameters
s_pop_prior = hcat(
naive_priors[:s_pop_prior],
repeat([0.05], length(naive_priors[:s_pop_prior]))
)
logσ_pop_prior = hcat(
naive_priors[:logσ_pop_prior],
repeat([1.0], length(naive_priors[:logσ_pop_prior]))
)
logσ_bc_prior = [StatsBase.mean(naive_priors[:logσ_pop_prior]), 1.0]
logλ_prior = hcat(
naive_priors[:logλ_prior],
repeat([3.0], length(naive_priors[:logλ_prior]))
)
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Initialize MCMC sampling
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
println("Initializing MCMC sampling...\n")
# Define function parameters
param = Dict(
:data => data,
:n_walkers => n_walkers,
:n_steps => n_steps,
:outputname => "./output/chain_joint_fitness_$(n_steps)steps_$(lpad(n_walkers, 2, "0"))walkers",
:model => BarBay.model.fitness_normal,
:model_kwargs => Dict(
:s_pop_prior => s_pop_prior,
:logσ_pop_prior => logσ_pop_prior,
:logσ_bc_prior => logσ_bc_prior,
:s_bc_prior => [0.0, 1.0],
:logλ_prior => logλ_prior,
),
:sampler => Turing.externalsampler(DynamicHMC.NUTS()),
:ensemble => Turing.MCMCThreads(),
:rm_T0 => false,
)
# Run inference
println("Running Inference...")
@time BarBay.mcmc.mcmc_sample(; param...)
Multi-environment single dataset variational inference
When dealing with an experiment where the growth-dilution cycles were done into different environments, the data should include a column indicating the environment label. The dataset then looks something like
| time | env | barcode | count | neutral | freq |
|------|-----|------------|-------|---------|-------------|
| 1 | 1 | neutral100 | 7327 | TRUE | 0.000399781 |
| 2 | 1 | neutral100 | 4034 | TRUE | 0.000228517 |
| 3 | 2 | neutral100 | 5135 | TRUE | 0.000257352 |
| 4 | 3 | neutral100 | 2011 | TRUE | 6.80E-05 |
| 5 | 1 | neutral100 | 1225 | TRUE | 3.39E-05 |
| 6 | 2 | neutral100 | 693 | TRUE | 1.93E-05 |
| 7 | 3 | neutral100 | 152 | TRUE | 4.08E-06 |
| 1 | 1 | mut001 | 268 | FALSE | 1.46E-05 |
| 2 | 1 | mut001 | 187 | FALSE | 1.06E-05 |
A basic script to analyze this dataset then takes the form
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Define ADVI hyerparameters
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Define number of samples and steps
n_samples = 1
n_steps = 10_000
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Generate output directories
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Generate output directory
if !isdir("./output/")
mkdir("./output/")
end # if
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Loading the data
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
println("Loading data...")
# Import data
data = CSV.read("path/to/data/tidy_data.csv", DF.DataFrame)
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Obtain priors on expected errors from neutral measurements
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Compute naive priors from neutral strains
naive_priors = BarBay.stats.naive_prior(data; pseudocount=1)
# Select standard deviation parameters
s_pop_prior = hcat(
naive_priors[:s_pop_prior],
repeat([0.05], length(naive_priors[:s_pop_prior]))
)
logσ_pop_prior = hcat(
naive_priors[:logσ_pop_prior],
repeat([1.0], length(naive_priors[:logσ_pop_prior]))
)
logσ_bc_prior = [StatsBase.mean(naive_priors[:logσ_pop_prior]), 1.0]
logλ_prior = hcat(
naive_priors[:logλ_prior],
repeat([3.0], length(naive_priors[:logλ_prior]))
)
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Define ADVI function parameters
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
param = Dict(
:data => data,
:outputname => "./output/advi_meanfield_$(lpad(n_samples, 2, "0"))samples_$(n_steps)steps",
:model => BarBay.model.multienv_fitness_normal,
:model_kwargs => Dict(
:s_pop_prior => s_pop_prior,
:logσ_pop_prior => logσ_pop_prior,
:logσ_bc_prior => logσ_bc_prior,
:s_bc_prior => [0.0, 1.0],
:logλ_prior => logλ_prior,
),
:env_col => :env,
:advi => Turing.ADVI(n_samples, n_steps),
:opt => Turing.TruncatedADAGrad()
)
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Perform optimization
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Run inference
println("Running Variational Inference...")
@time BarBay.vi.advi(; param...)
Hierarchical model for multiple experimental replicates variational inference
If there are more than one experimental replicates, the dataset must include a column indicating the replicate ID for each observation. The dataset ends up looking like
| time | barcode | count | neutral | count_sum | freq | rep |
|------|------------|-------|---------|-----------|-------------|------|
| 1 | neutral001 | 9967 | TRUE | 19321304 | 0.000515855 | R1 |
| 2 | neutral001 | 3749 | TRUE | 18224218 | 0.000205715 | R1 |
| 3 | neutral001 | 3516 | TRUE | 23317980 | 0.000150785 | R1 |
| 4 | neutral001 | 2217 | TRUE | 31261050 | 7.09E-05 | R1 |
| 5 | neutral001 | 1027 | TRUE | 38335591 | 2.68E-05 | R1 |
| 1 | neutral002 | 8676 | TRUE | 19321304 | 0.000449038 | R1 |
| 2 | neutral002 | 6019 | TRUE | 18224218 | 0.000330275 | R1 |
| 3 | neutral002 | 2245 | TRUE | 23317980 | 9.63E-05 | R1 |
| 4 | neutral002 | 2179 | TRUE | 31261050 | 6.97E-05 | R1 |
To analyze multiple experimental replicates jointly, we can use a hierarchical model. The basic script to implement this looks something like
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Define ADVI hyerparameters
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Define number of samples and steps
n_samples = 1
n_steps = 10_000
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Generate output directories
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Generate output directory
if !isdir("./output/")
mkdir("./output/")
end # if
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Loading the data
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
println("Loading data...")
# Import data
data = CSV.read("path/to/data/tidy_data.csv", DF.DataFrame)
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Obtain priors on expected errors from neutral measurements
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Compute naive priors from neutral strains
naive_priors = BarBay.stats.naive_prior(data; rep_col=:rep, pseudocount=1)
# Select standard deviation parameters
s_pop_prior = hcat(
naive_priors[:s_pop_prior],
repeat([0.05], length(naive_priors[:s_pop_prior]))
)
logσ_pop_prior = hcat(
naive_priors[:logσ_pop_prior],
repeat([1.0], length(naive_priors[:logσ_pop_prior]))
)
logσ_bc_prior = [StatsBase.mean(naive_priors[:logσ_pop_prior]), 1.0]
logλ_prior = hcat(
naive_priors[:logλ_prior],
repeat([3.0], length(naive_priors[:logλ_prior]))
)
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Define ADVI function parameters
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
param = Dict(
:data => data,
:outputname => "./output/advi_meanfield_" *
"$(lpad(n_samples, 2, "0"))samples_$(n_steps)steps",
:model => BarBay.model.replicate_fitness_normal,
:model_kwargs => Dict(
:s_pop_prior => s_pop_prior,
:logσ_pop_prior => logσ_pop_prior,
:logσ_bc_prior => logσ_bc_prior,
:s_bc_prior => [0.0, 1.0],
:logλ_prior => logλ_prior,
:logτ_prior => [-2.0, 0.5],
),
:advi => Turing.ADVI(n_samples, n_steps),
:opt => Turing.TruncatedADAGrad(),
:rep_col => :rep,
:fullrank => false
)
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Perform optimization
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Run inference
println("Running Variational Inference...")
@time BarBay.vi.advi(; param...)
Hierarchical model for multiple barcodes mapping to same genotype variational inference
When multiple barcodes map to the same genotype within a single experiment, the dataset must include a column indicating the genotype each barcode belongs to. The dataset ends up looking something like
| time | barcode | count | neutral | count_sum | freq | genotype |
|------|------------|-------|---------|-----------|--------------|--------------|
| 1 | neutral001 | 6649 | TRUE | 17418514 | 0.00038172 | genotype000 |
| 2 | neutral001 | 6245 | TRUE | 16007352 | 0.000390133 | genotype000 |
| 3 | neutral001 | 6323 | TRUE | 22075763 | 0.000286423 | genotype000 |
| 4 | neutral001 | 2345 | TRUE | 27743357 | 8.45E-05 | genotype000 |
| 5 | neutral001 | 1379 | TRUE | 34253492 | 4.03E-05 | genotype000 |
| 1 | neutral002 | 5160 | TRUE | 17418514 | 0.000296237 | genotype000 |
| 2 | neutral002 | 4078 | TRUE | 16007352 | 0.000254758 | genotype000 |
| 3 | neutral002 | 3386 | TRUE | 22075763 | 0.000153381 | genotype000 |
| 4 | neutral002 | 2821 | TRUE | 27743357 | 0.000101682 | genotype000 |
As with experimental replicates, we can implement a hierarchical model for this experimental design. The basic script to implement this looks like
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Define ADVI hyerparameters
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Define number of samples and steps
n_samples = 1
n_steps = 10_000
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Generate output directories
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Generate output directory
if !isdir("./output/")
mkdir("./output/")
end # if
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Loading the data
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
println("Loading data...")
# Import data
data = CSV.read("path/to/data/tidy_data.csv", DF.DataFrame)
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Obtain priors on expected errors from neutral measurements
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Compute naive priors from neutral strains
naive_priors = BarBay.stats.naive_prior(data; pseudocount=1)
# Select standard deviation parameters
s_pop_prior = hcat(
naive_priors[:s_pop_prior],
repeat([0.05], length(naive_priors[:s_pop_prior]))
)
logσ_pop_prior = hcat(
naive_priors[:logσ_pop_prior],
repeat([1.0], length(naive_priors[:logσ_pop_prior]))
)
logσ_bc_prior = [StatsBase.mean(naive_priors[:logσ_pop_prior]), 1.0]
logλ_prior = hcat(
naive_priors[:logλ_prior],
repeat([3.0], length(naive_priors[:logλ_prior]))
)
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Define ADVI function parameters
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
param = Dict(
:data => data,
:outputname => "./output/advi_meanfield_hierarchicalgenotypes_" *
"$(lpad(n_samples, 2, "0"))samples_$(n_steps)steps",
:model => BarBay.model.genotype_fitness_normal,
:model_kwargs => Dict(
:s_pop_prior => s_pop_prior,
:logσ_pop_prior => logσ_pop_prior,
:logσ_bc_prior => logσ_bc_prior,
:s_bc_prior => [0.0, 1.0],
:logλ_prior => logλ_prior,
),
:genotype_col => :genotype,
:advi => Turing.ADVI(n_samples, n_steps),
:opt => Turing.TruncatedADAGrad()
)
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Perform optimization
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Run inference
println("Running Variational Inference...")
@time dist = BarBay.vi.advi(; param...)