\(IC_{50}\)" /> Bayesian Inference of IC_{50} Values

Bayesian Inference of \(IC_{50}\) Values

Author
Affiliation

Manuel Razo-Mejia

Department of Biology, Stanford University

Keywords

Bayesian inference, Antibiotic resistance, \(IC_{50}\)

(c) This work is licensed under a Creative Commons Attribution License CC-BY 4.0. All code contained herein is licensed under an MIT license.

# Import project package
import Antibiotic

# Import package to handle DataFrames
import DataFrames as DF
import CSV

# Import library for Bayesian inference
import Turing

# Import library to list files
import Glob

# Import packages to work with data
import DataFrames as DF

# Load CairoMakie for plotting
using CairoMakie
import ColorSchemes

# Import packages for posterior sampling
import PairPlots

# Import basic math libraries
import LsqFit
import StatsBase
import LinearAlgebra
import Random

# Activate backend
CairoMakie.activate!()

# Set custom plotting style
Antibiotic.viz.theme_makie!()

(c) This work is licensed under a Creative Commons Attribution License CC-BY 4.0. All code contained herein is licensed under an MIT license.

Bayesian Inference of \(IC_{50}\) values

In this notebook, we will perform Bayesian inference on the \(IC_{50}\) values of the antibiotic resistance landscape. For this, we will use the raw OD620 measurements provided by Iwasawa et al. (2022).

Let’s begin by loading the data into a DataFrame.

# Load data into a DataFrame
df = CSV.read("./iwasawa_data/iwasawa_tidy.csv", DF.DataFrame)

first(df, 5)
5×10 DataFrame
Row antibiotic col OD design concentration_ugmL day strain_num env plate blank
String7 String7 Float64 Int64 Float64 Int64 Int64 String15 Int64 Bool
1 TET col1 0.0464 3 0.0 1 19 TETE4_in_KM 10 true
2 KM col1 0.0432 3 0.0 1 19 TETE4_in_KM 10 true
3 NFLX col1 0.0414 3 0.0 1 19 TETE4_in_KM 10 true
4 SS col1 0.0415 3 0.0 1 19 TETE4_in_KM 10 true
5 PLM col1 0.0427 3 0.0 1 19 TETE4_in_KM 10 true

To double-check that the structure of the table makes sense, let’s plot the time series for one example to see if the sequence agrees with the expectation.

# Define data to use
data = df[
    (df.antibiotic.=="KM").&(df.env.=="Parent_in_KM").&(df.strain_num.==13).&.!(df.blank).&(df.concentration_ugmL.>0),
    :]
# Remove blank measurement
# Group data by day
df_group = DF.groupby(data, :day)

# Initialize figure
fig = Figure(size=(500, 300))

# Add axis
ax = Axis(
    fig[1, 1],
    xlabel="antibiotic concentration",
    ylabel="OD₆₂₀",
    xscale=log2
)

# Define colors for plot
colors = get(ColorSchemes.Blues_9, LinRange(0.25, 1, length(df_group)))

# Loop through days
for (i, d) in enumerate(df_group)
    # Sort data by concentration
    DF.sort!(d, :concentration_ugmL)
    # Plot scatter line
    scatterlines!(
        ax, d.concentration_ugmL, d.OD, color=colors[i], label="$(first(d.day))"
    )
end # for

# Add legend to plot
fig[1, 2] = Legend(
    fig, ax, "day", framevisible=false, nbanks=3, labelsize=10
)

fig

The functional form used by the authors to fit the data is \[ f(x) = \frac{a} {1+\exp \left[b\left(\log _2 x-\log _2 IC_{50}\right)\right]} + c \tag{1}\]

where \(a\), \(b\), and \(c\) are nuisance parameters of the model, \(IC_{50}\) is the parameter of interest, and \(x\) is the antibiotic concentration. We can define a function to compute this model.

@doc raw"""
    logistic(x, a, b, c, ic50)

Compute the logistic function used to model the relationship between antibiotic
concentration and bacterial growth.

This function implements the following equation:

f(x) = a / (1 + exp(b * (log₂(x) - log₂(IC₅₀)))) + c

# Arguments
- `x`: Antibiotic concentration (input variable)
- `a`: Maximum effect parameter (difference between upper and lower asymptotes)
- `b`: Slope parameter (steepness of the curve)
- `c`: Minimum effect parameter (lower asymptote)
- `ic50`: IC₅₀ parameter (concentration at which the effect is halfway between
  the minimum and maximum)

# Returns
The computed effect (e.g., optical density) for the given antibiotic
concentration and parameters.

Note: This function is vectorized and can handle array inputs for `x`.
"""
function logistic(x, a, b, c, ic50)
    return @. a / (1.0 + exp(b * (log2(x) - log2(ic50)))) + c
end

To test the function, let’s plot the model for a set of parameters.

# Define parameters
a = 1.0
b = 1.0
c = 0.0
ic50 = 1.0

# Define concentration range
x = 10 .^ LinRange(-2.5, 2.5, 50)

