1. TGI-OS joint model minimal workflow with jmpost

Authors

Daniel Sabanés Bové

Francois Mercier

Published

2025-06-25

The purpose of this document is to show a minimal workflow for fitting a joint Stein-Fojo TGI + Weibull OS model using the jmpost package.

Setup and load data

Here we execute the R code from the setup and data preparation chapter, see the full code here.

Data object preparation

First we again prepare the data objects, starting with the subject level data:

Show the code
subj_df <- os_data |>
    mutate(study = "OAK") |>
    select(study, id, arm)

subj_data <- DataSubject(
    data = subj_df,
    subject = "id",
    arm = "arm",
    study = "study"
)

Next we prepare the longitudinal data object.

Show the code
long_df <- tumor_data |>
    select(id, year, sld)
long_data <- DataLongitudinal(
    data = long_df,
    formula = sld ~ year
)

Next we prepare the DataSurvival object, where specify the covariates for the survival model:

Show the code
surv_data <- DataSurvival(
    data = os_data,
    formula = Surv(os_time, os_event) ~ ecog + age + race + sex
)

Note that we don’t include the log growth rate estimate log_kg_est here, because here we are fitting a joint model - the log growth rate will be included later in the JointModel specification instead. For now we also don’t include the treatment arm covariate, because we assume Working Assumption (1) from the previous session, i.e. the treatment effect is fully captured by the mediator (or link).

Now we can create the JointData object for the TGI model:

Show the code
joint_data <- DataJoint(
    subject = subj_data,
    longitudinal = long_data,
    survival = surv_data
)

Model specification

We specify the Stein-Fojo model for the TGI data, the Weibull model for the OS data, as well as the link, i.e. the log growth rate from the TGI model which shall influence the hazard in the OS model:

Show the code
joint_mod <- JointModel(
    longitudinal = LongitudinalSteinFojo(
        mu_bsld = prior_normal(log(65), 1),
        mu_ks = prior_normal(log(0.52), 1),
        mu_kg = prior_normal(log(1.04), 1),
        omega_bsld = prior_normal(0, 3) |> set_limits(0, Inf),
        omega_ks = prior_normal(0, 3) |> set_limits(0, Inf),
        omega_kg = prior_normal(0, 3) |> set_limits(0, Inf),
        sigma = prior_normal(0, 3) |> set_limits(0, Inf)
    ),
    survival = SurvivalWeibullPH(
        lambda = prior_gamma(0.7, 1),
        gamma = prior_gamma(1.5, 1),
        beta = prior_normal(0, 20)
    ),
    link = linkGrowth(
        prior = prior_normal(0, 20)
    )
)

Here we use a normal prior with mean 0 and standard deviation 20 for the link coefficient, which is a very uninformative prior. This corresponds to the same prior used for the regression coefficients of the “fixed” covariates in SurvivalWeibullPH.

For the other parameters, we keep the same prior as in the previous sessions on separate TGI and OS models.

Model fitting

Again we need to be careful with the automatic selection of initial values due to the large standard deviation on beta and the link coefficient. We therefore set the shrinkage option and check the initial values:

Show the code
options("jmpost.prior_shrinkage" = 0.99)

initialValues(joint_mod, n_chains = CHAINS)
[[1]]
[[1]]$lm_sf_mu_bsld
[1] 4.150467

[[1]]$lm_sf_mu_ks
[1] -0.6527837

[[1]]$lm_sf_mu_kg
[1] 0.0296721

[[1]]$lm_sf_omega_bsld
[1] 0.001132921

[[1]]$lm_sf_omega_ks
[1] 0.01131233

[[1]]$lm_sf_omega_kg
[1] 0.03935367

[[1]]$lm_sf_sigma
[1] 0.006231285

[[1]]$lm_sf_eta_tilde_bsld
[1] -0.01215586

[[1]]$lm_sf_eta_tilde_ks
[1] -0.001196457

[[1]]$lm_sf_eta_tilde_kg
[1] -8.78613e-05

[[1]]$sm_weibull_ph_lambda
[1] 0.6964901

[[1]]$sm_weibull_ph_gamma
[1] 1.494504

[[1]]$beta_os_cov
[1] -0.1833369

[[1]]$link_growth
[1] 0.03838646


[[2]]
[[2]]$lm_sf_mu_bsld
[1] 4.176525

