# Import project package
import Antibiotic
# Import package to handle DataFrames
import DataFrames as DF
import CSV
# Import library for Bayesian inference
import Turing
# Import library to list files
import Glob
# Import packages to work with data
import DataFrames as DF
# Load CairoMakie for plotting
using CairoMakie
import ColorSchemes
# Import packages for posterior sampling
import PairPlots
# Import basic math libraries
import LsqFit
import StatsBase
import LinearAlgebra
import Random
# Activate backend
activate!()
CairoMakie.
# Set custom plotting style
theme_makie!() Antibiotic.viz.
Bayesian Inference of \(IC_{50}\) Values
Bayesian inference, Antibiotic resistance, \(IC_{50}\)
(c) This work is licensed under a Creative Commons Attribution License CC-BY 4.0. All code contained herein is licensed under an MIT license.
(c) This work is licensed under a Creative Commons Attribution License CC-BY 4.0. All code contained herein is licensed under an MIT license.
Bayesian Inference of \(IC_{50}\) values
In this notebook, we will perform Bayesian inference on the \(IC_{50}\) values of the antibiotic resistance landscape. For this, we will use the raw OD620
measurements provided by Iwasawa et al. (2022).
Let’s begin by loading the data into a DataFrame.
# Load data into a DataFrame
= CSV.read("./iwasawa_data/iwasawa_tidy.csv", DF.DataFrame)
df
first(df, 5)
Row | antibiotic | col | OD | design | concentration_ugmL | day | strain_num | env | plate | blank |
---|---|---|---|---|---|---|---|---|---|---|
String7 | String7 | Float64 | Int64 | Float64 | Int64 | Int64 | String15 | Int64 | Bool | |
1 | TET | col1 | 0.0464 | 3 | 0.0 | 1 | 19 | TETE4_in_KM | 10 | true |
2 | KM | col1 | 0.0432 | 3 | 0.0 | 1 | 19 | TETE4_in_KM | 10 | true |
3 | NFLX | col1 | 0.0414 | 3 | 0.0 | 1 | 19 | TETE4_in_KM | 10 | true |
4 | SS | col1 | 0.0415 | 3 | 0.0 | 1 | 19 | TETE4_in_KM | 10 | true |
5 | PLM | col1 | 0.0427 | 3 | 0.0 | 1 | 19 | TETE4_in_KM | 10 | true |
To double-check that the structure of the table makes sense, let’s plot the time series for one example to see if the sequence agrees with the expectation.
# Define data to use
= df[
data .=="KM").&(df.env.=="Parent_in_KM").&(df.strain_num.==13).&.!(df.blank).&(df.concentration_ugmL.>0),
(df.antibiotic:]
# Remove blank measurement
# Group data by day
= DF.groupby(data, :day)
df_group
# Initialize figure
= Figure(size=(500, 300))
fig
# Add axis
= Axis(
ax 1, 1],
fig[="antibiotic concentration",
xlabel="OD₆₂₀",
ylabel=log2
xscale
)
# Define colors for plot
= get(ColorSchemes.Blues_9, LinRange(0.25, 1, length(df_group)))
colors
# Loop through days
for (i, d) in enumerate(df_group)
# Sort data by concentration
sort!(d, :concentration_ugmL)
DF.# Plot scatter line
scatterlines!(
=colors[i], label="$(first(d.day))"
ax, d.concentration_ugmL, d.OD, color
)end # for
# Add legend to plot
1, 2] = Legend(
fig["day", framevisible=false, nbanks=3, labelsize=10
fig, ax,
)
fig
The functional form used by the authors to fit the data is \[ f(x) = \frac{a} {1+\exp \left[b\left(\log _2 x-\log _2 IC_{50}\right)\right]} + c \tag{1}\]
where \(a\), \(b\), and \(c\) are nuisance parameters of the model, \(IC_{50}\) is the parameter of interest, and \(x\) is the antibiotic concentration. We can define a function to compute this model.
@doc raw"""
logistic(x, a, b, c, ic50)
Compute the logistic function used to model the relationship between antibiotic
concentration and bacterial growth.
This function implements the following equation:
f(x) = a / (1 + exp(b * (log₂(x) - log₂(IC₅₀)))) + c
# Arguments
- `x`: Antibiotic concentration (input variable)
- `a`: Maximum effect parameter (difference between upper and lower asymptotes)
- `b`: Slope parameter (steepness of the curve)
- `c`: Minimum effect parameter (lower asymptote)
- `ic50`: IC₅₀ parameter (concentration at which the effect is halfway between
the minimum and maximum)
# Returns
The computed effect (e.g., optical density) for the given antibiotic
concentration and parameters.
Note: This function is vectorized and can handle array inputs for `x`.
"""
function logistic(x, a, b, c, ic50)
return @. a / (1.0 + exp(b * (log2(x) - log2(ic50)))) + c
end
To test the function, let’s plot the model for a set of parameters.
# Define parameters
= 1.0
a = 1.0
b = 0.0
c = 1.0
ic50
# Define concentration range
= 10 .^ LinRange(-2.5, 2.5, 50)
x
# Compute model
= logistic(x, a, b, c, ic50)
y
# Initialize figure
= Figure(size=(350, 300))
fig # Add axis
= Axis(
ax 1, 1],
fig[="antibiotic concentration (a.u.)",
xlabel="optical density",
ylabel=log10
xscale
)# Plot model
lines!(ax, x, y)
fig
The function seems to work as expected. However, notice that in Eq. (1), the \(\mathrm{IC}_{50}\) goes into the logarithm. This is the parameter we will fit for. Thus, let’s define a function that takes \(\log_2(\mathrm{IC}_{50})\) and \(\log_2(x)\) as input instead.
@doc raw"""
logistic_log2(log2x, a, b, c, log2ic50)
Compute the logistic function used to model the relationship between antibiotic
concentration and bacterial growth, using log2 inputs for concentration and
IC₅₀.
This function implements the following equation:
f(x) = a / (1 + exp(b * (log₂(x) - log₂(IC₅₀)))) + c
# Arguments
- `log2x`: log₂ of the antibiotic concentration (input variable)
- `a`: Maximum effect parameter (difference between upper and lower asymptotes)
- `b`: Slope parameter (steepness of the curve)
- `c`: Minimum effect parameter (lower asymptote)
- `log2ic50`: log₂ of the IC₅₀ parameter
# Returns
The computed effect (e.g., optical density) for the given log₂ antibiotic
concentration and parameters.
Note: This function is vectorized and can handle array inputs for `log2x`.
"""
function logistic_log2(log2x, a, b, c, log2ic50)
return @. a / (1.0 + exp(b * (log2x - log2ic50))) + c
end
Bayesian model
Given the model presented in Eq. (1), and the data, our objective is to infer the value of all parameters. By Bayes theorem, we write
\[ \pi(IC_{50}, a, b, c \mid \text{data}) = \frac{\pi(\text{data} \mid IC_{50}, a, b, c) \pi(IC_{50}, a, b, c)} {\pi(\text{data})}, \tag{2}\]
where \(\text{data}\) consists of the pairs of antibiotic concentration and optical density.
Likelihood \(\pi(\text{data} \mid IC_{50}, a, b, c)\)
Let’s begin by defining the likelihood function. For simplicity, we assume each datum is independent and identically distributed (i.i.d.) and write
\[ \pi(\text{data} \mid IC_{50}, a, b, c) = \prod_{i=1}^n \pi(d_i \mid IC_{50}, a, b, c), \tag{3}\]
where \(d_i = (x_i, y_i)\) is the \(i\)-th pair of antibiotic concentration and optical density, respectively, and \(n\) is the total number of data points. As a first pass, we assume that our experimental measurements can be expressed as
\[ y_i = f(x_i, IC_{50}, a, b, c) + \epsilon_i, \tag{4}\]
where \(\epsilon_i\) is the experimental error. Furthermore, we assume that the experimental error is normally distributed, i.e.,
\[ \epsilon_i \sim \mathcal{N}(0, \sigma^2), \tag{5}\]
where \(\sigma^2\) is an unknown variance parameter that must be included in our inference. Notice that we assume the same variance parameter for all data points since \(\sigma^2\) is not indexed by \(i\).
Given this likelihood function, we must update our inference on the parameters as
\[ \pi(IC_{50}, a, b, c, \sigma^2 \mid \text{data}) = \frac{\pi(\text{data} \mid IC_{50}, a, b, c, \sigma^2) \pi(IC_{50}, a, b, c, \sigma^2)} {\pi(\text{data})}, \tag{6}\]
to include the new parameter \(\sigma^2\).
Our likelihood function is then of the form \[ y_i \mid IC_{50}, a, b, c, \sigma^2 \sim \mathcal{N}(f(x_i, IC_{50}, a, b, c), \sigma^2). \tag{7}\]
Prior \(\pi(IC_{50}, a, b, c, \sigma^2)\)
For the prior, we assume that all parameters are independent and write \[ \pi(IC_{50}, a, b, c, \sigma^2) = \pi(IC_{50}) \pi(a) \pi(b) \pi(c) \pi(\sigma^2). \tag{8}\]
Let’s detail each prior.
- \(IC_{50}\): The \(IC_{50}\) is a strictly positive parameter. However, we will fit for \(\log_2(IC_{50})\). Thus, we will use a normal prior for \(\log_2(IC_{50})\). This means we have
\[ \log_2(IC_{50}) \sim \mathcal{N}( \mu_{\log_2(IC_{50})}, \sigma_{\log_2(IC_{50})}^2 ). \tag{9}\]
- \(a\): This nuisance parameter scales the logistic function. Again, the natural scale for this parameter is a strictly positive real number. Thus, we will use a lognormal prior for \(a\). This means we have
\[ a \sim \text{LogNormal}(\mu_a, \sigma_a^2). \tag{10}\]
- \(b\): This parameter controls the steepness of the logistic function. Again, the natural scale for this parameter is a strictly positive real number. Thus, we will use a lognormal prior for \(b\). This means we have
\[ b \sim \text{LogNormal}(\mu_b, \sigma_b^2). \tag{11}\]
- \(c\): This parameter controls the minimum value of the logistic function. Since this is a strictly positive real number that does not necessarily scale with the data, we will use a half-normal prior for \(c\). This means we have
\[ c \sim \text{Half-}\mathcal{N}(0, \sigma_c^2). \tag{12}\]
- \(\sigma^2\): This parameter controls the variance of the experimental error. Since this is a strictly positive real number that does not necessarily scale with the data, we will use a half-normal prior for \(\sigma^2\). This means we have
\[ \sigma^2 \sim \text{Half-}\mathcal{N}(0, \sigma_{\sigma^2}^2). \tag{13}\]
With all of this in place, we are ready to define a Turing
model to perform Bayesian inference on the parameters of the model.
@model function logistic_model(
Turing.::NamedTuple=NamedTuple()
log2x, y, prior_params
)# Define default prior parameters
= (
default_params =(0, 1),
log2ic50=(0, 1),
a=(0, 1),
b=(0, 1),
c=(0, 1)
σ²
)
# Merge default parameters with provided parameters
= merge(default_params, prior_params)
params
# Define priors
~ Turing.Normal(params.log2ic50...)
log2ic50 ~ Turing.LogNormal(params.a...)
a ~ Turing.LogNormal(params.b...)
b ~ Turing.truncated(Turing.Normal(params.c...), 0, Inf)
c ~ Turing.truncated(Turing.Normal(params.σ²...), 0, Inf)
σ²
# Define likelihood
~ Turing.MvNormal(
y logistic_log2(log2x, a, b, c, log2ic50),
LinearAlgebra.Diagonal(fill(σ², length(y)))
)end
Having defined the model, we can now perform inference. First, let’s perform inference on simulated data. For this, we first simulate a single titration curve with known parameters.
Random.seed!(42)
# Define ground truth parameters
= 0.5
log2ic50_true = 1.0
a_true = 10.0
b_true = 0.01
σ²_true = 0.1
c_true
# Simulate data
= LinRange(-2.5, 2.5, 15)
log2x # Define mean of data
= logistic_log2(log2x, a_true, b_true, c_true, log2ic50_true)
ŷ # Add noise
= ŷ .+ randn(length(log2x)) * √(σ²_true)
y
# Initialize figure
= Figure(size=(350, 300))
fig # Add axis
= Axis(
ax 1, 1],
fig[="antibiotic concentration",
xlabel="optical density",
ylabel=log2
xscale
)# Plot data
scatterlines!(ax, 2.0 .^ log2x, y)
fig
With the data simulated, let’s perform inference on the data. For this, we will use the NUTS sampler. To run multiple chains in parallel, we will use the MCMCThreads
option.
Random.seed!(42)
# Perform inference
= logistic_model(log2x, y)
model
# Define number of steps
= 10_000
n_burnin = 1_000
n_samples
# Run sampler using
= Turing.sample(
chain NUTS(), Turing.MCMCThreads(), n_burnin + n_samples, 4
model, Turing. )
Sampling (4 threads) 0%| | ETA: N/A
┌ Info: Found initial step size
└ ϵ = 0.8
┌ Info: Found initial step size
└ ϵ = 0.8
┌ Info: Found initial step size
└ ϵ = 0.4
┌ Info: Found initial step size
└ ϵ = 0.8
Sampling (4 threads) 25%|███████▌ | ETA: 0:00:39
Sampling (4 threads) 50%|███████████████ | ETA: 0:00:13
Sampling (4 threads) 75%|██████████████████████▌ | ETA: 0:00:04
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:13
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:13
Chains MCMC chain (11000×17×4 Array{Float64, 3}):
Iterations = 1001:1:12000
Number of chains = 4
Samples per chain = 11000
Wall duration = 9.39 seconds
Compute duration = 36.29 seconds
parameters = log2ic50, a, b, c, σ²
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯
log2ic50 0.4723 0.0718 0.0005 23407.8672 14077.0993 1.0003 ⋯
a 0.9668 0.0621 0.0005 16034.1809 16294.6985 1.0006 ⋯
b 8.7902 4.6027 0.0449 13409.6206 13035.6999 1.0003 ⋯
c 0.1192 0.0449 0.0004 15396.5130 11022.7335 1.0004 ⋯
σ² 0.0111 0.0076 0.0001 12720.8496 16686.4143 1.0005 ⋯
1 column omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
log2ic50 0.3173 0.4352 0.4745 0.5141 0.6048
a 0.8482 0.9279 0.9649 1.0042 1.0961
b 3.0834 5.7843 7.9162 10.6434 20.2429
c 0.0294 0.0902 0.1194 0.1475 0.2093
σ² 0.0038 0.0065 0.0091 0.0132 0.0302
Let’s look at the posterior distribution of the parameters. For this, we will use the PairPlots.jl
package.
# Plot corner plot for chains
pairplot(
PairPlots.+1:end, :, :],
chain[n_burninTruth(
PairPlots.
(;=log2ic50_true,
log2ic50=a_true,
a=b_true,
b=c_true,
c=σ²_true
σ²
)
) )
All parameters are well-constrained by the data, and we are able to recover the ground truth values.
Let’s plot the posterior predictive checks with the data to see how the model performs.
Random.seed!(42)
# Initialize matrix to store samples
= Array{Float64}(undef, length(log2x), n_samples)
y_samples
# Loop through samples
for i in 1:n_samples
# Generate mean for sample
:, i] = logistic_log2(
y_samples[
log2x,:a][i],
chain[:b][i],
chain[:c][i],
chain[:log2ic50][i]
chain[
)# Add noise
:, i] .+= randn(length(log2x)) * sqrt(chain[:σ²][i])
y_samples[end # for
# Initialize figure
= Figure(size=(350, 300))
fig # Add axis
= Axis(
ax 1, 1],
fig[="antibiotic concentration",
xlabel="optical density",
ylabel="Posterior Predictive Check",
title=log2
xscale
)
# Plot samples
for i in 1:n_samples
lines!(ax, 2.0 .^ log2x, y_samples[:, i], color=(:gray, 0.05))
end # for
# Plot data
scatterlines!(ax, 2.0 .^ log2x, y)
fig
The fit looks excellent. Let’s now perform inference on the Iwasawa data. We will use one example dataset. Moreover, we will set informative priors for the parameters.
Random.seed!(42)
# Group data by antibiotic, environment, and day
= DF.groupby(
df_group .&(df.concentration_ugmL.>0), :],
df[(.!df.blank):antibiotic, :env, :day]
[
)
# Extract data
= df_group[2]
data
# Define prior parameters
= (
prior_params =(log(0.1), 0.1),
a=(0, 1),
b=(0, 1),
c=(0, 0.1)
σ²
)# Perform inference
= logistic_model(
model log2.(data.concentration_ugmL),
data.OD,
prior_params
)
# Define number of steps
= 10_000
n_burnin = 1_000
n_samples
= Turing.sample(
chain NUTS(), Turing.MCMCThreads(), n_burnin + n_samples, 4
model, Turing. )
Sampling (4 threads) 0%| | ETA: N/A
┌ Info: Found initial step size
└ ϵ = 0.2
┌ Info: Found initial step size
└ ϵ = 0.05
┌ Info: Found initial step size
└ ϵ = 0.2
┌ Info: Found initial step size
└ ϵ = 0.05
Sampling (4 threads) 25%|███████▌ | ETA: 0:00:16
Sampling (4 threads) 50%|███████████████ | ETA: 0:00:05
Sampling (4 threads) 75%|██████████████████████▌ | ETA: 0:00:02
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:05
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:05
Chains MCMC chain (11000×17×4 Array{Float64, 3}):
Iterations = 1001:1:12000
Number of chains = 4
Samples per chain = 11000
Wall duration = 4.78 seconds
Compute duration = 17.88 seconds
parameters = log2ic50, a, b, c, σ²
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯
log2ic50 3.1585 0.0611 0.0004 32325.5670 25247.0891 1.0001 ⋯
a 0.1013 0.0074 0.0000 28541.0328 27631.5354 1.0001 ⋯
b 15.8931 10.4326 0.0655 31581.5518 26387.0900 1.0000 ⋯
c 0.0873 0.0073 0.0000 27591.1350 24819.1854 1.0002 ⋯
σ² 0.0015 0.0002 0.0000 37231.0264 30213.0584 1.0000 ⋯
1 column omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
log2ic50 3.0427 3.1192 3.1589 3.1963 3.2817
a 0.0872 0.0962 0.1012 0.1063 0.1162
b 4.6324 9.2475 13.4282 19.5750 41.4435
c 0.0726 0.0824 0.0874 0.0923 0.1013
σ² 0.0011 0.0013 0.0014 0.0016 0.0020
Let’s look at the corner plot for the chains.
# Plot corner plot for chains
pairplot(chain[n_burnin+1:end, :, :]) PairPlots.
This looks good. Let’s now plot the posterior predictive check.
Random.seed!(42)
# Define unique concentrations
= sort(unique(data.concentration_ugmL))
unique_concentrations
# Initialize matrix to store samples
= Array{Float64}(
y_samples undef, length(unique_concentrations), n_samples
)
# Loop through samples
for i in 1:n_samples
# Generate mean for sample
:, i] = logistic_log2(
y_samples[log2.(unique_concentrations),
:a][i],
chain[:b][i],
chain[:c][i],
chain[:log2ic50][i]
chain[
)# Add noise
:, i] .+= randn(length(unique_concentrations)) * √(chain[:σ²][i])
y_samples[√(chain[:σ²][i])
end # for
# Initialize figure
= Figure(size=(350, 300))
fig # Add axis
= Axis(
ax 1, 1],
fig[="antibiotic concentration",
xlabel="optical density",
ylabel="Posterior Predictive Check",
title=log2
xscale
)
# Plot samples
for i in 1:n_samples
lines!(
ax,
unique_concentrations,:, i],
y_samples[=(ColorSchemes.Paired_12[1], 0.5)
color
)end # for
# Plot data
scatter!(ax, data.concentration_ugmL, data.OD)
fig
The presence of the outlier measurements expands the posterior predictive checks uncertainty.
Let’s look into how to detect outliers in the data.
Residual-based outlier detection
Our naive approach to detect outliers is to fit the logistic model deterministically and then identify points that are more than 3 standard deviations from the mean. This is known as “residual-based outlier detection”.
function logistic_log2(log2x, params)
return logistic_log2(log2x, params...)
end
"""
fit_logistic_and_detect_outliers(log2x, y; threshold=3)
Fit a logistic model to the given data and detect outliers based on residuals.
This function performs the following steps:
1. Fits a logistic model to the log2-transformed x-values and y-values.
2. Calculates residuals between the fitted model and actual y-values.
3. Identifies outliers as points with residuals exceeding a specified threshold.
# Arguments
- `log2x`: Array of log2-transformed x-values (typically concentrations).
- `y`: Array of y-values (typically optical density measurements).
- `threshold`: Number of standard deviations beyond which a point is considered an outlier. Default is 3.
# Returns
- A boolean array indicating which points are outliers (true for outliers).
# Notes
- The function uses a logistic model of the form:
y = a / (1 + exp(b * (log2x - log2ic50))) + c
- Initial parameter guesses are made based on the input data.
- The LsqFit package is used for curve fitting.
- Outliers are determined by comparing the absolute residuals to the threshold * standard deviation of residuals.
"""
function fit_logistic_and_detect_outliers(log2x, y; threshold=3)
# Initial parameter guess
= [0.1, 1.0, maximum(y) - minimum(y), StatsBase.median(log2x)]
p0
# Fit the logistic model
= LsqFit.curve_fit(logistic_log2, log2x, y, p0)
fit
# Calculate residuals
= y - logistic_log2(log2x, fit.param)
residuals
# Calculate standard deviation of residuals
= StatsBase.std(residuals)
σ
# Identify outliers
= abs.(residuals) .> threshold * σ
outliers_idx
# Return outlier indices
return outliers_idx
end
Let’s test this function by trying to remove the outliers from the data we used in the previous section.
# Locate outliers
= fit_logistic_and_detect_outliers(
outliers_idx log2.(data.concentration_ugmL), data.OD, threshold=2
)
# Plot the results
= Figure(size=(450, 300))
fig
= Axis(
ax 1, 1],
fig[="antibiotic concentration",
xlabel="optical density",
ylabel# xscale=log2
)
# Plot the original data
scatter!(
ax,
data.concentration_ugmL,
data.OD,=ColorSchemes.Paired_12[1],
color="data"
label
)
# Plot the cleaned data
scatter!(
ax,.+ 0.5,
data.concentration_ugmL[.!outliers_idx] =ColorSchemes.Paired_12[2],
data.OD[.!outliers_idx], color="cleaned data"
label
)
# Add legend
Legend(fig[1, 2], ax)
fig
The detection of outliers worked really well. Let’s now perform inference on the cleaned data.
Random.seed!(42)
# Group data by antibiotic, environment, and day
= DF.groupby(
df_group .&(df.concentration_ugmL.>0), :],
df[(.!df.blank):antibiotic, :env, :day]
[
)
# Extract data
= df_group[2]
data
# Find outliers
= fit_logistic_and_detect_outliers(
outliers_idx log2.(data.concentration_ugmL), data.OD, threshold=2
)# Remove outliers
= data[.!outliers_idx, :]
data_clean
# Define prior parameters
= (
prior_params =(log(0.1), 0.1),
a=(0, 1),
b=(0, 1),
c=(0, 0.1)
σ²
)
# Perform inference
= logistic_model(
model log2.(data_clean.concentration_ugmL),
data_clean.OD,
prior_params
)
# Define number of steps
= 10_000
n_burnin = 1_000
n_samples
= Turing.sample(
chain NUTS(), Turing.MCMCThreads(), n_burnin + n_samples, 4
model, Turing. )
Sampling (4 threads) 0%| | ETA: N/A
┌ Info: Found initial step size
└ ϵ = 0.2
┌ Info: Found initial step size
└ ϵ = 0.05
┌ Info: Found initial step size
└ ϵ = 0.2
┌ Info: Found initial step size
└ ϵ = 0.05
Sampling (4 threads) 25%|███████▌ | ETA: 0:00:17
Sampling (4 threads) 50%|███████████████ | ETA: 0:00:06
Sampling (4 threads) 75%|██████████████████████▌ | ETA: 0:00:02
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:05
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:05
Chains MCMC chain (11000×17×4 Array{Float64, 3}):
Iterations = 1001:1:12000
Number of chains = 4
Samples per chain = 11000
Wall duration = 5.01 seconds
Compute duration = 19.13 seconds
parameters = log2ic50, a, b, c, σ²
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯
log2ic50 3.1001 0.0297 0.0002 25150.3685 27935.5938 1.0002 ⋯
a 0.1163 0.0038 0.0000 16860.1510 22570.9089 1.0001 ⋯
b 6.7350 1.0786 0.0069 26759.0743 25218.0167 1.0001 ⋯
c 0.0712 0.0034 0.0000 16323.0552 21352.9296 1.0001 ⋯
σ² 0.0001 0.0000 0.0000 32838.7703 28866.3125 1.0001 ⋯
1 column omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
log2ic50 3.0410 3.0804 3.1004 3.1202 3.1580
a 0.1089 0.1138 0.1163 0.1188 0.1238
b 5.0241 5.9818 6.5979 7.3261 9.2492
c 0.0643 0.0689 0.0712 0.0735 0.0778
σ² 0.0001 0.0001 0.0001 0.0001 0.0002
Let’s look again at the corner plot for the chains.
# Plot corner plot for chains
pairplot(chain[n_burnin+1:end, :, :]) PairPlots.
Compared to the previous example where we didn’t remove the outliers, the posterior distributions are more concentrated, especially for the \(b\) parameter.
Let’s now plot the posterior predictive check.
Random.seed!(42)
# Define unique concentrations
= sort(unique(data_clean.concentration_ugmL))
unique_concentrations
# Initialize matrix to store samples
= Array{Float64}(
y_samples undef, length(unique_concentrations), n_samples
)
# Loop through samples
for i in 1:n_samples
# Generate mean for sample
:, i] = logistic_log2(
y_samples[log2.(unique_concentrations),
:a][i],
chain[:b][i],
chain[:c][i],
chain[:log2ic50][i]
chain[
)# Add noise
:, i] .+= randn(length(unique_concentrations)) * √(chain[:σ²][i])
y_samples[√(chain[:σ²][i])
end # for
# Initialize figure
= Figure(size=(350, 300))
fig # Add axis
= Axis(
ax 1, 1],
fig[="antibiotic concentration",
xlabel="optical density",
ylabel="Posterior Predictive Check",
title=log2
xscale
)
# Plot samples
for i in 1:n_samples
lines!(
ax,
unique_concentrations,:, i],
y_samples[=(ColorSchemes.Paired_12[1], 0.5)
color
)end # for
# Plot data
scatter!(ax, data_clean.concentration_ugmL, data_clean.OD)
fig
This look much better. Removing the outliers improved the fit significantly.
Alternative parameterization
In the data, there are several concentrations at zero, which is not compatible with the log-scale model. However, we can rewrite Equation 1 without using the exponent as
\[ f(x) = \frac{a} {1+ \left(\frac{x}{\mathrm{IC}_{50}}\right)^b} + c \tag{14}\]
Let’s define a new model using this parameterization.
@doc raw"""
logistic_alt(x, a, b, c, ic50)
Compute the logistic function used to model the relationship between antibiotic
concentration and bacterial growth.
This function implements the following equation:
f(x) = a / (1 + (x / IC₅₀) ^ b) + c
# Arguments
- `x`: Antibiotic concentration (input variable)
- `a`: Maximum effect parameter (difference between upper and lower asymptotes)
- `b`: Slope parameter (steepness of the curve)
- `c`: Minimum effect parameter (lower asymptote)
- `ic50`: IC₅₀ parameter (concentration at which the effect is halfway between
the minimum and maximum)
# Returns
The computed effect (e.g., optical density) for the given antibiotic
concentration and parameters.
Note: This function is vectorized and can handle array inputs for `x`.
"""
function logistic_alt(x, a, b, c, ic50)
return @. a / (1.0 + (x / ic50)^b) + c
end
function logistic_alt(x, params)
return logistic_alt(x, params...)
end
We will use the previous function to detect outliers in the data. Now, we re-define the Bayesian model using this parameterization.
@model function logistic_alt_model(
Turing.::NamedTuple=NamedTuple()
x, y, prior_params
)# Define default prior parameters
= (
default_params =(0, 1),
ic50=(0, 1),
a=(0, 1),
b=(0, 1),
c=(0, 1)
σ²
)
# Merge default parameters with provided parameters
= merge(default_params, prior_params)
params
# Define priors
~ Turing.LogNormal(params.ic50...)
ic50 ~ Turing.LogNormal(params.a...)
a ~ Turing.LogNormal(params.b...)
b ~ Turing.truncated(Turing.Normal(params.c...), 0, Inf)
c ~ Turing.truncated(Turing.Normal(params.σ²...), 0, Inf)
σ²
# Define likelihood
~ Turing.MvNormal(
y logistic_alt(x, a, b, c, ic50),
LinearAlgebra.Diagonal(fill(σ², length(y)))
)end
Let’s now perform inference on synthetic data using this new parameterization.
Random.seed!(42)
# Define ground truth parameters
= 2^0.5
ic50_true = 1.0
a_true = 10.0
b_true = 0.01
σ²_true = 0.1
c_true
# Simulate data
= 2 .^ LinRange(-2.5, 2.5, 15)
x # Define mean of data
= logistic_alt(x, a_true, b_true, c_true, ic50_true)
ŷ # Add noise
= ŷ + randn(length(x)) * sqrt(σ²_true)
y
# Initialize figure
= Figure(size=(350, 300))
fig # Add axis
= Axis(
ax 1, 1],
fig[="antibiotic concentration",
xlabel="optical density",
ylabel=log2
xscale
)# Plot data
scatterlines!(ax, x, y)
fig
With the data simulated, let’s perform inference on the data.
Random.seed!(42)
# Perform inference
= logistic_alt_model(x, y)
model
# Define number of steps
= 10_000
n_burnin = 1_000
n_samples
= Turing.sample(
chain NUTS(), Turing.MCMCThreads(), n_burnin + n_samples, 4
model, Turing. )
Sampling (4 threads) 0%| | ETA: N/A
┌ Info: Found initial step size
└ ϵ = 1.6
┌ Info: Found initial step size
└ ϵ = 0.8
┌ Info: Found initial step size
└ ϵ = 0.4
┌ Info: Found initial step size
└ ϵ = 0.8
Sampling (4 threads) 25%|███████▌ | ETA: 0:00:19
Sampling (4 threads) 50%|███████████████ | ETA: 0:00:06
Sampling (4 threads) 75%|██████████████████████▌ | ETA: 0:00:02
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:06
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:06
Chains MCMC chain (11000×17×4 Array{Float64, 3}):
Iterations = 1001:1:12000
Number of chains = 4
Samples per chain = 11000
Wall duration = 5.49 seconds
Compute duration = 21.43 seconds
parameters = ic50, a, b, c, σ²
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯
ic50 1.3615 0.0783 0.0006 21996.9818 15834.2853 1.0001 ⋯
a 0.9776 0.0641 0.0005 14338.2651 16755.4871 1.0000 ⋯
b 7.5525 2.9801 0.0243 15179.8530 18757.0920 1.0000 ⋯
c 0.1183 0.0456 0.0004 14321.6452 10118.2654 1.0002 ⋯
σ² 0.0102 0.0066 0.0001 14265.0218 17759.7256 1.0002 ⋯
1 column omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
ic50 1.1980 1.3177 1.3639 1.4079 1.5098
a 0.8545 0.9372 0.9761 1.0169 1.1094
b 3.6070 5.6275 6.9973 8.7969 14.6326
c 0.0281 0.0886 0.1180 0.1470 0.2112
σ² 0.0037 0.0062 0.0085 0.0121 0.0270
Let’s look at the posterior distribution of the parameters.
# Plot corner plot for chains
pairplot(
PairPlots.+1:end, :, :],
chain[n_burninTruth(
PairPlots.
(;=ic50_true,
ic50=a_true,
a=b_true,
b=c_true,
c=σ²_true
σ²
)
) )
The model is still able to recover the ground truth parameters. Let’s now plot the posterior predictive check.
Random.seed!(42)
# Initialize matrix to store samples
= Array{Float64}(undef, length(x), n_samples)
y_samples
# Loop through samples
for i in 1:n_samples
# Generate mean for sample
:, i] = logistic_alt(
y_samples[
x,:a][i],
chain[:b][i],
chain[:c][i],
chain[:ic50][i]
chain[
)# Add noise
:, i] .+= randn(length(x)) * √(chain[:σ²][i])
y_samples[end # for
# Initialize figure
= Figure(size=(350, 300))
fig # Add axis
= Axis(
ax 1, 1],
fig[="antibiotic concentration",
xlabel="optical density",
ylabel="Posterior Predictive Check",
title=log2
xscale
)
# Plot samples
for i in 1:n_samples
lines!(ax, x, y_samples[:, i], color=(:gray, 0.05))
end # for
# Plot data
scatterlines!(ax, x, y)
fig
Everything looks good. Let’s again test it on the Iwasawa et al. (2022) data.
Random.seed!(42)
# Define prior parameters
= (
prior_params =(log(0.1), 0.1),
a=(0, 1),
b=(0, 1),
c=(0, 0.01)
σ²
)# Define data
= df_group[2]
data # Clean data
= fit_logistic_and_detect_outliers(
outliers_idx log2.(data.concentration_ugmL), data.OD, threshold=2
)= data[.!outliers_idx, :]
data_clean
# Perform inference
= logistic_alt_model(
model
data_clean.concentration_ugmL,
data_clean.OD,
prior_params
)
# Define number of steps
= 10_000
n_burnin = 1_000
n_samples
= Turing.sample(
chain NUTS(), Turing.MCMCThreads(), n_burnin + n_samples, 4
model, Turing. )
Sampling (4 threads) 0%| | ETA: N/A
┌ Info: Found initial step size
└ ϵ = 0.00625
┌ Info: Found initial step size
└ ϵ = 0.00625
┌ Info: Found initial step size
└ ϵ = 0.0125
┌ Info: Found initial step size
└ ϵ = 0.00625
Sampling (4 threads) 25%|███████▌ | ETA: 0:00:18
Sampling (4 threads) 50%|███████████████ | ETA: 0:00:06
Sampling (4 threads) 75%|██████████████████████▌ | ETA: 0:00:02
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:06
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:06
Chains MCMC chain (11000×17×4 Array{Float64, 3}):
Iterations = 1001:1:12000
Number of chains = 4
Samples per chain = 11000
Wall duration = 5.55 seconds
Compute duration = 21.2 seconds
parameters = ic50, a, b, c, σ²
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯
ic50 8.5831 0.1769 0.0011 25568.1627 25963.0962 1.0003 ⋯
a 0.1165 0.0038 0.0000 18041.1443 20901.4168 1.0002 ⋯
b 9.6366 1.5316 0.0097 26259.3918 25025.7026 1.0000 ⋯
c 0.0710 0.0035 0.0000 17294.7757 21118.9348 1.0002 ⋯
σ² 0.0001 0.0000 0.0000 29855.3383 28397.5187 1.0001 ⋯
1 column omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
ic50 8.2417 8.4631 8.5828 8.6994 8.9339
a 0.1091 0.1139 0.1164 0.1190 0.1241
b 7.2053 8.5622 9.4488 10.4964 13.1592
c 0.0640 0.0687 0.0711 0.0734 0.0777
σ² 0.0001 0.0001 0.0001 0.0001 0.0002
Let’s look at the posterior distribution of the parameters.
# Plot corner plot for chains
pairplot(chain[n_burnin+1:end, :, :]) PairPlots.
Nothing looks obviously wrong. Let’s plot the posterior predictive check.
Random.seed!(42)
# Define unique concentrations
= sort(unique(data_clean.concentration_ugmL))
unique_concentrations
# Initialize matrix to store samples
= Array{Float64}(
y_samples undef, length(unique_concentrations), n_samples
)
# Loop through samples
for i in 1:n_samples
# Generate mean for sample
:, i] = logistic_alt(
y_samples[
unique_concentrations,:a][i],
chain[:b][i],
chain[:c][i],
chain[:ic50][i]
chain[
)# Add noise
:, i] .+= randn(length(unique_concentrations)) * √(chain[:σ²][i])
y_samples[end # for
# Initialize figure
= Figure(size=(350, 300))
fig # Add axis
= Axis(
ax 1, 1],
fig[="antibiotic concentration",
xlabel="optical density",
ylabel="Posterior Predictive Check",
title=log2
xscale
)
# Plot samples
for i in 1:n_samples
lines!(
ax,
unique_concentrations,:, i],
y_samples[=(ColorSchemes.Paired_12[1], 0.5)
color
)end # for
# Plot data
scatter!(ax, data_clean.concentration_ugmL, data_clean.OD)
fig
With this parameterization, the model performs as well as the previous one.
Conclusion
In this notebook, we have performed Bayesian inference on the data from Iwasawa et al. (2022) using two different parameterizations of the logistic function. Both parameterizations perform equally well, and we are able to recover the ground truth parameters from simulated data. Furthermore, implementing an outlier detection scheme improved the fit significantly.