# Compute model
y = logistic(x, a, b, c, ic50)

# Initialize figure
fig = Figure(size=(350, 300))
# Add axis
ax = Axis(
    fig[1, 1],
    xlabel="antibiotic concentration (a.u.)",
    ylabel="optical density",
    xscale=log10
)
# Plot model
lines!(ax, x, y)

fig

The function seems to work as expected. However, notice that in Eq. (1), the \(\mathrm{IC}_{50}\) goes into the logarithm. This is the parameter we will fit for. Thus, let’s define a function that takes \(\log_2(\mathrm{IC}_{50})\) and \(\log_2(x)\) as input instead.

@doc raw"""
    logistic_log2(log2x, a, b, c, log2ic50)

Compute the logistic function used to model the relationship between antibiotic
concentration and bacterial growth, using log2 inputs for concentration and
IC₅₀.

This function implements the following equation:

f(x) = a / (1 + exp(b * (log₂(x) - log₂(IC₅₀)))) + c

# Arguments
- `log2x`: log₂ of the antibiotic concentration (input variable)
- `a`: Maximum effect parameter (difference between upper and lower asymptotes)
- `b`: Slope parameter (steepness of the curve)
- `c`: Minimum effect parameter (lower asymptote)
- `log2ic50`: log₂ of the IC₅₀ parameter

# Returns
The computed effect (e.g., optical density) for the given log₂ antibiotic
concentration and parameters.

Note: This function is vectorized and can handle array inputs for `log2x`.
"""
function logistic_log2(log2x, a, b, c, log2ic50)
    return @. a / (1.0 + exp(b * (log2x - log2ic50))) + c
end

Bayesian model

Given the model presented in Eq. (1), and the data, our objective is to infer the value of all parameters. By Bayes theorem, we write

\[ \pi(IC_{50}, a, b, c \mid \text{data}) = \frac{\pi(\text{data} \mid IC_{50}, a, b, c) \pi(IC_{50}, a, b, c)} {\pi(\text{data})}, \tag{2}\]

where \(\text{data}\) consists of the pairs of antibiotic concentration and optical density.

Likelihood \(\pi(\text{data} \mid IC_{50}, a, b, c)\)

Let’s begin by defining the likelihood function. For simplicity, we assume each datum is independent and identically distributed (i.i.d.) and write

\[ \pi(\text{data} \mid IC_{50}, a, b, c) = \prod_{i=1}^n \pi(d_i \mid IC_{50}, a, b, c), \tag{3}\]

where \(d_i = (x_i, y_i)\) is the \(i\)-th pair of antibiotic concentration and optical density, respectively, and \(n\) is the total number of data points. As a first pass, we assume that our experimental measurements can be expressed as

\[ y_i = f(x_i, IC_{50}, a, b, c) + \epsilon_i, \tag{4}\]

where \(\epsilon_i\) is the experimental error. Furthermore, we assume that the experimental error is normally distributed, i.e.,

\[ \epsilon_i \sim \mathcal{N}(0, \sigma^2), \tag{5}\]

where \(\sigma^2\) is an unknown variance parameter that must be included in our inference. Notice that we assume the same variance parameter for all data points since \(\sigma^2\) is not indexed by \(i\).

Given this likelihood function, we must update our inference on the parameters as

\[ \pi(IC_{50}, a, b, c, \sigma^2 \mid \text{data}) = \frac{\pi(\text{data} \mid IC_{50}, a, b, c, \sigma^2) \pi(IC_{50}, a, b, c, \sigma^2)} {\pi(\text{data})}, \tag{6}\]

to include the new parameter \(\sigma^2\).

Our likelihood function is then of the form \[ y_i \mid IC_{50}, a, b, c, \sigma^2 \sim \mathcal{N}(f(x_i, IC_{50}, a, b, c), \sigma^2). \tag{7}\]

Prior \(\pi(IC_{50}, a, b, c, \sigma^2)\)

For the prior, we assume that all parameters are independent and write \[ \pi(IC_{50}, a, b, c, \sigma^2) = \pi(IC_{50}) \pi(a) \pi(b) \pi(c) \pi(\sigma^2). \tag{8}\]

Let’s detail each prior.

  1. \(IC_{50}\): The \(IC_{50}\) is a strictly positive parameter. However, we will fit for \(\log_2(IC_{50})\). Thus, we will use a normal prior for \(\log_2(IC_{50})\). This means we have

\[ \log_2(IC_{50}) \sim \mathcal{N}( \mu_{\log_2(IC_{50})}, \sigma_{\log_2(IC_{50})}^2 ). \tag{9}\]

  1. \(a\): This nuisance parameter scales the logistic function. Again, the natural scale for this parameter is a strictly positive real number. Thus, we will use a lognormal prior for \(a\). This means we have

\[ a \sim \text{LogNormal}(\mu_a, \sigma_a^2). \tag{10}\]

  1. \(b\): This parameter controls the steepness of the logistic function. Again, the natural scale for this parameter is a strictly positive real number. Thus, we will use a lognormal prior for \(b\). This means we have

