library(tidyverse)
library(brms)
library(tidybayes)
library(ggdag)
library(ggrepel)
library(patchwork)
# Plot stuff
<- MetBrewer::met.brewer("Lakota", 6)
clrs theme_set(theme_bw())
# Seed stuff
<- 1234
BAYES_SEED set.seed(1234)
Video #5 code
Elemental confounds
\[ \newcommand{\ind}{\perp\!\!\!\perp} \newcommand{\notind}{\not\!\perp\!\!\!\perp} \]
The fork (confounders)
\[ X \leftarrow Z \rightarrow Y \]
\(Z\) connects \(X\) and \(Y\) so that \(Y \notind X\)
Simulated example
We can make some data to prove that they’re connected:
<- 1000
n
<- tibble(Z = rbinom(n, 1, prob = 0.5)) %>%
fork_sim # When Z is 0, there's a 10% chance of X or Y being 1
# When Z is 1, there's a 90% chance of X or Y being 1
mutate(X = rbinom(n, 1, prob = ((1 - Z) * 0.1) + (Z * 0.9)),
Y = rbinom(n, 1, prob = ((1 - Z) * 0.1) + (Z * 0.9)))
%>%
fork_sim select(-Z) %>%
table()
## Y
## X 0 1
## 0 390 101
## 1 82 427
%>%
fork_sim summarize(cor = cor(X, Y))
## # A tibble: 1 × 1
## cor
## <dbl>
## 1 0.634
But if we stratify by (or adjust for) \(Z\), we can see that \(Y \ind X \mid Z\):
%>%
fork_sim select(X, Y, Z) %>%
table()
## , , Z = 0
##
## Y
## X 0 1
## 0 388 56
## 1 36 2
##
## , , Z = 1
##
## Y
## X 0 1
## 0 2 45
## 1 46 425
%>%
fork_sim group_by(Z) %>%
summarize(cor = cor(X, Y))
## # A tibble: 2 × 2
## Z cor
## <int> <dbl>
## 1 0 -0.0609
## 2 1 -0.0546
Here’s a continuous version too. When looking at all values of \(Z\), there’s a positive slope and relationship; when looking within each group, the relationship is 0 and flat.
<- 300
n
<- tibble(Z = rbinom(n, 1, 0.5)) %>%
fork_sim_cont mutate(X = rnorm(n, 2 * Z - 1),
Y = rnorm(n, 2 * Z - 1))
ggplot(fork_sim_cont, aes(x = X, y = Y, color = factor(Z))) +
geom_point() +
geom_smooth(method = "lm") +
geom_smooth(aes(color = NULL), method = "lm")
Waffle House example
data(WaffleDivorce, package = "rethinking")
<- WaffleDivorce %>%
WaffleDivorce mutate(across(c(Marriage, Divorce, MedianAgeMarriage), ~scale(.), .names = "{col}_scaled")) %>%
mutate(across(c(Marriage, Divorce, MedianAgeMarriage), ~as.numeric(scale(.)), .names = "{col}_z"))
What is the causal effect of marriage on divorce?
<- dagify(
height_sex_dag ~ z,
x ~ x + z,
y exposure = "x",
outcome = "y",
labels = c(x = "Marriage", y = "Divorce", z = "Age"),
coords = list(x = c(x = 1, y = 3, z = 2),
y = c(x = 1, y = 1, z = 2))) %>%
tidy_dagitty() %>%
node_status()
ggplot(height_sex_dag, aes(x = x, y = y, xend = xend, yend = yend)) +
geom_dag_edges() +
geom_dag_point(aes(color = status)) +
geom_dag_text(aes(label = label), size = 3.5, color = "black") +
scale_color_manual(values = clrs[c(1, 4)], guide = "none") +
theme_dag()
We can look at the relationship of all three of these arrows
ggplot(WaffleDivorce, aes(x = MedianAgeMarriage, y = Marriage)) +
geom_point(aes(fill = factor(South)), size = 4, pch = 21, color = "white") +
geom_smooth(method = "lm") +
geom_text_repel(aes(label = Loc), max.overlaps = 2) +
scale_fill_manual(values = clrs[c(1, 3)], guide = "none") +
labs(x = "Median age of marriage", y = "Marriage rate")
ggplot(WaffleDivorce, aes(x = MedianAgeMarriage, y = Divorce)) +
geom_point(aes(fill = factor(South)), size = 4, pch = 21, color = "white") +
geom_smooth(method = "lm") +
geom_text_repel(aes(label = Loc), max.overlaps = 2) +
scale_fill_manual(values = clrs[c(1, 3)], guide = "none") +
labs(x = "Median age of marriage", y = "Divorce rate")
ggplot(WaffleDivorce, aes(x = Marriage, y = Divorce)) +
geom_point(aes(fill = factor(South)), size = 4, pch = 21, color = "white") +
geom_smooth(method = "lm") +
geom_text_repel(aes(label = Loc), max.overlaps = 2) +
scale_fill_manual(values = clrs[c(1, 3)], guide = "none") +
labs(x = "Marriage rate", y = "Divorce rate")
How do we stratify by a continuous variable though? Regression!
\[ \begin{aligned} D_i &\sim \mathcal{N}(\mu_i, \sigma) \\ \mu_i &= \alpha + \beta_M M_i + \beta_A A_i \end{aligned} \]
Prior predictive simulation
\[ \begin{aligned} D_i &\sim \mathcal{N}(\mu_i, \sigma) \\ \mu_i &= \alpha + \beta_M M_i + \beta_A A_i \\ \\ \alpha &\sim \mathcal{N}(0, 0.2) \\ \beta_M &\sim \mathcal{N}(0, 0.5) \\ \beta_A &\sim \mathcal{N}(0, 0.5) \\ \sigma &\sim \operatorname{Exponential}(1) \end{aligned} \]
<- c(prior(normal(0, 0.2), class = Intercept),
priors prior(normal(0, 0.5), class = b, coef = "Marriage_z"),
prior(normal(0, 0.5), class = b, coef = "MedianAgeMarriage_z"),
prior(exponential(1), class = sigma))
<- brm(
marriage_divorce_prior_only bf(Divorce_z ~ Marriage_z + MedianAgeMarriage_z),
data = WaffleDivorce,
family = gaussian(),
prior = priors,
sample_prior = "only",
chains = 4, cores = 4, seed = BAYES_SEED
)## Compiling Stan program...
## Start sampling
<- tibble(MedianAgeMarriage_z = seq(-2, 2, length.out = 100),
draws_prior Marriage_z = 0) %>%
add_epred_draws(marriage_divorce_prior_only, ndraws = 100)
%>%
draws_prior ggplot(aes(x = MedianAgeMarriage_z, y = .epred)) +
geom_line(aes(group = .draw), alpha = 0.2) +
labs(x = "Median age of marriage (standardized)",
y = "Divorce rate (standardized)",
caption = "Standardized marriage rate held constant at 0")
Actual model
Based on these models,
Once we know median age at marriage for a state, there is little or no additional predictive power in also knowing the rate of marriage in that state. (p. 134)
<- c(prior(normal(0, 0.2), class = Intercept),
priors prior(normal(0, 0.5), class = b, coef = "Marriage_z"),
prior(normal(0, 0.5), class = b, coef = "MedianAgeMarriage_z"),
prior(exponential(1), class = sigma))
<- brm(
marriage_divorce_actual bf(Divorce_z ~ Marriage_z + MedianAgeMarriage_z),
data = WaffleDivorce,
family = gaussian(),
prior = priors,
chains = 4, cores = 4, seed = BAYES_SEED
)## Compiling Stan program...
## recompiling to avoid crashing R session
## Start sampling
print(marriage_divorce_actual)
## Family: gaussian
## Links: mu = identity; sigma = identity
## Formula: Divorce_z ~ Marriage_z + MedianAgeMarriage_z
## Data: WaffleDivorce (Number of observations: 50)
## Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
## total post-warmup draws = 4000
##
## Population-Level Effects:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## Intercept -0.00 0.10 -0.20 0.20 1.00 3824 2317
## Marriage_z -0.07 0.15 -0.36 0.24 1.00 3087 2922
## MedianAgeMarriage_z -0.62 0.16 -0.92 -0.31 1.00 2894 2655
##
## Family Specific Parameters:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sigma 0.83 0.09 0.68 1.01 1.00 3630 2543
##
## Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
# get_variables(marriage_divorce_actual)
%>%
marriage_divorce_actual gather_draws(b_Intercept, b_Marriage_z, b_MedianAgeMarriage_z, sigma) %>%
ggplot(aes(x = .value, y = fct_rev(.variable))) +
stat_halfeye() +
coord_cartesian(xlim = c(-1, 1))
marriage_divorce_stan.stan
data {
int<lower=1> n; // Observations
vector[n] Divorce_z; // Outcome: divorce rate
vector[n] Marriage_z; // "Treatment": marriage rate
vector[n] MedianAgeMarriage_z; // Confounder: age
}
parameters {
real a;
real bM;
real bA;
real<lower=0> sigma;
}
transformed parameters {
vector[n] mu;
mu = a + bM*Marriage_z + bA*MedianAgeMarriage_z;
}
model {
// Likelihood
Divorce_z ~ normal(mu, sigma);
// Priors
0, 0.2);
a ~ normal(0, 0.5);
bM ~ normal(0, 0.5);
bA ~ normal(1);
sigma ~ exponential(
}
generated quantities {
vector[n] Divorce_z_rep;
for (i in 1:n) {
Divorce_z_rep[i] = normal_rng(mu[i], sigma);
} }
<- WaffleDivorce %>%
stan_data select(Divorce_z, Marriage_z, MedianAgeMarriage_z) %>%
compose_data()
<- rstan::sampling(
model_marriage_divorce_stan object = marriage_divorce_stan,
data = stan_data,
iter = 2000, warmup = 1000, seed = BAYES_SEED, chains = 4, cores = 4
)
print(model_marriage_divorce_stan,
pars = c("a", "bM", "bA", "sigma"))
## Inference for Stan model: 8cc6e06905b678b9147ee76469c82d06.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
## a 0.00 0 0.10 -0.19 -0.07 0.00 0.07 0.20 3754 1
## bM -0.06 0 0.15 -0.36 -0.17 -0.06 0.04 0.24 2518 1
## bA -0.61 0 0.15 -0.91 -0.71 -0.61 -0.51 -0.31 2568 1
## sigma 0.83 0 0.09 0.68 0.76 0.82 0.88 1.02 3049 1
##
## Samples were drawn using NUTS(diag_e) at Wed Sep 21 11:06:44 2022.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
# get_variables(model_marriage_divorce_stan)
%>%
model_marriage_divorce_stan gather_draws(a, bM, bA, sigma) %>%
mutate(.variable = factor(.variable, levels = c("a", "bM", "bA", "sigma"))) %>%
ggplot(aes(x = .value, y = fct_rev(.variable))) +
stat_halfeye() +
coord_cartesian(xlim = c(-1, 1))
Simulating causal effects
We can make counterfactual plots if we model the whole system, just like the “full luxury Bayes” model from video 4.
We want to know the causal effect of the marriage rate on the divorce rate, or:
\[ E(\text{Divorce rate} \mid \operatorname{do}(\text{Marriage rate})) \]
Here’s model for the whole system:
\[ \begin{aligned} M_i &\sim \mathcal{N}(\nu_i, \tau) \\ D_i &\sim \mathcal{N}(\mu_i, \sigma) \\ \nu_i &= \alpha_M + \beta_{AM} A_i \\ \mu_i &= \alpha + \beta_M M_i + \beta_A A_i \\ \\ \alpha_M &\sim \mathcal{N}(0, 0.2) \\ \alpha &\sim \mathcal{N}(0, 0.2) \\ \beta_{AM} &\sim \mathcal{N}(0, 0.5) \\ \beta_M &\sim \mathcal{N}(0, 0.5) \\ \beta_A &\sim \mathcal{N}(0, 0.5) \\ \tau &\sim \operatorname{Exponential}(1) \\ \sigma &\sim \operatorname{Exponential}(1) \end{aligned} \]
<- c(prior(normal(0, 0.2), class = Intercept, resp = Divorcez),
priors prior(normal(0, 0.5), class = b, coef = "Marriage_z", resp = Divorcez),
prior(normal(0, 0.5), class = b, coef = "MedianAgeMarriage_z", resp = Divorcez),
prior(exponential(1), class = sigma, resp = Divorcez),
prior(normal(0, 0.2), class = Intercept, resp = Marriagez),
prior(normal(0, 0.5), class = b, coef = "MedianAgeMarriage_z", resp = Marriagez),
prior(exponential(1), class = sigma, resp = Marriagez))
<- brm(
model_dag_full bf(Divorce_z ~ Marriage_z + MedianAgeMarriage_z) +
bf(Marriage_z ~ MedianAgeMarriage_z) +
set_rescor(FALSE),
data = WaffleDivorce,
family = gaussian(),
prior = priors,
chains = 4, cores = 4, seed = BAYES_SEED
)## Compiling Stan program...
## Start sampling
print(model_dag_full)
## Family: MV(gaussian, gaussian)
## Links: mu = identity; sigma = identity
## mu = identity; sigma = identity
## Formula: Divorce_z ~ Marriage_z + MedianAgeMarriage_z
## Marriage_z ~ MedianAgeMarriage_z
## Data: WaffleDivorce (Number of observations: 50)
## Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
## total post-warmup draws = 4000
##
## Population-Level Effects:
## Estimate Est.Error l-95% CI u-95% CI Rhat
## Divorcez_Intercept 0.00 0.10 -0.20 0.19 1.00
## Marriagez_Intercept 0.00 0.09 -0.18 0.19 1.00
## Divorcez_Marriage_z -0.06 0.16 -0.36 0.25 1.00
## Divorcez_MedianAgeMarriage_z -0.61 0.16 -0.91 -0.29 1.00
## Marriagez_MedianAgeMarriage_z -0.69 0.10 -0.89 -0.48 1.00
## Bulk_ESS Tail_ESS
## Divorcez_Intercept 5471 2861
## Marriagez_Intercept 5718 2909
## Divorcez_Marriage_z 3565 3160
## Divorcez_MedianAgeMarriage_z 3484 2710
## Marriagez_MedianAgeMarriage_z 4943 2673
##
## Family Specific Parameters:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sigma_Divorcez 0.83 0.09 0.68 1.02 1.00 4322 2680
## sigma_Marriagez 0.71 0.07 0.58 0.87 1.00 4992 3220
##
## Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
<- tibble(MedianAgeMarriage_z = seq(from = -2, to = 2, length.out = 40),
sim_age_divorce Marriage_z = 0) %>%
add_predicted_draws(model_dag_full, resp = "Divorcez")
ggplot(sim_age_divorce, aes(x = MedianAgeMarriage_z, y = .prediction)) +
stat_lineribbon(.width = 0.89, color = clrs[5], fill = clrs[5], alpha = 0.5) +
labs(title = "Total counterfactual effect of age on divorce rate",
subtitle = "A → D in the DAG")
<- tibble(MedianAgeMarriage_z = seq(from = -2, to = 2, length.out = 40)) %>%
sim_age_marriage add_predicted_draws(model_dag_full, resp = "Marriagez")
ggplot(sim_age_marriage, aes(x = MedianAgeMarriage_z, y = .prediction)) +
stat_lineribbon(.width = 0.89, color = clrs[6], fill = clrs[6], alpha = 0.5) +
labs(title = "Counterfactual effect of age on marriage rate",
subtitle = "A → M in the DAG")
<- tibble(Marriage_z = seq(from = -2, to = 2, length.out = 40),
sim_age_marriage_divorce MedianAgeMarriage_z = 0) %>%
add_predicted_draws(model_dag_full, resp = "Marriagez")
ggplot(sim_age_marriage_divorce, aes(x = Marriage_z, y = .prediction)) +
stat_lineribbon(.width = 0.89, color = clrs[3], fill = clrs[3], alpha = 0.5) +
labs(title = "Total counterfactual effect of marriage rate on divorce rate",
subtitle = "M → D, after adjusting for A in the DAG, or E(D | do(M))")
marriage_dag_full_stan.stan
data {
int<lower=1> n; // Observations
vector[n] Divorce_z; // Outcome: divorce rate
vector[n] Marriage_z; // "Treatment": marriage rate
vector[n] MedianAgeMarriage_z; // Confounder: age
}
parameters {
// Age -> Marriage
real aM;
real bAM;
real<lower=0> tau;
// Age -> Divorce <- Marriage
real a;
real bM;
real bA;
real<lower=0> sigma;
}
model {
vector[n] nu;
vector[n] mu;
// Age -> Marriage
0, 0.2);
aM ~ normal(0, 0.5);
bAM ~ normal(1);
tau ~ exponential(
nu = aM + bAM*MedianAgeMarriage_z;
Marriage_z ~ normal(nu, tau);
// Age -> Divorce <- Marriage
0, 0.2);
a ~ normal(0, 0.5);
bM ~ normal(0, 0.5);
bA ~ normal(1);
sigma ~ exponential(
mu = a + bM*Marriage_z + bA*MedianAgeMarriage_z;
Divorce_z ~ normal(mu, sigma);
}
generated quantities {
vector[n] Divorce_z_rep;
vector[n] Marriage_z_rep;
vector[n] divorce_do_marriage;
for (i in 1:n) {
real nu_hat_n = aM + bAM*MedianAgeMarriage_z[i];
real mu_hat_n = a + bM*Marriage_z[i] + bA*MedianAgeMarriage_z[i];
Marriage_z_rep[i] = normal_rng(nu_hat_n, tau);
Divorce_z_rep[i] = normal_rng(mu_hat_n, sigma);0, sigma);
divorce_do_marriage[i] = normal_rng(a + bM*Marriage_z_rep[i] + bA*
} }
<- WaffleDivorce %>%
stan_data select(Divorce_z, Marriage_z, MedianAgeMarriage_z) %>%
compose_data()
<- rstan::sampling(
model_marriage_dag_full_stan object = marriage_dag_full_stan,
data = stan_data,
iter = 2000, warmup = 1000, seed = BAYES_SEED, chains = 4, cores = 4
)
print(model_marriage_dag_full_stan,
pars = c("aM", "bAM", "tau", "a", "bM", "bA", "sigma"))
## Inference for Stan model: ea32b2c9a1ab179009a8845d85ea5d42.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
## aM 0.00 0 0.09 -0.18 -0.07 0.00 0.06 0.18 5552 1
## bAM -0.69 0 0.10 -0.88 -0.75 -0.69 -0.63 -0.50 5583 1
## tau 0.71 0 0.08 0.58 0.66 0.70 0.76 0.88 5465 1
## a 0.00 0 0.10 -0.20 -0.07 0.00 0.07 0.20 5541 1
## bM -0.06 0 0.16 -0.37 -0.17 -0.06 0.05 0.25 2842 1
## bA -0.61 0 0.16 -0.92 -0.71 -0.61 -0.50 -0.29 3314 1
## sigma 0.83 0 0.09 0.68 0.77 0.82 0.88 1.03 5675 1
##
## Samples were drawn using NUTS(diag_e) at Wed Sep 21 11:07:06 2022.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
<- model_marriage_dag_full_stan %>%
stan_age_divorce spread_draws(Divorce_z_rep[i]) %>%
mean_hdci() %>%
mutate(age = WaffleDivorce$MedianAgeMarriage_z)
ggplot(stan_age_divorce, aes(x = age, y = Divorce_z_rep)) +
geom_line(color = clrs[5]) +
geom_ribbon(aes(ymin = .lower, ymax = .upper), alpha = 0.2, fill = clrs[5]) +
coord_cartesian(xlim = c(-2, 2)) +
labs(title = "Total counterfactual effect of age on divorce rate",
subtitle = "A → D in the DAG")
<- model_marriage_dag_full_stan %>%
stan_age_marriage spread_draws(Marriage_z_rep[i]) %>%
mean_hdci() %>%
mutate(age = WaffleDivorce$MedianAgeMarriage_z)
ggplot(stan_age_marriage, aes(x = age, y = Marriage_z_rep)) +
geom_line(color = clrs[6]) +
geom_ribbon(aes(ymin = .lower, ymax = .upper), alpha = 0.2, fill = clrs[6]) +
coord_cartesian(xlim = c(-2, 2)) +
labs(title = "Counterfactual effect of age on marriage rate",
subtitle = "A → M in the DAG")
<- model_marriage_dag_full_stan %>%
stan_age_marriage_divorce spread_draws(divorce_do_marriage[i]) %>%
mean_hdci() %>%
mutate(age = WaffleDivorce$MedianAgeMarriage_z)
ggplot(stan_age_marriage_divorce, aes(x = age, y = divorce_do_marriage)) +
geom_line(color = clrs[3]) +
geom_ribbon(aes(ymin = .lower, ymax = .upper), alpha = 0.2, fill = clrs[3]) +
coord_cartesian(xlim = c(-2, 2)) +
labs(title = "Total counterfactual effect of marriage rate on divorce rate",
subtitle = "M → D, after adjusting for A in the DAG, or E(D | do(M))")
The pipe (mediators)
\[ X \rightarrow Z \rightarrow Y \]
\(X\) and \(Y\) are associated (\(Y \notind X\)) because influence of \(X\) is passed to \(Y\) through \(Z\). After adjusting for \(Z\), though, there’s no association, or \(Y \ind X \mid Z\).
Simulated example
<- 1000
n
<- tibble(X = rbinom(n, 1, prob = 0.5)) %>%
pipe_sim # When X is 0, there's a 10% chance of Z being 1
# When X is 1, there's a 90% chance of Z being 1
# When Z is 0, there's a 10% chance of Y being 1
# When Z is 1, there's a 90% chance of Y being 1
mutate(Z = rbinom(n, 1, prob = ((1 - X) * 0.1) + (X * 0.9)),
Y = rbinom(n, 1, prob = ((1 - Z) * 0.1) + (Z * 0.9)))
%>%
pipe_sim select(-Z) %>%
table()
## Y
## X 0 1
## 0 403 92
## 1 73 432
%>%
pipe_sim summarize(cor = cor(X, Y))
## # A tibble: 1 × 1
## cor
## <dbl>
## 1 0.670
But if we adjust for \(Z\), \(Y \ind X \mid Z\):
%>%
pipe_sim select(X, Y, Z) %>%
table()
## , , Z = 0
##
## Y
## X 0 1
## 0 401 58
## 1 33 4
##
## , , Z = 1
##
## Y
## X 0 1
## 0 2 34
## 1 40 428
%>%
pipe_sim group_by(Z) %>%
summarize(cor = cor(X, Y))
## # A tibble: 2 × 2
## Z cor
## <int> <dbl>
## 1 0 -0.0145
## 2 1 -0.0279
This also works with continuous data. When looking at all values of \(Z\), there’s a positive slope and relationship; when looking within each group, the relationship is 0 and flat.
<- 300
n
<- tibble(X = rnorm(n, 0, 1)) %>%
pipe_sim_cont mutate(Z = rbinom(n, 1, plogis(X)),
Y = rnorm(n, (2 * Z - 1), 1))
ggplot(pipe_sim_cont, aes(x = X, y = Y, color = factor(Z))) +
geom_point() +
geom_smooth(method = "lm") +
geom_smooth(aes(color = NULL), method = "lm")
Fungus experiment example
With this DAG, we shouldn’t adjust for \(F\), since that would block the effect of the fungus, which in this case is super important since the causal mechanism pretty much only flows through \(F\). If we adjust for \(F\), we’ll get the causal effect of the treatment on height without the effect of the fungus, which is weird and probably 0.
<- dagify(
plant_fungus_dag ~ t + f + h0,
h1 ~ t,
f exposure = "t",
outcome = "h1",
labels = c(t = "Treatment", h1 = "Height, t=1", f = "Fungus", h0 = "Height, t=0"),
coords = list(x = c(t = 1, h1 = 3, f = 2, h0 = 3),
y = c(t = 1, h1 = 1, f = 2, h0 = 2))) %>%
tidy_dagitty() %>%
node_status()
ggplot(plant_fungus_dag, aes(x = x, y = y, xend = xend, yend = yend)) +
geom_dag_edges() +
geom_dag_point(aes(color = status)) +
geom_dag_text(aes(label = label), size = 3.5, color = "black") +
scale_color_manual(values = clrs[c(1, 4)], guide = "none") +
theme_dag()
In general this is called post-treatment bias and it is bad.
The collider (colliders, obvs)
\[ X \rightarrow Z \leftarrow Y \]
\(X\) and \(Y\) are not associated (\(Y \ind X\)), but they both influence \(Z\). Once you adjust for \(Z\), \(X\) and \(Y\) become associated and \(Y \notind X \mid Z\).
When we learn about \(Z\) (or stratify by \(Z\), or only look at specific values of \(Z\)), we necessarily learn something about \(X\) and \(Y\), since they helped generate \(Z\)
Simulated example
<- 1000
n
<- tibble(X = rbinom(n, 1, prob = 0.5),
collider_sim Y = rbinom(n, 1, prob = 0.5)) %>%
# If either X and Y are 1, there's a 90% chance that Z will be 1
mutate(Z = rbinom(n, 1, prob = ifelse(X + Y > 0, 0.9, 0.2)))
# These are independent
%>%
collider_sim select(-Z) %>%
table()
## Y
## X 0 1
## 0 248 253
## 1 240 259
# No correlation
%>%
collider_sim summarize(cor = cor(X, Y))
## # A tibble: 1 × 1
## cor
## <dbl>
## 1 0.0141
When we adjust for \(Z\), though, \(Y \notind X \mid Z\):
%>%
collider_sim select(X, Y, Z) %>%
table()
## , , Z = 0
##
## Y
## X 0 1
## 0 206 25
## 1 27 17
##
## , , Z = 1
##
## Y
## X 0 1
## 0 42 228
## 1 213 242
# They're correlated!
%>%
collider_sim group_by(Z) %>%
summarize(cor = cor(X, Y))
## # A tibble: 2 × 2
## Z cor
## <int> <dbl>
## 1 0 0.283
## 2 1 -0.316
As with the others, this works with continuous data too. When ignoring values of \(Z\), there’s no relationship between \(X\) and \(Y\). But once we adjust for or stratify by \(Z\), there’s a relationship within each group.
<- 300
n
<- tibble(X = rnorm(n, 0, 1),
collider_sim_cont Y = rnorm(n, 0, 1)) %>%
mutate(Z = rbinom(n, 1, plogis(2*X + 2*Y - 2)))
ggplot(collider_sim_cont, aes(x = X, y = Y, color = factor(Z))) +
geom_point() +
geom_smooth(method = "lm") +
geom_smooth(aes(color = NULL), method = "lm")
Grant selection example
set.seed(1914)
<- 200
n
<- tibble(newsworthiness = rnorm(n, 0, 1),
grants trustworthiness = rnorm(n, 0, 1)) %>%
mutate(total = newsworthiness + trustworthiness) %>%
# Select just the top 10%
mutate(q = quantile(total, 1 - 0.1)) %>%
mutate(selected = total >= q)
# No relationship
%>%
grants summarize(cor = cor(newsworthiness, trustworthiness))
## # A tibble: 1 × 1
## cor
## <dbl>
## 1 -0.0672
# Relationship!
%>%
grants group_by(selected) %>%
summarize(cor = cor(newsworthiness, trustworthiness))
## # A tibble: 2 × 2
## selected cor
## <lgl> <dbl>
## 1 FALSE -0.274
## 2 TRUE -0.768
ggplot(grants, aes(x = newsworthiness, y = trustworthiness, color = selected)) +
geom_point() +
geom_smooth(data = filter(grants, selected), method = "lm") +
geom_smooth(aes(color = "Full sample"), method = "lm")
The descendant
Like a confounder if it comes from a confounder; like a mediator if it comes from a mediator; like a collider if it comes from a collider.
\(X\) and \(Y\) are causally associated through \(Z\), which implies that \(Y \notind X\). \(A\) contains information about \(Z\), so once we stratify by or adjust for \(A\), \(X\) and \(Y\) become less associated (if \(A\) is strong enough), implying \(Y \ind X \mid A\)
That can be good (if \(A\) is confounder-flavored) or bad (if \(A\) is mediator- or collider-flavored).
<- dagify(
desc_confounder_dag ~ Z,
Y ~ Z,
X ~ Z,
A coords = list(x = c(X = 1, Y = 3, Z = 2, A = 2),
y = c(X = 1, Y = 1, Z = 1, A = 0))) %>%
tidy_dagitty()
<- dagify(
desc_mediator_dag ~ Z,
Y ~ X,
Z ~ Z,
A coords = list(x = c(X = 1, Y = 3, Z = 2, A = 2),
y = c(X = 1, Y = 1, Z = 1, A = 0))) %>%
tidy_dagitty()
<- dagify(
desc_collider_dag ~ X + Y,
Z ~ Z,
A coords = list(x = c(X = 1, Y = 3, Z = 2, A = 2),
y = c(X = 1, Y = 1, Z = 1, A = 0))) %>%
tidy_dagitty()
<- ggplot(desc_confounder_dag,
plot_desc_confounder aes(x = x, y = y, xend = xend, yend = yend)) +
geom_dag_edges() +
geom_dag_point() +
geom_dag_text(aes(label = name), size = 3.5, color = "white") +
ylim(c(-0.25, 1.25)) +
labs(subtitle = "Confounder-flavored descendant") +
theme_dag() +
theme(plot.subtitle = element_text(hjust = 0.5, face = "bold"))
<- ggplot(desc_mediator_dag,
plot_desc_mediator aes(x = x, y = y, xend = xend, yend = yend)) +
geom_dag_edges() +
geom_dag_point() +
geom_dag_text(aes(label = name), size = 3.5, color = "white") +
ylim(c(-0.25, 1.25)) +
labs(subtitle = "Mediator-flavored descendant") +
theme_dag() +
theme(plot.subtitle = element_text(hjust = 0.5, face = "bold"))
<- ggplot(desc_collider_dag,
plot_desc_collider aes(x = x, y = y, xend = xend, yend = yend)) +
geom_dag_edges() +
geom_dag_point() +
geom_dag_text(aes(label = name), size = 3.5, color = "white") +
ylim(c(-0.25, 1.25)) +
labs(subtitle = "Collider-flavored descendant") +
theme_dag() +
theme(plot.subtitle = element_text(hjust = 0.5, face = "bold"))
+ plot_desc_mediator + plot_desc_collider plot_desc_confounder