# Import project package
import Antibiotic
# Import package to handle DataFrames
import DataFrames as DF
import CSV
# Import library for Bayesian inference
import Turing
# Import library to list files
import Glob
# Import packages to work with data
import DataFrames as DF
# Load CairoMakie for plotting
using CairoMakie
import ColorSchemes
# Import packages for posterior sampling
import PairPlots
# Import basic math libraries
import LsqFit
import StatsBase
import LinearAlgebra
import Random
# Activate backend
CairoMakie.activate!()
# Set custom plotting style
Antibiotic.viz.theme_makie!()Bayesian Inference of \(IC_{50}\) Values
Bayesian inference, Antibiotic resistance, \(IC_{50}\)
(c) This work is licensed under a Creative Commons Attribution License CC-BY 4.0. All code contained herein is licensed under an MIT license.
(c) This work is licensed under a Creative Commons Attribution License CC-BY 4.0. All code contained herein is licensed under an MIT license.
Bayesian Inference of \(IC_{50}\) values
In this notebook, we will perform Bayesian inference on the \(IC_{50}\) values of the antibiotic resistance landscape. For this, we will use the raw OD620 measurements provided by Iwasawa et al. (2022).
Let’s begin by loading the data into a DataFrame.
# Load data into a DataFrame
df = CSV.read("./iwasawa_data/iwasawa_tidy.csv", DF.DataFrame)
first(df, 5)| Row | antibiotic | col | OD | design | concentration_ugmL | day | strain_num | env | plate | blank |
|---|---|---|---|---|---|---|---|---|---|---|
| String7 | String7 | Float64 | Int64 | Float64 | Int64 | Int64 | String15 | Int64 | Bool | |
| 1 | TET | col1 | 0.0464 | 3 | 0.0 | 1 | 19 | TETE4_in_KM | 10 | true |
| 2 | KM | col1 | 0.0432 | 3 | 0.0 | 1 | 19 | TETE4_in_KM | 10 | true |
| 3 | NFLX | col1 | 0.0414 | 3 | 0.0 | 1 | 19 | TETE4_in_KM | 10 | true |
| 4 | SS | col1 | 0.0415 | 3 | 0.0 | 1 | 19 | TETE4_in_KM | 10 | true |
| 5 | PLM | col1 | 0.0427 | 3 | 0.0 | 1 | 19 | TETE4_in_KM | 10 | true |
To double-check that the structure of the table makes sense, let’s plot the time series for one example to see if the sequence agrees with the expectation.
# Define data to use
data = df[
(df.antibiotic.=="KM").&(df.env.=="Parent_in_KM").&(df.strain_num.==13).&.!(df.blank).&(df.concentration_ugmL.>0),
:]
# Remove blank measurement
# Group data by day
df_group = DF.groupby(data, :day)
# Initialize figure
fig = Figure(size=(500, 300))
# Add axis
ax = Axis(
fig[1, 1],
xlabel="antibiotic concentration",
ylabel="OD₆₂₀",
xscale=log2
)
# Define colors for plot
colors = get(ColorSchemes.Blues_9, LinRange(0.25, 1, length(df_group)))
# Loop through days
for (i, d) in enumerate(df_group)
# Sort data by concentration
DF.sort!(d, :concentration_ugmL)
# Plot scatter line
scatterlines!(
ax, d.concentration_ugmL, d.OD, color=colors[i], label="$(first(d.day))"
)
end # for
# Add legend to plot
fig[1, 2] = Legend(
fig, ax, "day", framevisible=false, nbanks=3, labelsize=10
)
figThe functional form used by the authors to fit the data is \[ f(x) = \frac{a} {1+\exp \left[b\left(\log _2 x-\log _2 IC_{50}\right)\right]} + c \tag{1}\]
where \(a\), \(b\), and \(c\) are nuisance parameters of the model, \(IC_{50}\) is the parameter of interest, and \(x\) is the antibiotic concentration. We can define a function to compute this model.
@doc raw"""
logistic(x, a, b, c, ic50)
Compute the logistic function used to model the relationship between antibiotic
concentration and bacterial growth.
This function implements the following equation:
f(x) = a / (1 + exp(b * (log₂(x) - log₂(IC₅₀)))) + c
# Arguments
- `x`: Antibiotic concentration (input variable)
- `a`: Maximum effect parameter (difference between upper and lower asymptotes)
- `b`: Slope parameter (steepness of the curve)
- `c`: Minimum effect parameter (lower asymptote)
- `ic50`: IC₅₀ parameter (concentration at which the effect is halfway between
the minimum and maximum)
# Returns
The computed effect (e.g., optical density) for the given antibiotic
concentration and parameters.
Note: This function is vectorized and can handle array inputs for `x`.
"""
function logistic(x, a, b, c, ic50)
return @. a / (1.0 + exp(b * (log2(x) - log2(ic50)))) + c
endTo test the function, let’s plot the model for a set of parameters.
# Define parameters
a = 1.0
b = 1.0
c = 0.0
ic50 = 1.0
# Define concentration range
x = 10 .^ LinRange(-2.5, 2.5, 50)
# Compute model
y = logistic(x, a, b, c, ic50)
# Initialize figure
fig = Figure(size=(350, 300))
# Add axis
ax = Axis(
fig[1, 1],
xlabel="antibiotic concentration (a.u.)",
ylabel="optical density",
xscale=log10
)
# Plot model
lines!(ax, x, y)
figThe function seems to work as expected. However, notice that in Eq. (1), the \(\mathrm{IC}_{50}\) goes into the logarithm. This is the parameter we will fit for. Thus, let’s define a function that takes \(\log_2(\mathrm{IC}_{50})\) and \(\log_2(x)\) as input instead.
@doc raw"""
logistic_log2(log2x, a, b, c, log2ic50)
Compute the logistic function used to model the relationship between antibiotic
concentration and bacterial growth, using log2 inputs for concentration and
IC₅₀.
This function implements the following equation:
f(x) = a / (1 + exp(b * (log₂(x) - log₂(IC₅₀)))) + c
# Arguments
- `log2x`: log₂ of the antibiotic concentration (input variable)
- `a`: Maximum effect parameter (difference between upper and lower asymptotes)
- `b`: Slope parameter (steepness of the curve)
- `c`: Minimum effect parameter (lower asymptote)
- `log2ic50`: log₂ of the IC₅₀ parameter
# Returns
The computed effect (e.g., optical density) for the given log₂ antibiotic
concentration and parameters.
Note: This function is vectorized and can handle array inputs for `log2x`.
"""
function logistic_log2(log2x, a, b, c, log2ic50)
return @. a / (1.0 + exp(b * (log2x - log2ic50))) + c
endBayesian model
Given the model presented in Eq. (1), and the data, our objective is to infer the value of all parameters. By Bayes theorem, we write
\[ \pi(IC_{50}, a, b, c \mid \text{data}) = \frac{\pi(\text{data} \mid IC_{50}, a, b, c) \pi(IC_{50}, a, b, c)} {\pi(\text{data})}, \tag{2}\]
where \(\text{data}\) consists of the pairs of antibiotic concentration and optical density.
Likelihood \(\pi(\text{data} \mid IC_{50}, a, b, c)\)
Let’s begin by defining the likelihood function. For simplicity, we assume each datum is independent and identically distributed (i.i.d.) and write
\[ \pi(\text{data} \mid IC_{50}, a, b, c) = \prod_{i=1}^n \pi(d_i \mid IC_{50}, a, b, c), \tag{3}\]
where \(d_i = (x_i, y_i)\) is the \(i\)-th pair of antibiotic concentration and optical density, respectively, and \(n\) is the total number of data points. As a first pass, we assume that our experimental measurements can be expressed as
\[ y_i = f(x_i, IC_{50}, a, b, c) + \epsilon_i, \tag{4}\]
where \(\epsilon_i\) is the experimental error. Furthermore, we assume that the experimental error is normally distributed, i.e.,
\[ \epsilon_i \sim \mathcal{N}(0, \sigma^2), \tag{5}\]
where \(\sigma^2\) is an unknown variance parameter that must be included in our inference. Notice that we assume the same variance parameter for all data points since \(\sigma^2\) is not indexed by \(i\).
Given this likelihood function, we must update our inference on the parameters as
\[ \pi(IC_{50}, a, b, c, \sigma^2 \mid \text{data}) = \frac{\pi(\text{data} \mid IC_{50}, a, b, c, \sigma^2) \pi(IC_{50}, a, b, c, \sigma^2)} {\pi(\text{data})}, \tag{6}\]
to include the new parameter \(\sigma^2\).
Our likelihood function is then of the form \[ y_i \mid IC_{50}, a, b, c, \sigma^2 \sim \mathcal{N}(f(x_i, IC_{50}, a, b, c), \sigma^2). \tag{7}\]
Prior \(\pi(IC_{50}, a, b, c, \sigma^2)\)
For the prior, we assume that all parameters are independent and write \[ \pi(IC_{50}, a, b, c, \sigma^2) = \pi(IC_{50}) \pi(a) \pi(b) \pi(c) \pi(\sigma^2). \tag{8}\]
Let’s detail each prior.
- \(IC_{50}\): The \(IC_{50}\) is a strictly positive parameter. However, we will fit for \(\log_2(IC_{50})\). Thus, we will use a normal prior for \(\log_2(IC_{50})\). This means we have
\[ \log_2(IC_{50}) \sim \mathcal{N}( \mu_{\log_2(IC_{50})}, \sigma_{\log_2(IC_{50})}^2 ). \tag{9}\]
- \(a\): This nuisance parameter scales the logistic function. Again, the natural scale for this parameter is a strictly positive real number. Thus, we will use a lognormal prior for \(a\). This means we have
\[ a \sim \text{LogNormal}(\mu_a, \sigma_a^2). \tag{10}\]
- \(b\): This parameter controls the steepness of the logistic function. Again, the natural scale for this parameter is a strictly positive real number. Thus, we will use a lognormal prior for \(b\). This means we have
\[ b \sim \text{LogNormal}(\mu_b, \sigma_b^2). \tag{11}\]
- \(c\): This parameter controls the minimum value of the logistic function. Since this is a strictly positive real number that does not necessarily scale with the data, we will use a half-normal prior for \(c\). This means we have
\[ c \sim \text{Half-}\mathcal{N}(0, \sigma_c^2). \tag{12}\]
- \(\sigma^2\): This parameter controls the variance of the experimental error. Since this is a strictly positive real number that does not necessarily scale with the data, we will use a half-normal prior for \(\sigma^2\). This means we have
\[ \sigma^2 \sim \text{Half-}\mathcal{N}(0, \sigma_{\sigma^2}^2). \tag{13}\]
With all of this in place, we are ready to define a Turing model to perform Bayesian inference on the parameters of the model.
Turing.@model function logistic_model(
log2x, y, prior_params::NamedTuple=NamedTuple()
)
# Define default prior parameters
default_params = (
log2ic50=(0, 1),
a=(0, 1),
b=(0, 1),
c=(0, 1),
σ²=(0, 1)
)
# Merge default parameters with provided parameters
params = merge(default_params, prior_params)
# Define priors
log2ic50 ~ Turing.Normal(params.log2ic50...)
a ~ Turing.LogNormal(params.a...)
b ~ Turing.LogNormal(params.b...)
c ~ Turing.truncated(Turing.Normal(params.c...), 0, Inf)
σ² ~ Turing.truncated(Turing.Normal(params.σ²...), 0, Inf)
# Define likelihood
y ~ Turing.MvNormal(
logistic_log2(log2x, a, b, c, log2ic50),
LinearAlgebra.Diagonal(fill(σ², length(y)))
)
endHaving defined the model, we can now perform inference. First, let’s perform inference on simulated data. For this, we first simulate a single titration curve with known parameters.
Random.seed!(42)
# Define ground truth parameters
log2ic50_true = 0.5
a_true = 1.0
b_true = 10.0
σ²_true = 0.01
c_true = 0.1
# Simulate data
log2x = LinRange(-2.5, 2.5, 15)
# Define mean of data
ŷ = logistic_log2(log2x, a_true, b_true, c_true, log2ic50_true)
# Add noise
y = ŷ .+ randn(length(log2x)) * √(σ²_true)
# Initialize figure
fig = Figure(size=(350, 300))
# Add axis
ax = Axis(
fig[1, 1],
xlabel="antibiotic concentration",
ylabel="optical density",
xscale=log2
)
# Plot data
scatterlines!(ax, 2.0 .^ log2x, y)
figWith the data simulated, let’s perform inference on the data. For this, we will use the NUTS sampler. To run multiple chains in parallel, we will use the MCMCThreads option.
Random.seed!(42)
# Perform inference
model = logistic_model(log2x, y)
# Define number of steps
n_burnin = 10_000
n_samples = 1_000
# Run sampler using
chain = Turing.sample(
model, Turing.NUTS(), Turing.MCMCThreads(), n_burnin + n_samples, 4
)Sampling (4 threads) 0%| | ETA: N/A
┌ Info: Found initial step size
└ ϵ = 0.8
┌ Info: Found initial step size
└ ϵ = 0.8
┌ Info: Found initial step size
└ ϵ = 0.4
┌ Info: Found initial step size
└ ϵ = 0.8
Sampling (4 threads) 25%|███████▌ | ETA: 0:00:39
Sampling (4 threads) 50%|███████████████ | ETA: 0:00:13
Sampling (4 threads) 75%|██████████████████████▌ | ETA: 0:00:04
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:13
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:13
Chains MCMC chain (11000×17×4 Array{Float64, 3}):
Iterations = 1001:1:12000
Number of chains = 4
Samples per chain = 11000
Wall duration = 9.39 seconds
Compute duration = 36.29 seconds
parameters = log2ic50, a, b, c, σ²
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯
log2ic50 0.4723 0.0718 0.0005 23407.8672 14077.0993 1.0003 ⋯
a 0.9668 0.0621 0.0005 16034.1809 16294.6985 1.0006 ⋯
b 8.7902 4.6027 0.0449 13409.6206 13035.6999 1.0003 ⋯
c 0.1192 0.0449 0.0004 15396.5130 11022.7335 1.0004 ⋯
σ² 0.0111 0.0076 0.0001 12720.8496 16686.4143 1.0005 ⋯
1 column omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
log2ic50 0.3173 0.4352 0.4745 0.5141 0.6048
a 0.8482 0.9279 0.9649 1.0042 1.0961
b 3.0834 5.7843 7.9162 10.6434 20.2429
c 0.0294 0.0902 0.1194 0.1475 0.2093
σ² 0.0038 0.0065 0.0091 0.0132 0.0302
Let’s look at the posterior distribution of the parameters. For this, we will use the PairPlots.jl package.
# Plot corner plot for chains
PairPlots.pairplot(
chain[n_burnin+1:end, :, :],
PairPlots.Truth(
(;
log2ic50=log2ic50_true,
a=a_true,
b=b_true,
c=c_true,
σ²=σ²_true
)
)
)All parameters are well-constrained by the data, and we are able to recover the ground truth values.
Let’s plot the posterior predictive checks with the data to see how the model performs.
Random.seed!(42)
# Initialize matrix to store samples
y_samples = Array{Float64}(undef, length(log2x), n_samples)
# Loop through samples
for i in 1:n_samples
# Generate mean for sample
y_samples[:, i] = logistic_log2(
log2x,
chain[:a][i],
chain[:b][i],
chain[:c][i],
chain[:log2ic50][i]
)
# Add noise
y_samples[:, i] .+= randn(length(log2x)) * sqrt(chain[:σ²][i])
end # for
# Initialize figure
fig = Figure(size=(350, 300))
# Add axis
ax = Axis(
fig[1, 1],
xlabel="antibiotic concentration",
ylabel="optical density",
title="Posterior Predictive Check",
xscale=log2
)
# Plot samples
for i in 1:n_samples
lines!(ax, 2.0 .^ log2x, y_samples[:, i], color=(:gray, 0.05))
end # for
# Plot data
scatterlines!(ax, 2.0 .^ log2x, y)
figThe fit looks excellent. Let’s now perform inference on the Iwasawa data. We will use one example dataset. Moreover, we will set informative priors for the parameters.
Random.seed!(42)
# Group data by antibiotic, environment, and day
df_group = DF.groupby(
df[(.!df.blank).&(df.concentration_ugmL.>0), :],
[:antibiotic, :env, :day]
)
# Extract data
data = df_group[2]
# Define prior parameters
prior_params = (
a=(log(0.1), 0.1),
b=(0, 1),
c=(0, 1),
σ²=(0, 0.1)
)
# Perform inference
model = logistic_model(
log2.(data.concentration_ugmL),
data.OD,
prior_params
)
# Define number of steps
n_burnin = 10_000
n_samples = 1_000
chain = Turing.sample(
model, Turing.NUTS(), Turing.MCMCThreads(), n_burnin + n_samples, 4
)Sampling (4 threads) 0%| | ETA: N/A
┌ Info: Found initial step size
└ ϵ = 0.2
┌ Info: Found initial step size
└ ϵ = 0.05
┌ Info: Found initial step size
└ ϵ = 0.2
┌ Info: Found initial step size
└ ϵ = 0.05
Sampling (4 threads) 25%|███████▌ | ETA: 0:00:16
Sampling (4 threads) 50%|███████████████ | ETA: 0:00:05
Sampling (4 threads) 75%|██████████████████████▌ | ETA: 0:00:02
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:05
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:05
Chains MCMC chain (11000×17×4 Array{Float64, 3}):
Iterations = 1001:1:12000
Number of chains = 4
Samples per chain = 11000
Wall duration = 4.78 seconds
Compute duration = 17.88 seconds
parameters = log2ic50, a, b, c, σ²
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯
log2ic50 3.1585 0.0611 0.0004 32325.5670 25247.0891 1.0001 ⋯
a 0.1013 0.0074 0.0000 28541.0328 27631.5354 1.0001 ⋯
b 15.8931 10.4326 0.0655 31581.5518 26387.0900 1.0000 ⋯
c 0.0873 0.0073 0.0000 27591.1350 24819.1854 1.0002 ⋯
σ² 0.0015 0.0002 0.0000 37231.0264 30213.0584 1.0000 ⋯
1 column omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
log2ic50 3.0427 3.1192 3.1589 3.1963 3.2817
a 0.0872 0.0962 0.1012 0.1063 0.1162
b 4.6324 9.2475 13.4282 19.5750 41.4435
c 0.0726 0.0824 0.0874 0.0923 0.1013
σ² 0.0011 0.0013 0.0014 0.0016 0.0020
Let’s look at the corner plot for the chains.
# Plot corner plot for chains
PairPlots.pairplot(chain[n_burnin+1:end, :, :])This looks good. Let’s now plot the posterior predictive check.
Random.seed!(42)
# Define unique concentrations
unique_concentrations = sort(unique(data.concentration_ugmL))
# Initialize matrix to store samples
y_samples = Array{Float64}(
undef, length(unique_concentrations), n_samples
)
# Loop through samples
for i in 1:n_samples
# Generate mean for sample
y_samples[:, i] = logistic_log2(
log2.(unique_concentrations),
chain[:a][i],
chain[:b][i],
chain[:c][i],
chain[:log2ic50][i]
)
# Add noise
y_samples[:, i] .+= randn(length(unique_concentrations)) * √(chain[:σ²][i])
√(chain[:σ²][i])
end # for
# Initialize figure
fig = Figure(size=(350, 300))
# Add axis
ax = Axis(
fig[1, 1],
xlabel="antibiotic concentration",
ylabel="optical density",
title="Posterior Predictive Check",
xscale=log2
)
# Plot samples
for i in 1:n_samples
lines!(
ax,
unique_concentrations,
y_samples[:, i],
color=(ColorSchemes.Paired_12[1], 0.5)
)
end # for
# Plot data
scatter!(ax, data.concentration_ugmL, data.OD)
figThe presence of the outlier measurements expands the posterior predictive checks uncertainty.
Let’s look into how to detect outliers in the data.
Residual-based outlier detection
Our naive approach to detect outliers is to fit the logistic model deterministically and then identify points that are more than 3 standard deviations from the mean. This is known as “residual-based outlier detection”.
function logistic_log2(log2x, params)
return logistic_log2(log2x, params...)
end
"""
fit_logistic_and_detect_outliers(log2x, y; threshold=3)
Fit a logistic model to the given data and detect outliers based on residuals.
This function performs the following steps:
1. Fits a logistic model to the log2-transformed x-values and y-values.
2. Calculates residuals between the fitted model and actual y-values.
3. Identifies outliers as points with residuals exceeding a specified threshold.
# Arguments
- `log2x`: Array of log2-transformed x-values (typically concentrations).
- `y`: Array of y-values (typically optical density measurements).
- `threshold`: Number of standard deviations beyond which a point is considered an outlier. Default is 3.
# Returns
- A boolean array indicating which points are outliers (true for outliers).
# Notes
- The function uses a logistic model of the form:
y = a / (1 + exp(b * (log2x - log2ic50))) + c
- Initial parameter guesses are made based on the input data.
- The LsqFit package is used for curve fitting.
- Outliers are determined by comparing the absolute residuals to the threshold * standard deviation of residuals.
"""
function fit_logistic_and_detect_outliers(log2x, y; threshold=3)
# Initial parameter guess
p0 = [0.1, 1.0, maximum(y) - minimum(y), StatsBase.median(log2x)]
# Fit the logistic model
fit = LsqFit.curve_fit(logistic_log2, log2x, y, p0)
# Calculate residuals
residuals = y - logistic_log2(log2x, fit.param)
# Calculate standard deviation of residuals
σ = StatsBase.std(residuals)
# Identify outliers
outliers_idx = abs.(residuals) .> threshold * σ
# Return outlier indices
return outliers_idx
endLet’s test this function by trying to remove the outliers from the data we used in the previous section.
# Locate outliers
outliers_idx = fit_logistic_and_detect_outliers(
log2.(data.concentration_ugmL), data.OD, threshold=2
)
# Plot the results
fig = Figure(size=(450, 300))
ax = Axis(
fig[1, 1],
xlabel="antibiotic concentration",
ylabel="optical density",
# xscale=log2
)
# Plot the original data
scatter!(
ax,
data.concentration_ugmL,
data.OD,
color=ColorSchemes.Paired_12[1],
label="data"
)
# Plot the cleaned data
scatter!(
ax,
data.concentration_ugmL[.!outliers_idx] .+ 0.5,
data.OD[.!outliers_idx], color=ColorSchemes.Paired_12[2],
label="cleaned data"
)
# Add legend
Legend(fig[1, 2], ax)
figThe detection of outliers worked really well. Let’s now perform inference on the cleaned data.
Random.seed!(42)
# Group data by antibiotic, environment, and day
df_group = DF.groupby(
df[(.!df.blank).&(df.concentration_ugmL.>0), :],
[:antibiotic, :env, :day]
)
# Extract data
data = df_group[2]
# Find outliers
outliers_idx = fit_logistic_and_detect_outliers(
log2.(data.concentration_ugmL), data.OD, threshold=2
)
# Remove outliers
data_clean = data[.!outliers_idx, :]
# Define prior parameters
prior_params = (
a=(log(0.1), 0.1),
b=(0, 1),
c=(0, 1),
σ²=(0, 0.1)
)
# Perform inference
model = logistic_model(
log2.(data_clean.concentration_ugmL),
data_clean.OD,
prior_params
)
# Define number of steps
n_burnin = 10_000
n_samples = 1_000
chain = Turing.sample(
model, Turing.NUTS(), Turing.MCMCThreads(), n_burnin + n_samples, 4
)Sampling (4 threads) 0%| | ETA: N/A
┌ Info: Found initial step size
└ ϵ = 0.2
┌ Info: Found initial step size
└ ϵ = 0.05
┌ Info: Found initial step size
└ ϵ = 0.2
┌ Info: Found initial step size
└ ϵ = 0.05
Sampling (4 threads) 25%|███████▌ | ETA: 0:00:17
Sampling (4 threads) 50%|███████████████ | ETA: 0:00:06
Sampling (4 threads) 75%|██████████████████████▌ | ETA: 0:00:02
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:05
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:05
Chains MCMC chain (11000×17×4 Array{Float64, 3}):
Iterations = 1001:1:12000
Number of chains = 4
Samples per chain = 11000
Wall duration = 5.01 seconds
Compute duration = 19.13 seconds
parameters = log2ic50, a, b, c, σ²
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯
log2ic50 3.1001 0.0297 0.0002 25150.3685 27935.5938 1.0002 ⋯
a 0.1163 0.0038 0.0000 16860.1510 22570.9089 1.0001 ⋯
b 6.7350 1.0786 0.0069 26759.0743 25218.0167 1.0001 ⋯
c 0.0712 0.0034 0.0000 16323.0552 21352.9296 1.0001 ⋯
σ² 0.0001 0.0000 0.0000 32838.7703 28866.3125 1.0001 ⋯
1 column omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
log2ic50 3.0410 3.0804 3.1004 3.1202 3.1580
a 0.1089 0.1138 0.1163 0.1188 0.1238
b 5.0241 5.9818 6.5979 7.3261 9.2492
c 0.0643 0.0689 0.0712 0.0735 0.0778
σ² 0.0001 0.0001 0.0001 0.0001 0.0002
Let’s look again at the corner plot for the chains.
# Plot corner plot for chains
PairPlots.pairplot(chain[n_burnin+1:end, :, :])Compared to the previous example where we didn’t remove the outliers, the posterior distributions are more concentrated, especially for the \(b\) parameter.
Let’s now plot the posterior predictive check.
Random.seed!(42)
# Define unique concentrations
unique_concentrations = sort(unique(data_clean.concentration_ugmL))
# Initialize matrix to store samples
y_samples = Array{Float64}(
undef, length(unique_concentrations), n_samples
)
# Loop through samples
for i in 1:n_samples
# Generate mean for sample
y_samples[:, i] = logistic_log2(
log2.(unique_concentrations),
chain[:a][i],
chain[:b][i],
chain[:c][i],
chain[:log2ic50][i]
)
# Add noise
y_samples[:, i] .+= randn(length(unique_concentrations)) * √(chain[:σ²][i])
√(chain[:σ²][i])
end # for
# Initialize figure
fig = Figure(size=(350, 300))
# Add axis
ax = Axis(
fig[1, 1],
xlabel="antibiotic concentration",
ylabel="optical density",
title="Posterior Predictive Check",
xscale=log2
)
# Plot samples
for i in 1:n_samples
lines!(
ax,
unique_concentrations,
y_samples[:, i],
color=(ColorSchemes.Paired_12[1], 0.5)
)
end # for
# Plot data
scatter!(ax, data_clean.concentration_ugmL, data_clean.OD)
figThis look much better. Removing the outliers improved the fit significantly.
Alternative parameterization
In the data, there are several concentrations at zero, which is not compatible with the log-scale model. However, we can rewrite Equation 1 without using the exponent as
\[ f(x) = \frac{a} {1+ \left(\frac{x}{\mathrm{IC}_{50}}\right)^b} + c \tag{14}\]
Let’s define a new model using this parameterization.
@doc raw"""
logistic_alt(x, a, b, c, ic50)
Compute the logistic function used to model the relationship between antibiotic
concentration and bacterial growth.
This function implements the following equation:
f(x) = a / (1 + (x / IC₅₀) ^ b) + c
# Arguments
- `x`: Antibiotic concentration (input variable)
- `a`: Maximum effect parameter (difference between upper and lower asymptotes)
- `b`: Slope parameter (steepness of the curve)
- `c`: Minimum effect parameter (lower asymptote)
- `ic50`: IC₅₀ parameter (concentration at which the effect is halfway between
the minimum and maximum)
# Returns
The computed effect (e.g., optical density) for the given antibiotic
concentration and parameters.
Note: This function is vectorized and can handle array inputs for `x`.
"""
function logistic_alt(x, a, b, c, ic50)
return @. a / (1.0 + (x / ic50)^b) + c
end
function logistic_alt(x, params)
return logistic_alt(x, params...)
endWe will use the previous function to detect outliers in the data. Now, we re-define the Bayesian model using this parameterization.
Turing.@model function logistic_alt_model(
x, y, prior_params::NamedTuple=NamedTuple()
)
# Define default prior parameters
default_params = (
ic50=(0, 1),
a=(0, 1),
b=(0, 1),
c=(0, 1),
σ²=(0, 1)
)
# Merge default parameters with provided parameters
params = merge(default_params, prior_params)
# Define priors
ic50 ~ Turing.LogNormal(params.ic50...)
a ~ Turing.LogNormal(params.a...)
b ~ Turing.LogNormal(params.b...)
c ~ Turing.truncated(Turing.Normal(params.c...), 0, Inf)
σ² ~ Turing.truncated(Turing.Normal(params.σ²...), 0, Inf)
# Define likelihood
y ~ Turing.MvNormal(
logistic_alt(x, a, b, c, ic50),
LinearAlgebra.Diagonal(fill(σ², length(y)))
)
endLet’s now perform inference on synthetic data using this new parameterization.
Random.seed!(42)
# Define ground truth parameters
ic50_true = 2^0.5
a_true = 1.0
b_true = 10.0
σ²_true = 0.01
c_true = 0.1
# Simulate data
x = 2 .^ LinRange(-2.5, 2.5, 15)
# Define mean of data
ŷ = logistic_alt(x, a_true, b_true, c_true, ic50_true)
# Add noise
y = ŷ + randn(length(x)) * sqrt(σ²_true)
# Initialize figure
fig = Figure(size=(350, 300))
# Add axis
ax = Axis(
fig[1, 1],
xlabel="antibiotic concentration",
ylabel="optical density",
xscale=log2
)
# Plot data
scatterlines!(ax, x, y)
figWith the data simulated, let’s perform inference on the data.
Random.seed!(42)
# Perform inference
model = logistic_alt_model(x, y)
# Define number of steps
n_burnin = 10_000
n_samples = 1_000
chain = Turing.sample(
model, Turing.NUTS(), Turing.MCMCThreads(), n_burnin + n_samples, 4
)Sampling (4 threads) 0%| | ETA: N/A
┌ Info: Found initial step size
└ ϵ = 1.6
┌ Info: Found initial step size
└ ϵ = 0.8
┌ Info: Found initial step size
└ ϵ = 0.4
┌ Info: Found initial step size
└ ϵ = 0.8
Sampling (4 threads) 25%|███████▌ | ETA: 0:00:19
Sampling (4 threads) 50%|███████████████ | ETA: 0:00:06
Sampling (4 threads) 75%|██████████████████████▌ | ETA: 0:00:02
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:06
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:06
Chains MCMC chain (11000×17×4 Array{Float64, 3}):
Iterations = 1001:1:12000
Number of chains = 4
Samples per chain = 11000
Wall duration = 5.49 seconds
Compute duration = 21.43 seconds
parameters = ic50, a, b, c, σ²
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯
ic50 1.3615 0.0783 0.0006 21996.9818 15834.2853 1.0001 ⋯
a 0.9776 0.0641 0.0005 14338.2651 16755.4871 1.0000 ⋯
b 7.5525 2.9801 0.0243 15179.8530 18757.0920 1.0000 ⋯
c 0.1183 0.0456 0.0004 14321.6452 10118.2654 1.0002 ⋯
σ² 0.0102 0.0066 0.0001 14265.0218 17759.7256 1.0002 ⋯
1 column omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
ic50 1.1980 1.3177 1.3639 1.4079 1.5098
a 0.8545 0.9372 0.9761 1.0169 1.1094
b 3.6070 5.6275 6.9973 8.7969 14.6326
c 0.0281 0.0886 0.1180 0.1470 0.2112
σ² 0.0037 0.0062 0.0085 0.0121 0.0270
Let’s look at the posterior distribution of the parameters.
# Plot corner plot for chains
PairPlots.pairplot(
chain[n_burnin+1:end, :, :],
PairPlots.Truth(
(;
ic50=ic50_true,
a=a_true,
b=b_true,
c=c_true,
σ²=σ²_true
)
)
)The model is still able to recover the ground truth parameters. Let’s now plot the posterior predictive check.
Random.seed!(42)
# Initialize matrix to store samples
y_samples = Array{Float64}(undef, length(x), n_samples)
# Loop through samples
for i in 1:n_samples
# Generate mean for sample
y_samples[:, i] = logistic_alt(
x,
chain[:a][i],
chain[:b][i],
chain[:c][i],
chain[:ic50][i]
)
# Add noise
y_samples[:, i] .+= randn(length(x)) * √(chain[:σ²][i])
end # for
# Initialize figure
fig = Figure(size=(350, 300))
# Add axis
ax = Axis(
fig[1, 1],
xlabel="antibiotic concentration",
ylabel="optical density",
title="Posterior Predictive Check",
xscale=log2
)
# Plot samples
for i in 1:n_samples
lines!(ax, x, y_samples[:, i], color=(:gray, 0.05))
end # for
# Plot data
scatterlines!(ax, x, y)
figEverything looks good. Let’s again test it on the Iwasawa et al. (2022) data.
Random.seed!(42)
# Define prior parameters
prior_params = (
a=(log(0.1), 0.1),
b=(0, 1),
c=(0, 1),
σ²=(0, 0.01)
)
# Define data
data = df_group[2]
# Clean data
outliers_idx = fit_logistic_and_detect_outliers(
log2.(data.concentration_ugmL), data.OD, threshold=2
)
data_clean = data[.!outliers_idx, :]
# Perform inference
model = logistic_alt_model(
data_clean.concentration_ugmL,
data_clean.OD,
prior_params
)
# Define number of steps
n_burnin = 10_000
n_samples = 1_000
chain = Turing.sample(
model, Turing.NUTS(), Turing.MCMCThreads(), n_burnin + n_samples, 4
)Sampling (4 threads) 0%| | ETA: N/A
┌ Info: Found initial step size
└ ϵ = 0.00625
┌ Info: Found initial step size
└ ϵ = 0.00625
┌ Info: Found initial step size
└ ϵ = 0.0125
┌ Info: Found initial step size
└ ϵ = 0.00625
Sampling (4 threads) 25%|███████▌ | ETA: 0:00:18
Sampling (4 threads) 50%|███████████████ | ETA: 0:00:06
Sampling (4 threads) 75%|██████████████████████▌ | ETA: 0:00:02
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:06
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:06
Chains MCMC chain (11000×17×4 Array{Float64, 3}):
Iterations = 1001:1:12000
Number of chains = 4
Samples per chain = 11000
Wall duration = 5.55 seconds
Compute duration = 21.2 seconds
parameters = ic50, a, b, c, σ²
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯
ic50 8.5831 0.1769 0.0011 25568.1627 25963.0962 1.0003 ⋯
a 0.1165 0.0038 0.0000 18041.1443 20901.4168 1.0002 ⋯
b 9.6366 1.5316 0.0097 26259.3918 25025.7026 1.0000 ⋯
c 0.0710 0.0035 0.0000 17294.7757 21118.9348 1.0002 ⋯
σ² 0.0001 0.0000 0.0000 29855.3383 28397.5187 1.0001 ⋯
1 column omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
ic50 8.2417 8.4631 8.5828 8.6994 8.9339
a 0.1091 0.1139 0.1164 0.1190 0.1241
b 7.2053 8.5622 9.4488 10.4964 13.1592
c 0.0640 0.0687 0.0711 0.0734 0.0777
σ² 0.0001 0.0001 0.0001 0.0001 0.0002
Let’s look at the posterior distribution of the parameters.
# Plot corner plot for chains
PairPlots.pairplot(chain[n_burnin+1:end, :, :])Nothing looks obviously wrong. Let’s plot the posterior predictive check.
Random.seed!(42)
# Define unique concentrations
unique_concentrations = sort(unique(data_clean.concentration_ugmL))
# Initialize matrix to store samples
y_samples = Array{Float64}(
undef, length(unique_concentrations), n_samples
)
# Loop through samples
for i in 1:n_samples
# Generate mean for sample
y_samples[:, i] = logistic_alt(
unique_concentrations,
chain[:a][i],
chain[:b][i],
chain[:c][i],
chain[:ic50][i]
)
# Add noise
y_samples[:, i] .+= randn(length(unique_concentrations)) * √(chain[:σ²][i])
end # for
# Initialize figure
fig = Figure(size=(350, 300))
# Add axis
ax = Axis(
fig[1, 1],
xlabel="antibiotic concentration",
ylabel="optical density",
title="Posterior Predictive Check",
xscale=log2
)
# Plot samples
for i in 1:n_samples
lines!(
ax,
unique_concentrations,
y_samples[:, i],
color=(ColorSchemes.Paired_12[1], 0.5)
)
end # for
# Plot data
scatter!(ax, data_clean.concentration_ugmL, data_clean.OD)
figWith this parameterization, the model performs as well as the previous one.
Conclusion
In this notebook, we have performed Bayesian inference on the data from Iwasawa et al. (2022) using two different parameterizations of the logistic function. Both parameterizations perform equally well, and we are able to recover the ground truth parameters from simulated data. Furthermore, implementing an outlier detection scheme improved the fit significantly.