\[ b \sim \text{LogNormal}(\mu_b, \sigma_b^2). \tag{11}\]

  1. \(c\): This parameter controls the minimum value of the logistic function. Since this is a strictly positive real number that does not necessarily scale with the data, we will use a half-normal prior for \(c\). This means we have

\[ c \sim \text{Half-}\mathcal{N}(0, \sigma_c^2). \tag{12}\]

  1. \(\sigma^2\): This parameter controls the variance of the experimental error. Since this is a strictly positive real number that does not necessarily scale with the data, we will use a half-normal prior for \(\sigma^2\). This means we have

\[ \sigma^2 \sim \text{Half-}\mathcal{N}(0, \sigma_{\sigma^2}^2). \tag{13}\]

With all of this in place, we are ready to define a Turing model to perform Bayesian inference on the parameters of the model.

Turing.@model function logistic_model(
    log2x, y, prior_params::NamedTuple=NamedTuple()
)
    # Define default prior parameters
    default_params = (
        log2ic50=(0, 1),
        a=(0, 1),
        b=(0, 1),
        c=(0, 1),
        σ²=(0, 1)
    )

    # Merge default parameters with provided parameters
    params = merge(default_params, prior_params)

    # Define priors
    log2ic50 ~ Turing.Normal(params.log2ic50...)
    a ~ Turing.LogNormal(params.a...)
    b ~ Turing.LogNormal(params.b...)
    c ~ Turing.truncated(Turing.Normal(params.c...), 0, Inf)
    σ² ~ Turing.truncated(Turing.Normal(params.σ²...), 0, Inf)

    # Define likelihood
    y ~ Turing.MvNormal(
        logistic_log2(log2x, a, b, c, log2ic50),
        LinearAlgebra.Diagonal(fill(σ², length(y)))
    )
end

Having defined the model, we can now perform inference. First, let’s perform inference on simulated data. For this, we first simulate a single titration curve with known parameters.

Random.seed!(42)
# Define ground truth parameters
log2ic50_true = 0.5
a_true = 1.0
b_true = 10.0
σ²_true = 0.01
c_true = 0.1

# Simulate data
log2x = LinRange(-2.5, 2.5, 15)
# Define mean of data
= logistic_log2(log2x, a_true, b_true, c_true, log2ic50_true)
# Add noise
y =.+ randn(length(log2x)) * (σ²_true)

# Initialize figure
fig = Figure(size=(350, 300))
# Add axis
ax = Axis(
    fig[1, 1],
    xlabel="antibiotic concentration",
    ylabel="optical density",
    xscale=log2
)
# Plot data
scatterlines!(ax, 2.0 .^ log2x, y)

fig

With the data simulated, let’s perform inference on the data. For this, we will use the NUTS sampler. To run multiple chains in parallel, we will use the MCMCThreads option.

Random.seed!(42)
# Perform inference
model = logistic_model(log2x, y)

# Define number of steps
n_burnin = 10_000
n_samples = 1_000

# Run sampler using
chain = Turing.sample(
    model, Turing.NUTS(), Turing.MCMCThreads(), n_burnin + n_samples, 4
)
Sampling (4 threads)   0%|                              |  ETA: N/A
┌ Info: Found initial step size
└   ϵ = 0.8
┌ Info: Found initial step size
└   ϵ = 0.8
┌ Info: Found initial step size
└   ϵ = 0.4
┌ Info: Found initial step size
└   ϵ = 0.8
Sampling (4 threads)  25%|███████▌                      |  ETA: 0:00:39
Sampling (4 threads)  50%|███████████████               |  ETA: 0:00:13
Sampling (4 threads)  75%|██████████████████████▌       |  ETA: 0:00:04
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:13
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:13
Chains MCMC chain (11000×17×4 Array{Float64, 3}):

Iterations        = 1001:1:12000
Number of chains  = 4
Samples per chain = 11000
Wall duration     = 9.39 seconds
Compute duration  = 36.29 seconds
parameters        = log2ic50, a, b, c, σ²
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse     ess_bulk     ess_tail      rhat ⋯
      Symbol   Float64   Float64   Float64      Float64      Float64   Float64 ⋯

    log2ic50    0.4723    0.0718    0.0005   23407.8672   14077.0993    1.0003 ⋯
           a    0.9668    0.0621    0.0005   16034.1809   16294.6985    1.0006 ⋯
           b    8.7902    4.6027    0.0449   13409.6206   13035.6999    1.0003 ⋯
           c    0.1192    0.0449    0.0004   15396.5130   11022.7335    1.0004 ⋯
          σ²    0.0111    0.0076    0.0001   12720.8496   16686.4143    1.0005 ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

    log2ic50    0.3173    0.4352    0.4745    0.5141    0.6048
           a    0.8482    0.9279    0.9649    1.0042    1.0961
           b    3.0834    5.7843    7.9162   10.6434   20.2429
           c    0.0294    0.0902    0.1194    0.1475    0.2093
          σ²    0.0038    0.0065    0.0091    0.0132    0.0302