[[2]]$lm_sf_mu_ks
[1] -0.6589843

[[2]]$lm_sf_mu_kg
[1] 0.04359252

[[2]]$lm_sf_omega_bsld
[1] 0.006504328

[[2]]$lm_sf_omega_ks
[1] 0.01123342

[[2]]$lm_sf_omega_kg
[1] 0.002242244

[[2]]$lm_sf_sigma
[1] 0.02596012

[[2]]$lm_sf_eta_tilde_bsld
[1] 0.003312313

[[2]]$lm_sf_eta_tilde_ks
[1] 0.03004617

[[2]]$lm_sf_eta_tilde_kg
[1] 0.0009214603

[[2]]$sm_weibull_ph_lambda
[1] 0.6934962

[[2]]$sm_weibull_ph_gamma
[1] 1.495563

[[2]]$beta_os_cov
[1] -0.08520562

[[2]]$link_growth
[1] 0.2842043


[[3]]
[[3]]$lm_sf_mu_bsld
[1] 4.179447

[[3]]$lm_sf_mu_ks
[1] -0.6655289

[[3]]$lm_sf_mu_kg
[1] 0.05239981

[[3]]$lm_sf_omega_bsld
[1] 0.0142334

[[3]]$lm_sf_omega_ks
[1] 0.06144598

[[3]]$lm_sf_omega_kg
[1] 0.009828384

[[3]]$lm_sf_sigma
[1] 0.03141703

[[3]]$lm_sf_eta_tilde_bsld
[1] 0.007005259

[[3]]$lm_sf_eta_tilde_ks
[1] 0.01321017

[[3]]$lm_sf_eta_tilde_kg
[1] 0.01108995

[[3]]$sm_weibull_ph_lambda
[1] 0.701657

[[3]]$sm_weibull_ph_gamma
[1] 1.485266

[[3]]$beta_os_cov
[1] -0.3623778

[[3]]$link_growth
[1] -0.04616349


[[4]]
[[4]]$lm_sf_mu_bsld
[1] 4.182144

[[4]]$lm_sf_mu_ks
[1] -0.647393

[[4]]$lm_sf_mu_kg
[1] 0.04806538

[[4]]$lm_sf_omega_bsld
[1] 0.0006164391

[[4]]$lm_sf_omega_ks
[1] 0.001151999

[[4]]$lm_sf_omega_kg
[1] 0.02038285

[[4]]$lm_sf_sigma
[1] 0.02773427

[[4]]$lm_sf_eta_tilde_bsld
[1] 0.003735449

[[4]]$lm_sf_eta_tilde_ks
[1] -0.008627102

[[4]]$lm_sf_eta_tilde_kg
[1] -0.008762467

[[4]]$sm_weibull_ph_lambda
[1] 0.7051286

[[4]]$sm_weibull_ph_gamma
[1] 1.503295

[[4]]$beta_os_cov
[1] -0.1125564

[[4]]$link_growth
[1] -0.1922085

If we don’t do this, then it is very likely that some chains will diverge because of very unrealistic initial values for beta and/or the link coefficient.

Now we can fit the model:

Show the code
save_file <- here("session-jm/jm1.rds")
if (file.exists(save_file)) {
    joint_results <- readRDS(save_file)
} else {
    joint_results <- sampleStanModel(
        joint_mod,
        data = joint_data,
        iter_sampling = ITER,
        iter_warmup = WARMUP,
        chains = CHAINS,
        parallel_chains = CHAINS,
        thin = CHAINS,
        seed = BAYES.SEED,
        refresh = REFRESH
    )
    saveObject(joint_results, file = save_file)
}

As mentioned before, also here we can get warnings at the beginning of the chains’ sampling process (“The current Metropolis proposal is about to be rejected …”). As long as this only happens in the beginning, and not during the sampling later, then this is not a cause for concern.

We note that the MCMC sampling process takes much longer here (about factor 10 more) compared to just fitting the TGI or the OS data separately. This is due to the more complex likelihood function calculations for the joint TGI-OS model.

Let’s check the convergence of the population parameters. If we don’t remember their names, we can query them as follows:

