Setting priors on hierarchical multivariate models

keep it realistic.

UdeS
stan
Author

Andrew MacDonald

Published

November 11, 2022

library(targets)
library(ggplot2)
library(tidyverse)
library(tidybayes)

What is this post

demo_data <- structure(list(csize = c(5, 5, 5), 
                            age_morpho_indic = c(1, 1, 
1), mass = c(20.64, 25.58, 20.2), log_mass = c(3.02723094061336, 
3.2418107961507, 3.00568260440716), dponte = c(134, 135, 136), 
    annee = c(2005, 2006, 2005), ferme = c(9, 9, 9), idF1 = c(188197003, 
    188197004, 188197004), general_csize = c(8.972, 9.344, 8.972
    ), coldsnap_csize = c(3.9, 7.15, 3.9), general_dponte = c(8.704651163, 
    10.48604651, 8.704651163), coldsnap_dponte = c(2.3, 4.7, 
    2.3), density_TRSW = c(8, 9, 8), density_HOSP = c(0, 0, 0
    ), paysa_ext = c(0.224174146, 0.223943415, 0.224166189), 
    general_mean_csize = c(-0.482520728753587, -0.343044978530785, 
    -0.343044978530785), difference_general_csize = c(0, 0.186, 
    -0.186), coldsnap_mean_csize = c(-1.09248466640163, -0.156587649057671, 
    -0.156587649057671), difference_coldsnap_csize = c(0, 1.625, 
    -1.625), general_mean_dponte = c(-2.54912170305933, -1.57828050221117, 
    -1.57828050221117), difference_general_dponte = c(0, 0.890698, 
    -0.890698), coldsnap_mean_dponte = c(-1.63910623622377, -0.951828417972457, 
    -0.951828417972457), difference_coldsnap_dponte = c(0, 1.2, 
    -1.2), density_TRSW_mean = c(0.248353096875404, 0.484967912059555, 
    0.484967912059555), difference_density_TRSW = c(0, 0.5, -0.5
    ), paysa_ext_mean = c(-0.306307397416911, -0.308050462721618, 
    -0.308050462721618), difference_paysa_ext = c(-0.000913607068423686, 
    -0.0152243604543345, 0.0133971463174872), density_HOSP_mean = c(-0.525930668145508, 
    -0.525930668145508, -0.525930668145508), difference_density_HOSP = c(0, 
    0, 0), noisenvol = c(5, 5, 0)), row.names = c(NA, -3L), class = c("tbl_df", 
"tbl", "data.frame"))
library(brms)
dponte_model_bf_2 = bf(dponte ~ 1 + age_morpho_indic + general_mean_dponte + difference_general_dponte + (1|annee) + (1|ferme) + 
                         (1 + difference_general_dponte|f|idF1),
                       family = gaussian(), center = FALSE)

 

csize_model_bf_2 = bf(csize ~ 1 + age_morpho_indic + general_mean_csize + difference_general_csize + (1|annee) + (1|ferme) + 
                        (1 + difference_general_csize|f|idF1),
                      family = poisson(), center = FALSE)

 

sucess_model_bf_2 = bf(noisenvol ~ 1 + (1|f|idF1) + (1|annee) + (1|ferme),
                       family = poisson(), center = FALSE)

# combine all three into one model
full_model_bf <- dponte_model_bf_2 + csize_model_bf_2 + sucess_model_bf_2

get_prior(full_model_bf, data = demo_data)

full_model_prior <- c(
  ## individual level correlations
  prior(lkj(3), class = "cor", group = "idF1"),
  ## clutch size model
  prior(normal(0,1),    class = "b", resp = "csize"),
  # prior(normal(0,1),    class = "Intercept", resp = "csize"),
  prior(exponential(1), lb = 0, class = "sd", resp = "csize"),
  ## laying date model
  prior(normal(0,1),    class = "b",         resp = "dponte"),
  # prior(normal(0,1),    class = "Intercept", resp = "dponte"),
  prior(exponential(1), lb = 0, class = "sd",   resp = "dponte"),
  prior(exponential(1), lb = 0, class = "sigma", resp = "dponte"),
  # fitness -- no slopes here
  # prior(normal(0,1),    class = "Intercept", resp = "noisenvol"),
  prior(exponential(1), lb = 0, class = "sd",        resp = "noisenvol")
)

# Run fuul model
full_model_prior_predict = brm(full_model_bf,
                    data = demo_data,
                    prior = full_model_prior,
                    cores = 4, chains = 4, 
                    sample_prior = "only")

 

summary(full_model_prior_predict)
library(tidybayes)

# noisenvol
prior_predictions <- demo_data |> 
  add_predicted_rvars(object = full_model_prior_predict, resp = "noisenvol")

prior_predictions |> glimpse()

library()
prior_predictions |> 
  select(idF1, noisenvol, .prediction) |> 
  ungroup() |> 
  ggplot(aes(y = idF1, xdist = .prediction)) + 
  stat_halfeye() + 
  coord_cartesian(xlim =c(0,1e5))

already we can see that this is probably way too wide!

full_model_smaller_noisenvol_prior <- c(
  ## individual level correlations
  prior(lkj(3), class = "cor", group = "idF1"),
  ## clutch size model
  prior(normal(0,1),    class = "b", resp = "csize"),
  prior(exponential(1), lb = 0, class = "sd", resp = "csize"),
  ## laying date model
  prior(normal(138,5),    class = "b",         resp = "dponte"),
  prior(exponential(1), lb = 0, class = "sd",   resp = "dponte"),
  prior(exponential(1), lb = 0, class = "sigma", resp = "dponte"),
  # fitness -- no slopes here
  prior(normal(1, 0.1),    class = "b", resp = "noisenvol"),
  prior(exponential(4), lb = 0, class = "sd",        resp = "noisenvol")
)



# Run fuul model
full_model_noisenvol_prior_predict = brm(full_model_bf,
                    data = demo_data,
                    prior = full_model_smaller_noisenvol_prior,
                    cores = 4, chains = 4, 
                    sample_prior = "only")

 

demo_data |> 
  add_predicted_rvars(object = full_model_noisenvol_prior_predict, resp = "noisenvol") |> 
  select(idF1, noisenvol, .prediction) |> 
  ungroup() |> 
  ggplot(aes(y = idF1, xdist = .prediction)) + 
  stat_halfeye() + 
  coord_cartesian(xlim =c(0,12))
# age_morpho_indic + general_mean_dponte + difference_general_dponte
one_female <- tibble(difference_general_dponte = c(-3,0, 3),
                     age_morpho_indic = c(0,1,1))

fake_data <- expand_grid(one_female,
                         nesting(general_mean_dponte = c(-2,0,2),
                                 idF1 = letters[1:3])) |> 
  arrange(general_mean_dponte) |> 
  mutate(ferme = "f",
         annee = "y")

dponte_prior_pred <- fake_data |> 
  add_predicted_draws(object = full_model_noisenvol_prior_predict, 
                      resp = "dponte", allow_new_levels = TRUE, ndraws = 3)
dponte_prior_pred |> 
  ggplot(aes(x= difference_general_dponte, y = .prediction, group = .draw)) + 
  geom_line() + 
  facet_wrap(~general_mean_dponte)