Let’s look at the posterior distribution of the parameters. For this, we will use the PairPlots.jl package.

# Plot corner plot for chains
PairPlots.pairplot(
    chain[n_burnin+1:end, :, :],
    PairPlots.Truth(
        (;
        log2ic50=log2ic50_true,
        a=a_true,
        b=b_true,
        c=c_true,
        σ²=σ²_true
    )
    )
)

All parameters are well-constrained by the data, and we are able to recover the ground truth values.

Let’s plot the posterior predictive checks with the data to see how the model performs.

Random.seed!(42)

# Initialize matrix to store samples
y_samples = Array{Float64}(undef, length(log2x), n_samples)

# Loop through samples
for i in 1:n_samples
    # Generate mean for sample
    y_samples[:, i] = logistic_log2(
        log2x,
        chain[:a][i],
        chain[:b][i],
        chain[:c][i],
        chain[:log2ic50][i]
    )
    # Add noise
    y_samples[:, i] .+= randn(length(log2x)) * sqrt(chain[:σ²][i])
end # for

# Initialize figure
fig = Figure(size=(350, 300))
# Add axis
ax = Axis(
    fig[1, 1],
    xlabel="antibiotic concentration",
    ylabel="optical density",
    title="Posterior Predictive Check",
    xscale=log2
)

# Plot samples
for i in 1:n_samples
    lines!(ax, 2.0 .^ log2x, y_samples[:, i], color=(:gray, 0.05))
end # for

# Plot data
scatterlines!(ax, 2.0 .^ log2x, y)

fig

The fit looks excellent. Let’s now perform inference on the Iwasawa data. We will use one example dataset. Moreover, we will set informative priors for the parameters.

Random.seed!(42)
# Group data by antibiotic, environment, and day
df_group = DF.groupby(
    df[(.!df.blank).&(df.concentration_ugmL.>0), :],
    [:antibiotic, :env, :day]
)

# Extract data
data = df_group[2]

# Define prior parameters
prior_params = (
    a=(log(0.1), 0.1),
    b=(0, 1),
    c=(0, 1),
    σ²=(0, 0.1)
)
# Perform inference
model = logistic_model(
    log2.(data.concentration_ugmL),
    data.OD,
    prior_params
)

# Define number of steps
n_burnin = 10_000
n_samples = 1_000

chain = Turing.sample(
    model, Turing.NUTS(), Turing.MCMCThreads(), n_burnin + n_samples, 4
)
Sampling (4 threads)   0%|                              |  ETA: N/A
┌ Info: Found initial step size
└   ϵ = 0.2
┌ Info: Found initial step size
└   ϵ = 0.05
┌ Info: Found initial step size
└   ϵ = 0.2
┌ Info: Found initial step size
└   ϵ = 0.05
Sampling (4 threads)  25%|███████▌                      |  ETA: 0:00:16
Sampling (4 threads)  50%|███████████████               |  ETA: 0:00:05
Sampling (4 threads)  75%|██████████████████████▌       |  ETA: 0:00:02
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:05
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:05
Chains MCMC chain (11000×17×4 Array{Float64, 3}):

Iterations        = 1001:1:12000
Number of chains  = 4
Samples per chain = 11000
Wall duration     = 4.78 seconds
Compute duration  = 17.88 seconds
parameters        = log2ic50, a, b, c, σ²
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse     ess_bulk     ess_tail      rhat ⋯
      Symbol   Float64   Float64   Float64      Float64      Float64   Float64 ⋯

    log2ic50    3.1585    0.0611    0.0004   32325.5670   25247.0891    1.0001 ⋯
           a    0.1013    0.0074    0.0000   28541.0328   27631.5354    1.0001 ⋯
           b   15.8931   10.4326    0.0655   31581.5518   26387.0900    1.0000 ⋯
           c    0.0873    0.0073    0.0000   27591.1350   24819.1854    1.0002 ⋯
          σ²    0.0015    0.0002    0.0000   37231.0264   30213.0584    1.0000 ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

    log2ic50    3.0427    3.1192    3.1589    3.1963    3.2817
           a    0.0872    0.0962    0.1012    0.1063    0.1162
           b    4.6324    9.2475   13.4282   19.5750   41.4435
           c    0.0726    0.0824    0.0874    0.0923    0.1013
          σ²    0.0011    0.0013    0.0014    0.0016    0.0020

Let’s look at the corner plot for the chains.

# Plot corner plot for chains
PairPlots.pairplot(chain[n_burnin+1:end, :, :])

This looks good. Let’s now plot the posterior predictive check.

Random.seed!(42)

# Define unique concentrations
unique_concentrations = sort(unique(data.concentration_ugmL))

# Initialize matrix to store samples
y_samples = Array{Float64}(
    undef, length(unique_concentrations), n_samples
)