Show the code
joint_results

   JointModelSamples Object with:
  
      # of samples per chain = 1000
      # of chains            = 4
  
      Variables:
          beta_os_cov[5]
          link_coefficients
          link_function_inputs[203, 3]
          link_growth
          lm_sf_eta_tilde_bsld[203]
          lm_sf_eta_tilde_kg[203]
          lm_sf_eta_tilde_ks[203]
          lm_sf_mu_bsld
          lm_sf_mu_kg[2]
          lm_sf_mu_ks[2]
          lm_sf_omega_bsld
          lm_sf_omega_kg[2]
          lm_sf_omega_ks[2]
          lm_sf_psi_bsld[203]
          lm_sf_psi_kg[203]
          lm_sf_psi_ks[203]
          lm_sf_sigma
          log_surv_fit_at_obs_times[203]
          long_obvs_log_lik[1093]
          lp__
          os_cov_contribution[203]
          os_subj_log_lik[203]
          pars_os[2]
          sm_weibull_ph_gamma
          sm_weibull_ph_lambda
          Ypred[1093] 
Show the code
vars <- c(
    "lm_sf_mu_bsld",
    "lm_sf_mu_ks",
    "lm_sf_mu_kg",
    "lm_sf_sigma",
    "lm_sf_omega_bsld",
    "lm_sf_omega_ks",
    "lm_sf_omega_kg",
    "beta_os_cov",
    "sm_weibull_ph_gamma",
    "sm_weibull_ph_lambda",
    "link_growth"
)

mcmc_joint_results <- cmdstanr::as.CmdStanMCMC(joint_results)
mcmc_joint_results$summary(vars)
# A tibble: 19 × 10
   variable         mean   median      sd     mad      q5     q95  rhat ess_bulk
   <chr>           <dbl>    <dbl>   <dbl>   <dbl>   <dbl>   <dbl> <dbl>    <dbl>
 1 lm_sf_mu_bs…  3.76     3.76    0.0371  0.0355   3.69    3.82   1.02      192.
 2 lm_sf_mu_ks… -0.0188   0.00943 0.235   0.228   -0.445   0.339  1.00      462.
 3 lm_sf_mu_ks… -1.06    -1.03    0.330   0.314   -1.66   -0.591  1.01      559.
 4 lm_sf_mu_kg… -0.585   -0.576   0.139   0.146   -0.837  -0.375  0.999     500.
 5 lm_sf_mu_kg… -0.892   -0.890   0.139   0.138   -1.14   -0.656  1.00      552.
 6 lm_sf_sigma   0.129    0.129   0.00371 0.00366  0.123   0.135  1.00      840.
 7 lm_sf_omega…  0.530    0.529   0.0265  0.0256   0.488   0.576  1.01      340.
 8 lm_sf_omega…  0.936    0.913   0.178   0.169    0.680   1.27   1.01      409.
 9 lm_sf_omega…  1.36     1.33    0.252   0.247    0.989   1.81   1.00      566.
10 lm_sf_omega…  0.684    0.677   0.0812  0.0744   0.561   0.831  1.00      781.
11 lm_sf_omega…  0.956    0.952   0.0953  0.0973   0.810   1.12   1.01      585.
12 beta_os_cov…  0.842    0.829   0.234   0.229    0.464   1.24   1.00      812.
13 beta_os_cov…  0.00316  0.00296 0.00981 0.00996 -0.0128  0.0197 0.998     975.
14 beta_os_cov…  0.658    0.656   0.431   0.428   -0.0678  1.36   1.00      901.
15 beta_os_cov… -0.00443 -0.00251 0.237   0.247   -0.371   0.371  1.00      760.
16 beta_os_cov…  0.313    0.305   0.215   0.217   -0.0363  0.677  1.00      735.
17 sm_weibull_…  1.81     1.81    0.167   0.170    1.55    2.09   1.00      794.
18 sm_weibull_…  0.261    0.213   0.176   0.132    0.0733  0.635  0.998     805.
19 link_growth   0.880    0.861   0.235   0.242    0.525   1.29   1.00      571.
# ℹ 1 more variable: ess_tail <dbl>
Show the code
draws_joint_results <- mcmc_joint_results$draws(vars)
mcmc_trace(draws_joint_results)

So this looks good.

Interpret covariate effects

In order to better see which of the survival model coefficients relate to which covariates, we can again rename them as follows:

