This appendix shows how the generalized Stein-Fojo model can be implemented in a Bayesian framework using the brms package in R.
Setup and load data
First we need to load the necessary packages and set some default options for the MCMC sampling. We also set the theme for the plots to theme_bw with a base size of 12.
Show the code
library(bayesplot)library(brms)library(ggplot2)library(gt)library(here)library(janitor)library(jmpost)library(modelr)library(posterior)library(readxl)library(rstan)library(tidybayes)library(tidyverse)library(truncnorm)if (require(cmdstanr)) {# If cmdstanr is available, instruct brms to use cmdstanr as backend # and cache all Stan binariesoptions(brms.backend ="cmdstanr", cmdstanr_write_stan_file_dir =here("_brms-cache"))dir.create(here("_brms-cache"), FALSE) # create cache directory if not yet available} else { rstan::rstan_options(auto_write =TRUE)}# MCMC optionsoptions(mc.cores =4)ITER <-1000# number of sampling iterations after warm upWARMUP <-2000# number of warm up iterationsCHAINS <-4BAYES.SEED <-878REFRESH <-500theme_set(theme_bw(base_size =12))
We also need a small function definition, which is still missing in brms:
We will use the publicly published tumor size data from the OAK study, see here. In particular we are using the S1 data set, which is the fully anonymized data set used in the publication. For simplicity, we have copied the data set in this GitHub repository.
for positive times \(t_{ij}\). Again, if the time \(t\) is negative, i.e. the treatment has not started yet, then it is reasonable to assume that the tumor cannot shrink yet. That is, we have then \(\phi_i = 0\). Therefore, the final model for the mean SLD is:
In terms of likelihood and priors, we can use the same assumptions as in the previous model. The only difference is that we have to model the \(\phi\) parameter. We can use a logit-normal distribution for this parameter. This is a normal distribution on the logit scale, which is then transformed to the unit interval. \[
\psi_{\phi i} \sim \text{LogitNormal}(\text{logit}(0.5) = 0, 0.5)
\]
Fit model
We can now fit the model using brms. The structure is determined by the model formula:
Warning: There were 41 divergent transitions after warmup. Increasing
adapt_delta above 0.8 may help. See
http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
Family: gaussian
Links: mu = identity; sigma = log
Formula: sld ~ eta
eta ~ int_step(year > 0) * (b0 * (phi * exp(-ks * year) + (1 - phi) * exp(kg * year))) + int_step(year <= 0) * (b0 * exp(kg * year))
sigma ~ log(tau) + log(eta)
tau ~ 1
b0 ~ exp(lb0)
phi ~ inv_logit(tphi)
ks ~ exp(lks)
kg ~ exp(lkg)
lb0 ~ 1 + (1 | id)
tphi ~ 1 + (1 | id)
lks ~ 1 + (1 | id)
lkg ~ 1 + (1 | id)
Data: df (Number of observations: 4099)
Draws: 4 chains, each with iter = 3000; warmup = 2000; thin = 1;
total post-warmup draws = 4000
Multilevel Hyperparameters:
~id (Number of levels: 701)
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(lb0_Intercept) 0.58 0.02 0.55 0.61 1.01 382 606
sd(tphi_Intercept) 2.12 0.18 1.78 2.49 1.01 611 1216
sd(lks_Intercept) 2.16 0.14 1.90 2.46 1.00 520 1282
sd(lkg_Intercept) 1.18 0.09 1.01 1.36 1.00 899 1766
Regression Coefficients:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
tau_Intercept 0.15 0.00 0.14 0.15 1.00 1611 2351
lb0_Intercept 3.63 0.02 3.59 3.68 1.03 221 547
tphi_Intercept -0.14 0.22 -0.57 0.27 1.00 424 839
lks_Intercept -0.62 0.10 -0.81 -0.43 1.00 1170 1868
lkg_Intercept -1.15 0.14 -1.43 -0.89 1.00 699 1488
Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
Let’s first look at the population level parameters:
Show the code
gsf_pop_params <-c("theta_b0", "theta_ks", "theta_kg", "theta_phi", "sigma")mcmc_trace(post_df, pars = gsf_pop_params)
Show the code
mcmc_pairs( post_df, pars = gsf_pop_params,off_diag_args =list(size =1, alpha =0.1))
The trace plots look good. The chains seem to have converged and the pairs plot shows no strong correlations between the parameters. Let’s check the table:
Warning: Dropping 'draws_df' class as required metadata was removed.
Show the code
post_sum
variable
mean
median
sd
mad
q5
q95
rhat
ess_bulk
ess_tail
theta_b0
44.766
44.746
1.057
1.073
43.074
46.528
1.007
179.361
378.906
theta_ks
5.857
5.444
1.924
1.582
3.589
9.560
1.003
384.622
905.388
theta_kg
0.640
0.637
0.061
0.062
0.541
0.745
1.000
685.379
1,885.264
theta_phi
0.467
0.466
0.055
0.058
0.376
0.555
1.003
419.761
831.991
omega_0
0.578
0.578
0.016
0.017
0.553
0.606
1.001
376.483
628.807
omega_s
2.159
2.155
0.145
0.145
1.929
2.412
1.002
521.866
1,265.639
omega_g
1.177
1.174
0.090
0.090
1.037
1.333
1.002
868.075
1,746.640
omega_phi
2.117
2.108
0.182
0.184
1.833
2.423
1.003
603.806
1,215.066
sigma
0.147
0.147
0.002
0.002
0.143
0.150
1.000
1,575.916
2,168.010
So \(\theta_{\phi}\) is estimated around 0.5. The other parameters are similar to the previous Stein-Fojo model, but we see a larger \(\theta_{k_s}\) e.g.
Observation vs model fit
We can now compare the model fit to the observations. Let’s do this for the first 20 patients again:
Show the code
pt_subset <-as.character(1:20)df_subset <- df |>filter(id %in% pt_subset)df_sim <- df_subset |>data_grid(id = pt_subset, year =seq_range(year, 101) ) |>add_epred_draws(fit) |>median_qi()df_sim |>ggplot(aes(x = year, y = sld)) +facet_wrap(~ id) +geom_ribbon(aes(y = .epred, ymin = .lower, ymax = .upper), alpha =0.3, fill ="deepskyblue" ) +geom_line(aes(y = .epred), color ="deepskyblue") +geom_point(data = df_subset, color ="tomato") +coord_cartesian(ylim =range(df_subset$sld)) +scale_fill_brewer(palette ="Greys") +labs(title ="GSF model fit")
This also looks good. The model seems to capture the data well.
With jmpost
This model can also be fit with the jmpost package. The corresponding function is LongitudinalGSF. The statistical model is specified in the vignette here.
Homework: Implement the generalized Stein-Fojo model with jmpost and compare the results with the brms implementation.
Source Code
---title: "3. Generalized Stein-Fojo model"author: - Daniel Sabanés Bové - Francois Mercierdate: last-modifiededitor_options: chunk_output_type: inlineformat: html: code-fold: show html-math-method: mathjaxcache: true---This appendix shows how the generalized Stein-Fojo model can be implemented in a Bayesian framework using the `brms` package in R.## Setup and load data{{< include _setup_and_load.qmd >}}{{< include _load_data.qmd >}}## Generalized Stein-Fojo modelHere we have an additional parameter $\phi$, which is the weight for the shrinkage in the double exponential model. The model is then:$$y^{*}(t_{ij}) = \psi_{b_{0}i} \{ \psi_{\phi i} \exp(- \psi_{k_{s}i} \cdot t_{ij}) + (1 - \psi_{\phi i}) \exp(\psi_{k_{g}i} \cdot t_{ij}) \}$$for positive times $t_{ij}$. Again, if the time $t$ is negative, i.e. the treatment has not started yet, then it is reasonable to assume that the tumor cannot shrink yet. That is, we have then $\phi_i = 0$. Therefore, the final model for the mean SLD is:$$y^{*}(t_{ij}) = \begin{cases} \psi_{b_{0}i} \exp(\psi_{k_{g}i} \cdot t_{ij}) & \text{if } t_{ij} < 0 \\\psi_{b_{0}i} \{ \psi_{\phi i} \exp(- \psi_{k_{s}i} \cdot t_{ij}) + (1 - \psi_{\phi i}) \exp(\psi_{k_{g}i} \cdot t_{ij}) \} & \text{if } t_{ij} \geq 0 \end{cases}$$In terms of likelihood and priors, we can use the same assumptions as in the previous model. The only difference is that we have to model the $\phi$ parameter. We can use a logit-normal distribution for this parameter. This is a normal distribution on the logit scale, which is then transformed to the unit interval.$$\psi_{\phi i} \sim \text{LogitNormal}(\text{logit}(0.5) = 0, 0.5)$$## Fit modelWe can now fit the model using `brms`. The structure is determined by the model formula:```{r}#| label: specify_gsf_modelformula <-bf(sld ~ ystar, nl =TRUE) +# Define the mean for the likelihoodnlf( ystar ~int_step(year >0) * (b0 * (phi *exp(-ks * year) + (1- phi) *exp(kg * year))) +int_step(year <=0) * (b0 *exp(kg * year)) ) +# As before:nlf(sigma ~log(tau) +log(ystar)) +lf(tau ~1) +# Define nonlinear parameter transformations:nlf(b0 ~exp(lb0)) +nlf(phi ~inv_logit(tphi)) +nlf(ks ~exp(lks)) +nlf(kg ~exp(lkg)) +# Define random effect structure:lf(lb0 ~1+ (1| id)) +lf(tphi ~1+ (1| id)) +lf(lks ~1+ (1| id)) +lf(lkg ~1+ (1| id))# Define the priorspriors <-c(prior(normal(log(65), 1), nlpar ="lb0"),prior(normal(log(0.52), 0.1), nlpar ="lks"),prior(normal(log(1.04), 1), nlpar ="lkg"),prior(normal(0, 0.5), nlpar ="tphi"),prior(normal(0, 3), lb =0, nlpar ="lb0", class ="sd"),prior(normal(0, 3), lb =0, nlpar ="lks", class ="sd"),prior(normal(0, 3), lb =0, nlpar ="lkg", class ="sd"),prior(student_t(3, 0, 22.2), lb =0, nlpar ="tphi", class ="sd"),prior(normal(0, 3), lb =0, nlpar ="tau"))# Initial values to avoid problems at the beginningn_patients <-nlevels(df$id)inits <-list(b_lb0 =array(3.61),b_lks =array(-1.25),b_lkg =array(-1.33),b_tphi =array(0),sd_1 =array(0.58),sd_2 =array(1.6),sd_3 =array(0.994),sd_4 =array(0.1),b_tau =array(0.161),z_1 =matrix(0, nrow =1, ncol = n_patients),z_2 =matrix(0, nrow =1, ncol = n_patients),z_3 =matrix(0, nrow =1, ncol = n_patients),z_4 =matrix(0, nrow =1, ncol = n_patients))# Fit the modelsave_file <-here("session-tgi/gsf1.RData")if (file.exists(save_file)) {load(save_file)} else { fit <-brm(formula = formula,data = df,prior = priors,family =gaussian(),init =rep(list(inits), CHAINS),chains = CHAINS, iter = WARMUP + ITER, warmup = WARMUP, seed = BAYES.SEED,refresh = REFRESH )save(fit, file = save_file)}# Summarize the fitsummary(fit)```In total this took 76 minutes on my laptop. ## Parameter estimates```{r}#| label: gsf_post_processingpost_df <-as_draws_df(fit)head(names(post_df), 10)post_df <- post_df |>mutate(theta_b0 =exp(b_lb0_Intercept + sd_id__lb0_Intercept^2/2),theta_ks =exp(b_lks_Intercept + sd_id__lks_Intercept^2/2),theta_kg =exp(b_lkg_Intercept + sd_id__lkg_Intercept^2/2),theta_phi =plogis(b_tphi_Intercept),omega_0 = sd_id__lb0_Intercept,omega_s = sd_id__lks_Intercept,omega_g = sd_id__lkg_Intercept,omega_phi = sd_id__tphi_Intercept,cv_0 =sqrt(exp(sd_id__lb0_Intercept^2) -1),cv_s =sqrt(exp(sd_id__lks_Intercept^2) -1),cv_g =sqrt(exp(sd_id__lkg_Intercept^2) -1),sigma = b_tau_Intercept )```Let's first look at the population level parameters:```{r}#| label: gsf_pop_paramsgsf_pop_params <-c("theta_b0", "theta_ks", "theta_kg", "theta_phi", "sigma")mcmc_trace(post_df, pars = gsf_pop_params)mcmc_pairs( post_df, pars = gsf_pop_params,off_diag_args =list(size =1, alpha =0.1))```The trace plots look good. The chains seem to have converged and the pairs plot shows no strong correlations between the parameters. Let's check the table:```{r}#| label: gsf_post_summarypost_sum <- post_df |>select(theta_b0, theta_ks, theta_kg, theta_phi, omega_0, omega_s, omega_g, omega_phi, sigma) |>summarize_draws() |>gt() |>fmt_number(decimals =3)post_sum```So $\theta_{\phi}$ is estimated around 0.5. The other parameters are similar to the previous Stein-Fojo model, but we see a larger $\theta_{k_s}$ e.g.## Observation vs model fitWe can now compare the model fit to the observations. Let's do this for the first 20 patients again:```{r}#| label: gsf_model_fitpt_subset <-as.character(1:20)df_subset <- df |>filter(id %in% pt_subset)df_sim <- df_subset |>data_grid(id = pt_subset, year =seq_range(year, 101) ) |>add_epred_draws(fit) |>median_qi()df_sim |>ggplot(aes(x = year, y = sld)) +facet_wrap(~ id) +geom_ribbon(aes(y = .epred, ymin = .lower, ymax = .upper), alpha =0.3, fill ="deepskyblue" ) +geom_line(aes(y = .epred), color ="deepskyblue") +geom_point(data = df_subset, color ="tomato") +coord_cartesian(ylim =range(df_subset$sld)) +scale_fill_brewer(palette ="Greys") +labs(title ="GSF model fit")```This also looks good. The model seems to capture the data well.## With `jmpost`This model can also be fit with the `jmpost` package. The corresponding function is `LongitudinalGSF`. The statistical model is specified in the vignette [here](https://genentech.github.io/jmpost/main/articles/statistical-specification.html#generalized-stein-fojo-gsf-model).Homework: Implement the generalized Stein-Fojo model with `jmpost` and compare the results with the `brms` implementation.