# Loop through samples
for i in 1:n_samples
    # Generate mean for sample
    y_samples[:, i] = logistic_log2(
        log2.(unique_concentrations),
        chain[:a][i],
        chain[:b][i],
        chain[:c][i],
        chain[:log2ic50][i]
    )
    # Add noise
    y_samples[:, i] .+= randn(length(unique_concentrations)) * (chain[:σ²][i])
    (chain[:σ²][i])
end # for

# Initialize figure
fig = Figure(size=(350, 300))
# Add axis
ax = Axis(
    fig[1, 1],
    xlabel="antibiotic concentration",
    ylabel="optical density",
    title="Posterior Predictive Check",
    xscale=log2
)

# Plot samples
for i in 1:n_samples
    lines!(
        ax,
        unique_concentrations,
        y_samples[:, i],
        color=(ColorSchemes.Paired_12[1], 0.5)
    )
end # for

# Plot data
scatter!(ax, data.concentration_ugmL, data.OD)

fig

The presence of the outlier measurements expands the posterior predictive checks uncertainty.

Let’s look into how to detect outliers in the data.

Residual-based outlier detection

Our naive approach to detect outliers is to fit the logistic model deterministically and then identify points that are more than 3 standard deviations from the mean. This is known as “residual-based outlier detection”.

function logistic_log2(log2x, params)
    return logistic_log2(log2x, params...)
end

"""
    fit_logistic_and_detect_outliers(log2x, y; threshold=3)

Fit a logistic model to the given data and detect outliers based on residuals.

This function performs the following steps:
1. Fits a logistic model to the log2-transformed x-values and y-values.
2. Calculates residuals between the fitted model and actual y-values.
3. Identifies outliers as points with residuals exceeding a specified threshold.

# Arguments
- `log2x`: Array of log2-transformed x-values (typically concentrations).
- `y`: Array of y-values (typically optical density measurements).
- `threshold`: Number of standard deviations beyond which a point is considered an outlier. Default is 3.

# Returns
- A boolean array indicating which points are outliers (true for outliers).

# Notes
- The function uses a logistic model of the form: 
    y = a / (1 + exp(b * (log2x - log2ic50))) + c
- Initial parameter guesses are made based on the input data.
- The LsqFit package is used for curve fitting.
- Outliers are determined by comparing the absolute residuals to the threshold * standard deviation of residuals.
"""
function fit_logistic_and_detect_outliers(log2x, y; threshold=3)
    # Initial parameter guess
    p0 = [0.1, 1.0, maximum(y) - minimum(y), StatsBase.median(log2x)]

    # Fit the logistic model
    fit = LsqFit.curve_fit(logistic_log2, log2x, y, p0)

    # Calculate residuals
    residuals = y - logistic_log2(log2x, fit.param)

    # Calculate standard deviation of residuals
    σ = StatsBase.std(residuals)

    # Identify outliers
    outliers_idx = abs.(residuals) .> threshold * σ

    # Return outlier indices
    return outliers_idx
end

Let’s test this function by trying to remove the outliers from the data we used in the previous section.

# Locate outliers
outliers_idx = fit_logistic_and_detect_outliers(
    log2.(data.concentration_ugmL), data.OD, threshold=2
)

# Plot the results
fig = Figure(size=(450, 300))

ax = Axis(
    fig[1, 1],
    xlabel="antibiotic concentration",
    ylabel="optical density",
    # xscale=log2
)

# Plot the original data
scatter!(
    ax,
    data.concentration_ugmL,
    data.OD,
    color=ColorSchemes.Paired_12[1],
    label="data"
)

# Plot the cleaned data
scatter!(
    ax,
    data.concentration_ugmL[.!outliers_idx] .+ 0.5,
    data.OD[.!outliers_idx], color=ColorSchemes.Paired_12[2],
    label="cleaned data"
)

# Add legend
Legend(fig[1, 2], ax)

fig

The detection of outliers worked really well. Let’s now perform inference on the cleaned data.

Random.seed!(42)
# Group data by antibiotic, environment, and day
df_group = DF.groupby(
    df[(.!df.blank).&(df.concentration_ugmL.>0), :],
    [:antibiotic, :env, :day]
)

# Extract data
data = df_group[2]

# Find outliers
outliers_idx = fit_logistic_and_detect_outliers(
    log2.(data.concentration_ugmL), data.OD, threshold=2
)
# Remove outliers
data_clean = data[.!outliers_idx, :]

# Define prior parameters
prior_params = (
    a=(log(0.1), 0.1),
    b=(0, 1),
    c=(0, 1),
    σ²=(0, 0.1)
)

# Perform inference
model = logistic_model(
    log2.(data_clean.concentration_ugmL),
    data_clean.OD,
    prior_params
)

# Define number of steps
n_burnin = 10_000
n_samples = 1_000