Show the code
os_cov_name_mapping <- function(surv_data) {
    surv_data_design <- as_stan_list(surv_data)$os_cov_design
    os_cov_names <- colnames(surv_data_design)
    old_coef_names <- as.character(glue::glue("beta_os_cov[{seq_along(os_cov_names)}]"))
    setNames(old_coef_names, os_cov_names)
}
os_cov_renaming <- os_cov_name_mapping(surv_data)
draws_joint_results <- do.call(
    rename_variables,
    c(list(draws_joint_results), os_cov_renaming)
)
mcmc_dens_overlay(draws_joint_results) +
    geom_vline(xintercept = 0, linetype = "dashed", color = "red")

If we compare this with the covariate effect estimates from the 2-stage OS model we did in the last session, then we can see:

  • the log_kg_est effect (see link_growth here) is stronger here, but also with larger uncertainty
  • the ecog effect is similar (clearly higher risk with ECOG 1)
  • the age effect is similar (no effect)
  • the race effect is similar (almost no effect)
  • the sex effect is similar (higher risk for males)

It is also interesting to look at the shrinkage and growth rate estimates from the SF model part: We see e.g. 0 shrinkage in arm 1, which is the control arm, while we see a strong shrinkage in arm 2, which is the Atezo arm.

SLD vs longitudinal model fit

Let’s first check the fit of the Stein-Fojo model to the SLD data.

The first step is to generate the predictions at the subject level. We can do this using the LongitudinalQuantities() function, which takes the MCMC results and the grid at which the predictions should be made. Here we use the GridObserved() function, which takes the IDs of the subjects for which the predictions should be made. Since each patient has its own plot, we sample a small subset of patient IDs here as an example only. In a real application we could write a simple loop that then processes batches of patients in sequence.

Show the code
set.seed(521)
pt_subset <- as.character(sample(subj_df$id, 20))

tgi_fit_pred <- LongitudinalQuantities(
    joint_results,
    grid = GridObserved(subjects = pt_subset)
)

Note that here again a Stan program needs to be compiled, which can take some time (but only the first time, because the executable is cached). This is because we pass the posterior samples to a Stan program which then generates the quantities of interest, here the Stein-Fojo model fit for each patient at the observed time points.

Now we can plot the predictions:

Show the code
autoplot(tgi_fit_pred) +
    labs(x = "Time (years)", y = "SLD (mm)")

We can see that the model fits the data well, with the estimates Stein-Fojo model curves closely following the observed values.

Kaplan-Meier vs survival model fit

Another useful plot displays the model predicted survival function and overlays the non-parametric Kaplan-Meier plot to it.

The first step consists in generating the survival predictions at the group level with the SurvivalQuantities() function. In order to do this, we use now the GridGrouped() function, which takes the time points at which the predictions should be made and the groups for which the predictions should be made (as a list containing the IDs in each element defining the group). This works the same way as in the previous session with the OS model.

Show the code
time_grid <- seq(from = 0, to = max(os_data$os_time), length = 100)
os_surv_group_grid <- GridGrouped(
    times = time_grid,
    groups = with(
        subj_df,
        split(as.character(id), arm)
    )
)
os_surv_pred <- SurvivalQuantities(
    object = joint_results,
    grid = os_surv_group_grid,
    type = "surv"
)

Now we can use the autoplot() method:

Show the code
autoplot(os_surv_pred, add_km = TRUE, add_wrap = FALSE)

We can see that the fit looks adequate, with the modelled survival functions closely following the Kaplan-Meier curves in each treatment group. Of note, this fit looks better than in the 2-stage TGI-OS model from the previous session.

Hazard and hazard rate estimation

Similarly to the survival function estimation, we can also estimate the hazard function by treatment group.

Show the code
os_hazard_pred <- SurvivalQuantities(
    object = joint_results,
    grid = os_surv_group_grid,
    type = "haz"
)

Also this can be plotted using the autoplot() method:

Show the code
autoplot(os_hazard_pred, add_wrap = FALSE)

We can already see here that this looks slightly different than the same plot from the 2-step TGI-OS model:

  • The hazards are higher
  • The uncertainty is considerably larger

Now let’s look at the estimated hazard ratio:

Show the code
os_hr_est <- os_hazard_pred |>
    as.data.frame() |>
    group_by(group, time) |>
    mutate(sample = row_number()) |>
    pivot_wider(names_from = group, values_from = values) |>
    mutate(hr = MPDL3280A / Docetaxel) |>
    group_by(time) |>
    summarize(
        mean = mean(hr, na.rm = TRUE),
        lower = quantile(hr, 0.05, na.rm = TRUE),
        upper = quantile(hr, 0.95, na.rm = TRUE)
    ) |>
    na.omit() # Omit the time = 0 which has NA
summary(os_hr_est)
      time              mean            lower            upper      
 Min.   :0.02276   Min.   :0.8703   Min.   :0.6966   Min.   :1.063  
 1st Qu.:0.58038   1st Qu.:0.8703   1st Qu.:0.6966   1st Qu.:1.063  
 Median :1.13801   Median :0.8703   Median :0.6966   Median :1.063  
 Mean   :1.13801   Mean   :0.8703   Mean   :0.6966   Mean   :1.063  
 3rd Qu.:1.69563   3rd Qu.:0.8703   3rd Qu.:0.6966   3rd Qu.:1.063  
 Max.   :2.25325   Max.   :0.8703   Max.   :0.6966   Max.   :1.063  

Also here the hazard ratio is indeed constant over time, which was the same in the 2-step TGI-OS model. This is because the link between the longitudinal and the survival model is here the log growth rate of the Stein-Fojo model, which is constant over time. For other link functions that are time-varying, e.g. the derivative of the longitudinal model, the hazard ratio could change over time.

So here the estimated hazard ratio is 0.87 with a 90% credible interval of 0.7 to 1.06. We can see that this hazard ratio estimate is slightly lower, representing a slightly stronger effect estimate. Due to the larger uncertainty, which is due to the propagation of the log growth rate uncertainty to the OS model, the 90% CI now actually includes 1, in contrast to the 2-step TGI-OS model.

Alternative covariate specification

In the previous OS session we tried to include the treatment arm as a direct covariate. We can also try this here.

Show the code
surv_data_with_arm <- DataSurvival(
    data = os_data,
    # Here we add the arm covariate:
    formula = update(surv_data@formula, . ~ . + arm)
)
joint_data_with_arm <- DataJoint(
    subject = subj_data,
    longitudinal = long_data,
    survival = surv_data_with_arm
)

save_file <- here("session-jm/jm2.rds")
if (file.exists(save_file)) {
    joint_results_with_arm <- readRDS(save_file)
} else {
    joint_results_with_arm <- sampleStanModel(
        joint_mod,
        data = joint_data_with_arm,
        iter_sampling = ITER,
        iter_warmup = WARMUP,
        chains = CHAINS,
        parallel_chains = CHAINS,
        thin = CHAINS,
        seed = BAYES.SEED,
        refresh = REFRESH
    )
    saveObject(joint_results_with_arm, file = save_file)
}

mcmc_joint_arm_results <- cmdstanr::as.CmdStanMCMC(joint_results_with_arm)
mcmc_joint_arm_results$summary(vars)
# A tibble: 20 × 10
   variable         mean   median      sd     mad      q5     q95  rhat ess_bulk
   <chr>           <dbl>    <dbl>   <dbl>   <dbl>   <dbl>   <dbl> <dbl>    <dbl>
 1 lm_sf_mu_bs…  3.75     3.76    0.0373  0.0377   3.69    3.82   1.02      196.
 2 lm_sf_mu_ks… -0.120   -0.0763  0.274   0.269   -0.590   0.271  1.01      455.
 3 lm_sf_mu_ks… -1.00    -0.972   0.328   0.315   -1.59   -0.525  1.00      538.
 4 lm_sf_mu_kg… -0.643   -0.638   0.154   0.153   -0.908  -0.402  1.01      500.
 5 lm_sf_mu_kg… -0.868   -0.862   0.146   0.146   -1.12   -0.640  1.01      522.
 6 lm_sf_sigma   0.129    0.129   0.00356 0.00356  0.124   0.136  1.00      944.
 7 lm_sf_omega…  0.531    0.530   0.0279  0.0268   0.489   0.583  1.01      342.
 8 lm_sf_omega…  0.979    0.959   0.191   0.181    0.713   1.34   1.00      481.
 9 lm_sf_omega…  1.32     1.29    0.240   0.232    0.972   1.74   1.00      580.