chain = Turing.sample(
    model, Turing.NUTS(), Turing.MCMCThreads(), n_burnin + n_samples, 4
)
Sampling (4 threads)   0%|                              |  ETA: N/A
┌ Info: Found initial step size
└   ϵ = 0.2
┌ Info: Found initial step size
└   ϵ = 0.05
┌ Info: Found initial step size
└   ϵ = 0.2
┌ Info: Found initial step size
└   ϵ = 0.05
Sampling (4 threads)  25%|███████▌                      |  ETA: 0:00:17
Sampling (4 threads)  50%|███████████████               |  ETA: 0:00:06
Sampling (4 threads)  75%|██████████████████████▌       |  ETA: 0:00:02
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:05
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:05
Chains MCMC chain (11000×17×4 Array{Float64, 3}):

Iterations        = 1001:1:12000
Number of chains  = 4
Samples per chain = 11000
Wall duration     = 5.01 seconds
Compute duration  = 19.13 seconds
parameters        = log2ic50, a, b, c, σ²
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse     ess_bulk     ess_tail      rhat ⋯
      Symbol   Float64   Float64   Float64      Float64      Float64   Float64 ⋯

    log2ic50    3.1001    0.0297    0.0002   25150.3685   27935.5938    1.0002 ⋯
           a    0.1163    0.0038    0.0000   16860.1510   22570.9089    1.0001 ⋯
           b    6.7350    1.0786    0.0069   26759.0743   25218.0167    1.0001 ⋯
           c    0.0712    0.0034    0.0000   16323.0552   21352.9296    1.0001 ⋯
          σ²    0.0001    0.0000    0.0000   32838.7703   28866.3125    1.0001 ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

    log2ic50    3.0410    3.0804    3.1004    3.1202    3.1580
           a    0.1089    0.1138    0.1163    0.1188    0.1238
           b    5.0241    5.9818    6.5979    7.3261    9.2492
           c    0.0643    0.0689    0.0712    0.0735    0.0778
          σ²    0.0001    0.0001    0.0001    0.0001    0.0002

Let’s look again at the corner plot for the chains.

# Plot corner plot for chains
PairPlots.pairplot(chain[n_burnin+1:end, :, :])

Compared to the previous example where we didn’t remove the outliers, the posterior distributions are more concentrated, especially for the \(b\) parameter.

Let’s now plot the posterior predictive check.

Random.seed!(42)

# Define unique concentrations
unique_concentrations = sort(unique(data_clean.concentration_ugmL))

# Initialize matrix to store samples
y_samples = Array{Float64}(
    undef, length(unique_concentrations), n_samples
)

# Loop through samples
for i in 1:n_samples
    # Generate mean for sample
    y_samples[:, i] = logistic_log2(
        log2.(unique_concentrations),
        chain[:a][i],
        chain[:b][i],
        chain[:c][i],
        chain[:log2ic50][i]
    )
    # Add noise
    y_samples[:, i] .+= randn(length(unique_concentrations)) * (chain[:σ²][i])
    (chain[:σ²][i])
end # for

# Initialize figure
fig = Figure(size=(350, 300))
# Add axis
ax = Axis(
    fig[1, 1],
    xlabel="antibiotic concentration",
    ylabel="optical density",
    title="Posterior Predictive Check",
    xscale=log2
)

# Plot samples
for i in 1:n_samples
    lines!(
        ax,
        unique_concentrations,
        y_samples[:, i],
        color=(ColorSchemes.Paired_12[1], 0.5)
    )
end # for

# Plot data
scatter!(ax, data_clean.concentration_ugmL, data_clean.OD)

fig

This look much better. Removing the outliers improved the fit significantly.

Alternative parameterization

In the data, there are several concentrations at zero, which is not compatible with the log-scale model. However, we can rewrite Equation 1 without using the exponent as

\[ f(x) = \frac{a} {1+ \left(\frac{x}{\mathrm{IC}_{50}}\right)^b} + c \tag{14}\]

Let’s define a new model using this parameterization.

@doc raw"""
    logistic_alt(x, a, b, c, ic50)

Compute the logistic function used to model the relationship between antibiotic
concentration and bacterial growth.

This function implements the following equation:

f(x) = a / (1 + (x / IC₅₀) ^ b) + c

# Arguments
- `x`: Antibiotic concentration (input variable)
- `a`: Maximum effect parameter (difference between upper and lower asymptotes)
- `b`: Slope parameter (steepness of the curve)
- `c`: Minimum effect parameter (lower asymptote)
- `ic50`: IC₅₀ parameter (concentration at which the effect is halfway between
  the minimum and maximum)

# Returns
The computed effect (e.g., optical density) for the given antibiotic
concentration and parameters.

Note: This function is vectorized and can handle array inputs for `x`.
"""
function logistic_alt(x, a, b, c, ic50)
    return @. a / (1.0 + (x / ic50)^b) + c
end

function logistic_alt(x, params)
    return logistic_alt(x, params...)
end

We will use the previous function to detect outliers in the data. Now, we re-define the Bayesian model using this parameterization.