10 lm_sf_omega…  0.687    0.680   0.0871  0.0896   0.560   0.844  1.00      911.
11 lm_sf_omega…  0.957    0.951   0.0951  0.0898   0.820   1.13   1.00      647.
12 beta_os_cov…  0.864    0.858   0.235   0.237    0.477   1.26   1.00     1028.
13 beta_os_cov…  0.00372  0.00329 0.00996 0.0104  -0.0125  0.0197 0.999     964.
14 beta_os_cov…  0.690    0.694   0.445   0.447   -0.0367  1.44   1.00      930.
15 beta_os_cov…  0.0296   0.0232  0.248   0.239   -0.371   0.439  0.999     906.
16 beta_os_cov…  0.310    0.307   0.216   0.218   -0.0299  0.681  0.999     958.
17 beta_os_cov… -0.207   -0.204   0.239   0.229   -0.606   0.176  0.999     767.
18 sm_weibull_…  1.82     1.81    0.163   0.160    1.56    2.10   1.01      836.
19 sm_weibull_…  0.272    0.228   0.193   0.148    0.0755  0.608  0.999    1000.
20 link_growth   0.870    0.862   0.239   0.239    0.499   1.27   1.01      617.
# ℹ 1 more variable: ess_tail <dbl>
Show the code
draws_joint_arm_results <- mcmc_joint_arm_results$draws(vars)
mcmc_trace(draws_joint_arm_results)

We can easily plot the survival functions and compare them with the Kaplan-Meier curves of the treatment arms, because we can reuse the above os_surv_group_grid:

Show the code
joint_mod_with_arm_os_pred <- SurvivalQuantities(
    object = joint_results_with_arm,
    grid = os_surv_group_grid,
    type = "surv"
)
autoplot(joint_mod_with_arm_os_pred, add_km = TRUE, add_wrap = FALSE)

Model comparison

We can again use the Brier score to compare the three different survival models as part of the joint models. The Brier score is a measure of the mean squared difference between the predicted survival probability and the actual survival status. The lower the Brier score, the better the model.

To calculate it, we need to use the GridFixed input for SurvivalQuantities(). This is because the Brier score is calculated at fixed time points across all patients, and not at the observed time points of specific patients. Because we don’t specify patient IDs, the quantities are generated for all patients.

Show the code
os_surv_fixed <- SurvivalQuantities(
    object = joint_results,
    grid = GridFixed(times = time_grid),
    type = "surv"
)
os_bs <- brierScore(os_surv_fixed)

Let’s first compare this with the second model, which includes the treatment arm as a direct covariate in the survival model:

Show the code
os_surv_with_arm_fixed <- SurvivalQuantities(
    object = joint_results_with_arm,
    grid = GridFixed(times = time_grid),
    type = "surv"
)
os_with_arm_bs <- brierScore(os_surv_with_arm_fixed)

Now let’s compare this with the third model where we have the derivative link. We have a suspicion that the third model will perform worse, because of the hazard ratio results, but let’s have a look.

Show the code
os_surv_dsld_fixed <- SurvivalQuantities(
    object = joint_results_dsld,
    grid = GridFixed(times = time_grid),
    type = "surv"
)
os_dsld_bs <- brierScore(os_surv_dsld_fixed)

We can plot then all three Brier scores over time to compare them visually:

Show the code
data.frame(
    time = time_grid,
    `1 - 2` = os_bs - os_with_arm_bs,
    `1 - 3` = os_bs - os_dsld_bs,
    `2 - 3` = os_with_arm_bs - os_dsld_bs,
    check.names = FALSE
) |>
    pivot_longer(
        cols = c("1 - 2", "1 - 3", "2 - 3"),
        names_to = "diff",
        values_to = "brier_score_diff"
    ) |>
    ggplot(aes(x = time, y = brier_score_diff, color = diff, group = diff)) +
    geom_line() +
    labs(y = "Brier score difference") +
    geom_hline(yintercept = 0, linetype = "dashed", color = "red")

As expected, we see that especially for the times between 0.5 and 2 years the first and second models with the growth link performs better than the third model with the derivative link. On the other hand, there is almost no difference between the first and second model, which includes the treatment arm covariate.

Currently, the LOOIC does not work yet for joint models, because the log_lik is not included in the samples. There is an ongoing discussion on GitHub about this topic.