Turing.@model function logistic_alt_model(
    x, y, prior_params::NamedTuple=NamedTuple()
)
    # Define default prior parameters
    default_params = (
        ic50=(0, 1),
        a=(0, 1),
        b=(0, 1),
        c=(0, 1),
        σ²=(0, 1)
    )

    # Merge default parameters with provided parameters
    params = merge(default_params, prior_params)

    # Define priors
    ic50 ~ Turing.LogNormal(params.ic50...)
    a ~ Turing.LogNormal(params.a...)
    b ~ Turing.LogNormal(params.b...)
    c ~ Turing.truncated(Turing.Normal(params.c...), 0, Inf)
    σ² ~ Turing.truncated(Turing.Normal(params.σ²...), 0, Inf)

    # Define likelihood
    y ~ Turing.MvNormal(
        logistic_alt(x, a, b, c, ic50),
        LinearAlgebra.Diagonal(fill(σ², length(y)))
    )
end

Let’s now perform inference on synthetic data using this new parameterization.

Random.seed!(42)
# Define ground truth parameters
ic50_true = 2^0.5
a_true = 1.0
b_true = 10.0
σ²_true = 0.01
c_true = 0.1

# Simulate data
x = 2 .^ LinRange(-2.5, 2.5, 15)
# Define mean of data
= logistic_alt(x, a_true, b_true, c_true, ic50_true)
# Add noise
y =+ randn(length(x)) * sqrt(σ²_true)

# Initialize figure
fig = Figure(size=(350, 300))
# Add axis
ax = Axis(
    fig[1, 1],
    xlabel="antibiotic concentration",
    ylabel="optical density",
    xscale=log2
)
# Plot data
scatterlines!(ax, x, y)

fig

With the data simulated, let’s perform inference on the data.

Random.seed!(42)
# Perform inference
model = logistic_alt_model(x, y)

# Define number of steps
n_burnin = 10_000
n_samples = 1_000

chain = Turing.sample(
    model, Turing.NUTS(), Turing.MCMCThreads(), n_burnin + n_samples, 4
)
Sampling (4 threads)   0%|                              |  ETA: N/A
┌ Info: Found initial step size
└   ϵ = 1.6
┌ Info: Found initial step size
└   ϵ = 0.8
┌ Info: Found initial step size
└   ϵ = 0.4
┌ Info: Found initial step size
└   ϵ = 0.8
Sampling (4 threads)  25%|███████▌                      |  ETA: 0:00:19
Sampling (4 threads)  50%|███████████████               |  ETA: 0:00:06
Sampling (4 threads)  75%|██████████████████████▌       |  ETA: 0:00:02
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:06
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:06
Chains MCMC chain (11000×17×4 Array{Float64, 3}):

Iterations        = 1001:1:12000
Number of chains  = 4
Samples per chain = 11000
Wall duration     = 5.49 seconds
Compute duration  = 21.43 seconds
parameters        = ic50, a, b, c, σ²
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse     ess_bulk     ess_tail      rhat ⋯
      Symbol   Float64   Float64   Float64      Float64      Float64   Float64 ⋯

        ic50    1.3615    0.0783    0.0006   21996.9818   15834.2853    1.0001 ⋯
           a    0.9776    0.0641    0.0005   14338.2651   16755.4871    1.0000 ⋯
           b    7.5525    2.9801    0.0243   15179.8530   18757.0920    1.0000 ⋯
           c    0.1183    0.0456    0.0004   14321.6452   10118.2654    1.0002 ⋯
          σ²    0.0102    0.0066    0.0001   14265.0218   17759.7256    1.0002 ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

        ic50    1.1980    1.3177    1.3639    1.4079    1.5098
           a    0.8545    0.9372    0.9761    1.0169    1.1094
           b    3.6070    5.6275    6.9973    8.7969   14.6326
           c    0.0281    0.0886    0.1180    0.1470    0.2112
          σ²    0.0037    0.0062    0.0085    0.0121    0.0270

Let’s look at the posterior distribution of the parameters.

# Plot corner plot for chains
PairPlots.pairplot(
    chain[n_burnin+1:end, :, :],
    PairPlots.Truth(
        (;
        ic50=ic50_true,
        a=a_true,
        b=b_true,
        c=c_true,
        σ²=σ²_true
    )
    )
)

The model is still able to recover the ground truth parameters. Let’s now plot the posterior predictive check.

Random.seed!(42)

# Initialize matrix to store samples
y_samples = Array{Float64}(undef, length(x), n_samples)

# Loop through samples
for i in 1:n_samples
    # Generate mean for sample
    y_samples[:, i] = logistic_alt(
        x,
        chain[:a][i],
        chain[:b][i],
        chain[:c][i],
        chain[:ic50][i]
    )
    # Add noise
    y_samples[:, i] .+= randn(length(x)) * (chain[:σ²][i])
end # for

# Initialize figure
fig = Figure(size=(350, 300))
# Add axis
ax = Axis(
    fig[1, 1],
    xlabel="antibiotic concentration",
    ylabel="optical density",
    title="Posterior Predictive Check",
    xscale=log2
)

# Plot samples
for i in 1:n_samples
    lines!(ax, x, y_samples[:, i], color=(:gray, 0.05))
end # for

# Plot data
scatterlines!(ax, x, y)

fig

Everything looks good. Let’s again test it on the Iwasawa et al. (2022) data.

Random.seed!(42)
# Define prior parameters
prior_params = (
    a=(log(0.1), 0.1),
    b=(0, 1),
    c=(0, 1),
    σ²=(0, 0.01)
)
# Define data
data = df_group[2]
# Clean data
outliers_idx = fit_logistic_and_detect_outliers(
    log2.(data.concentration_ugmL), data.OD, threshold=2
)
data_clean = data[.!outliers_idx, :]

# Perform inference
model = logistic_alt_model(
    data_clean.concentration_ugmL,
    data_clean.OD,
    prior_params
)

# Define number of steps
n_burnin = 10_000
n_samples = 1_000

chain = Turing.sample(
    model, Turing.NUTS(), Turing.MCMCThreads(), n_burnin + n_samples, 4
)
Sampling (4 threads)   0%|                              |  ETA: N/A
┌ Info: Found initial step size
└   ϵ = 0.00625
┌ Info: Found initial step size
└   ϵ = 0.00625
┌ Info: Found initial step size
└   ϵ = 0.0125
┌ Info: Found initial step size
└   ϵ = 0.00625
Sampling (4 threads)  25%|███████▌                      |  ETA: 0:00:18
Sampling (4 threads)  50%|███████████████               |  ETA: 0:00:06
Sampling (4 threads)  75%|██████████████████████▌       |  ETA: 0:00:02
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:06
Sampling (4 threads) 100%|██████████████████████████████| Time: 0:00:06
Chains MCMC chain (11000×17×4 Array{Float64, 3}):

Iterations        = 1001:1:12000
Number of chains  = 4
Samples per chain = 11000
Wall duration     = 5.55 seconds
Compute duration  = 21.2 seconds
parameters        = ic50, a, b, c, σ²
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse     ess_bulk     ess_tail      rhat ⋯
      Symbol   Float64   Float64   Float64      Float64      Float64   Float64 ⋯

        ic50    8.5831    0.1769    0.0011   25568.1627   25963.0962    1.0003 ⋯
           a    0.1165    0.0038    0.0000   18041.1443   20901.4168    1.0002 ⋯
           b    9.6366    1.5316    0.0097   26259.3918   25025.7026    1.0000 ⋯
           c    0.0710    0.0035    0.0000   17294.7757   21118.9348    1.0002 ⋯
          σ²    0.0001    0.0000    0.0000   29855.3383   28397.5187    1.0001 ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

        ic50    8.2417    8.4631    8.5828    8.6994    8.9339
           a    0.1091    0.1139    0.1164    0.1190    0.1241
           b    7.2053    8.5622    9.4488   10.4964   13.1592
           c    0.0640    0.0687    0.0711    0.0734    0.0777
          σ²    0.0001    0.0001    0.0001    0.0001    0.0002

Let’s look at the posterior distribution of the parameters.

# Plot corner plot for chains
PairPlots.pairplot(chain[n_burnin+1:end, :, :])

Nothing looks obviously wrong. Let’s plot the posterior predictive check.

Random.seed!(42)

# Define unique concentrations
unique_concentrations = sort(unique(data_clean.concentration_ugmL))

# Initialize matrix to store samples
y_samples = Array{Float64}(
    undef, length(unique_concentrations), n_samples
)

# Loop through samples
for i in 1:n_samples
    # Generate mean for sample
    y_samples[:, i] = logistic_alt(
        unique_concentrations,
        chain[:a][i],
        chain[:b][i],
        chain[:c][i],
        chain[:ic50][i]
    )
    # Add noise
    y_samples[:, i] .+= randn(length(unique_concentrations)) * (chain[:σ²][i])
end # for

# Initialize figure
fig = Figure(size=(350, 300))
# Add axis
ax = Axis(
    fig[1, 1],
    xlabel="antibiotic concentration",
    ylabel="optical density",
    title="Posterior Predictive Check",
    xscale=log2
)

# Plot samples
for i in 1:n_samples
    lines!(
        ax,
        unique_concentrations,
        y_samples[:, i],
        color=(ColorSchemes.Paired_12[1], 0.5)
    )
end # for

# Plot data
scatter!(ax, data_clean.concentration_ugmL, data_clean.OD)

fig

With this parameterization, the model performs as well as the previous one.

Conclusion

In this notebook, we have performed Bayesian inference on the data from Iwasawa et al. (2022) using two different parameterizations of the logistic function. Both parameterizations perform equally well, and we are able to recover the ground truth parameters from simulated data. Furthermore, implementing an outlier detection scheme improved the fit significantly.

References

Iwasawa, Junichiro, Tomoya Maeda, Atsushi Shibai, Hazuki Kotani, Masako Kawada, and Chikara Furusawa. 2022. “Analysis of the Evolution of Resistance to Multiple Antibiotics Enables Prediction of the Escherichia Coli Phenotype-Based Fitness Landscape.” Edited by J. Arjan G. M. De Visser. PLOS Biology 20 (12): e3001920. https://doi.org/10.1371/journal.pbio.3001920.