In this vignette, we show how to use BCF to estimate treatment effects of a simulated intervention.
library(bcf)
library(ggplot2)
library(latex2exp)
library(rpart)
library(rpart.plot)
library(partykit)
First, we simulate some data for testing. This data set has an outcome variable y, a treatment indicator z, and three covariates x1, x2 and x3. Of the covariates, two affect the outcome y at both levels of treatment, while the third is an effect moderator.
We draw three random xs for each unit and generate each unit’s expected outcome without treatment, μ, as a function of x1 and x2. Each unit’s probability of joining the intervention, π, is also a function of μ, so that units with larger responses under control are more likely to participate in the intervention. We then assign units to treatment (z = 1) or comparison (z = 0) status as a function of π.
Then we generate the true treatment effect for each unit, τ. As noted above, τ is a function of x3. The observed outcome, y, is a function of μ, τ, a random error term with variance σ2, and weights w if applicable.
set.seed(1)
p <- 3 # two control variables and one effect moderator
n <- 1000 # number of observations
n_burn <- 2000
n_sim <- 1000
x <- matrix(rnorm(n*(p-1)), nrow=n)
x <- cbind(x, x[,2] + rnorm(n))
weights <- abs(rnorm(n))
# create targeted selection, whereby a practice's likelihood of joining the intervention (pi)
# is related to their expected outcome (mu)
mu <- -1*(x[,1]>(x[,2])) + 1*(x[,1]<(x[,2])) - 0.1
# generate treatment variable
pi <- pnorm(mu)
z <- rbinom(n,1,pi)
# tau is the true treatment effect. It varies across practices as a function of
# X3 and X2
tau <- 1/(1 + exp(-x[,3])) + x[,2]/10
# generate the expected response using mu, tau and z
y_noiseless <- mu + tau*z
# set the noise level relative to the expected mean function of Y
sigma <- diff(range(mu + tau*pi))/8
# draw the response variable with additive error
y <- y_noiseless + sigma*rnorm(n)/sqrt(weights)
In this data set we have observed y, x, and π values to which we can fit our BCF model. With BCF, we can distinguish between control variables – which affect the outcome at both levels of treatment – and moderator variables – which affect the estimated treatment effect.
Note that we are using the n_chains
argument to
bcf()
, which allows us to run several MCMC chains in
parallel and assess whether they have converged to the posterior
distribution.
After fitting the BCF model, we can compare the τ̂ estimates from BCF to the true τ from the data-generating process.
We use the summary.bcf
function to obtain posterior
summary statistics and MCMC diagnostics. We use those diagnostics to
assess convergence of our run.
summary(bcf_out)
#> Summary statistics for each Markov Chain Monte Carlo run
#>
#> Iterations = 1:1000
#> Thinning interval = 1
#> Number of chains = 4
#> Sample size per chain = 1000
#>
#> 1. Empirical mean and standard deviation for each variable,
#> plus standard error of the mean:
#>
#> Mean SD Naive SE Time-series SE
#> sigma 0.37556 0.008577 0.0001356 0.0001689
#> tau_bar 0.49308 0.040953 0.0006475 0.0012522
#> mu_bar -0.08981 0.023568 0.0003726 0.0005772
#> yhat_bar 0.19119 0.012922 0.0002043 0.0002069
#>
#> 2. Quantiles for each variable:
#>
#> 2.5% 25% 50% 75% 97.5%
#> sigma 0.3592 0.3698 0.37533 0.38120 0.39302
#> tau_bar 0.4131 0.4656 0.49293 0.52112 0.57317
#> mu_bar -0.1362 -0.1058 -0.08954 -0.07331 -0.04391
#> yhat_bar 0.1657 0.1825 0.19128 0.20000 0.21660
#>
#>
#> ----
#> Effective sample size for summary parameters
#> Reverting to coda's default ESS calculation. See ?summary.bcf for details.
#>
#> sigma tau_bar mu_bar yhat_bar
#> 2676.224 1077.885 1696.041 3911.178
#>
#> ----
#> Gelman and Rubin's convergence diagnostic for summary parameters
#> Potential scale reduction factors:
#>
#> Point est. Upper C.I.
#> sigma 1 1.00
#> tau_bar 1 1.01
#> mu_bar 1 1.01
#> yhat_bar 1 1.00
#>
#> Multivariate psrf
#>
#> 1
#>
#> ----
Since our “R̂s” (Gelman and Rubin’s convergence factor or Potential Scale Reduction Factor) are between 0.9 and 1.1, we don’t see any obvious mixing issues. Note that mixing on other parameters (like subgroup average treatment effects defined below, or evaluations of τ(x) or μ(x) at particular values of x) might be worse than these aggregate summaries. But in our experience good mixing on σ tends to indicate good mixing on other quantities of interest.
Now that we’ve successfully fit our model, let’s explore the output. First, since this is a simulation, let’s compare the unit-specific treatment effect estimates τ̂ to the true unit-specific treatment effects τ.
We plot the true and estimated τ versus x3, since x3 is one of our primary effect moderators (see the data creation step above).
tau_ests <- data.frame(Mean = colMeans(bcf_out$tau),
Low95 = apply(bcf_out$tau, 2, function(x) quantile(x, 0.025)),
Up95 = apply(bcf_out$tau, 2, function(x) quantile(x, 0.975)))
ggplot(NULL, aes(x = x[,3])) +
geom_pointrange(aes(y = tau_ests$Mean, ymin = tau_ests$Low95, ymax = tau_ests$Up95), color = "forestgreen", alpha = 0.5) +
geom_smooth(aes(y = tau), se = FALSE) +
xlab(TeX("$x_3$")) +
ylab(TeX("$\\hat{\\tau}$")) +
xlim(-4, 6) +
geom_segment(aes(x = 3, xend = 4, y = 0.2, yend = 0.2), color = "blue", alpha = 0.9) +
geom_text(aes(x = 4.5, y = 0.2, label = "Truth"), color = "black") +
geom_segment(aes(x = 3, xend = 4, y = 0.1, yend = 0.1), color = "forestgreen", alpha = 0.7) +
geom_text(aes(x = 5.2, y = 0.1, label = "Estimate (95% CI)"), color = "black")
#> `geom_smooth()` using method = 'gam' and formula = 'y ~ s(x, bs = "cs")'
#> Warning: Removed 6 rows containing non-finite values (`stat_smooth()`).
#> Warning: Removed 6 rows containing missing values (`geom_pointrange()`).
BCF recovers the true unit-specific treatment effects with a high degree of accuracy in this simple example, with conservative estimates at extreme values of the covariate. We can also examine coverage of 95% credible intervals:
isCovered <- function(i){
ifelse(tau_ests$Low95[i] <= tau[i] & tau[i] <= tau_ests$Up95[i], 1, 0)
}
coverage <- lapply(1:length(tau), isCovered)
perCoverage <- sum(unlist(coverage))/length(tau)
perCoverage
#> [1] 1
We find that our estimates have 100 percent coverage of 95 percent credible intervals, suggesting that our uncertainty estimates are slightly conservative in this example.
In this example we see variation in the point estimates across units, although most uncertainty intervals overlap (except for points at the extremes). This is typical and shouldn’t be taken as a sign that there is no meaningful treatment effect variation overall, or that treatment effect variation isn’t correlated with x3.
First, note that these are pointwise or marginal intervals computed at each observed covariate value. However, there is significant posterior correlation between τ(x) and τ(x′) when x and x′ are close so comparing these marginals can be misleading. For example, let’s compare CATEs for two observations that are similar on x1 and x2 but vary by about 1.5 units on x3:
ix = c(635, 614)
print(x[ix,])
#> [,1] [,2] [,3]
#> [1,] -1.572864 -1.353550 -2.1335820
#> [2,] -1.398754 -1.347236 -0.6600108
plot(bcf_out$tau[,ix], xlab = paste0("CATE for observation ", ix[1]),
ylab = paste0("CATE for observation ", ix[2]))
par(mfrow=c(1,2))
plot(density(bcf_out$tau[,ix[1]]), main="Marginal posterior CATEs")
lines(density(bcf_out$tau[,ix[2]]), col='blue')
plot(density(bcf_out$tau[,ix[1]] - bcf_out$tau[,ix[2]]), main="Posterior for \ndifference in CATEs")
abline(v=tau[ix[1]] - tau[ix[2]], lty=2)
par(mfrow=c(1,1))
library(coda)
# These intervals are a little misleading with the bimodal posterior, but still informative
HPDinterval(mcmc(bcf_out$tau[,ix[1]] - bcf_out$tau[,ix[2]]), prob=0.5)
#> lower upper
#> var1 -0.1842529 3.01148e-14
#> attr(,"Probability")
#> [1] 0.5
HPDinterval(mcmc(bcf_out$tau[,ix[1]] - bcf_out$tau[,ix[2]]), prob=0.75)
#> lower upper
#> var1 -0.2917351 3.01148e-14
#> attr(,"Probability")
#> [1] 0.75
HPDinterval(mcmc(bcf_out$tau[,ix[1]] - bcf_out$tau[,ix[2]]), prob=0.9)
#> lower upper
#> var1 -0.4149043 0.001022315
#> attr(,"Probability")
#> [1] 0.9
HPDinterval(mcmc(bcf_out$tau[,ix[1]] - bcf_out$tau[,ix[2]]), prob=0.95)
#> lower upper
#> var1 -0.4709973 0.04536157
#> attr(,"Probability")
#> [1] 0.95
# Approx posterior probability that the difference is leq 0:
mean(bcf_out$tau[,ix[1]] < bcf_out$tau[,ix[2]] + 1e-6)
#> [1] 0.956
Examining the posterior distribution over the difference tells a different story than just looking at the marginal posteriors as in the plot above, suggesting stronger evidence for treatment effect variation between these two observations.
The spike at zero is from posterior draws where these two observations end up in the same leaf nodes for each tree comprising τ and therefore have equal treatment effects under the model. This shows up even more clearly in a scatterplot of posterior draws:
plot(bcf_out$tau[,ix[1]],bcf_out$tau[,ix[2]], main="Posterior for difference in CATEs",
xlab = paste0("CATE for observation ", ix[1]), ylab = paste0("CATE for observation ", ix[2]))
BCF was specifically designed to impose this kind of shrinkage by preferring fewer shallow τ trees that BART defaults.
The larger, more fundamental issue is that it’s more difficult to estimate conditional average treatment effects – and variation between them – when the conditioning set is small. One way to gain precision is by aggregating CATEs into larger groups and comparing subgroup average treatment effects. Aggregating into subgroups is not just a mechanism for gaining precision, but an effective way to summarize and communicate the posterior distribution in its own right.
One obvious thing to do in this example is to compare the posterior distributions for observations with x3 > 0 to those with x3 ≤ 0 (above/below the population median, according to this data generating process):
lo <- x[,3] <= 0
hi <- x[,3] > 0
loTaus <- bcf_out$tau[,lo]
hiTaus <- bcf_out$tau[,hi]
wloTaus <- apply(loTaus, 1, weighted.mean, weights[lo])
whiTaus <- apply(hiTaus, 1, weighted.mean, weights[hi])
groupTaus <- data.frame(taus = c(wloTaus, whiTaus),
subgroup = c(rep("x3_lo", nrow(bcf_out$tau)), rep("x3_hi", 4*n_sim)))
ggplot(groupTaus, aes(taus, fill = subgroup, color = subgroup)) +
geom_density(alpha = 0.5) +
ylab("posterior density") +
xlab("average subgroup treatment effect")
qplot(whiTaus - wloTaus, geom="density", main="Posterior distribution of difference\n between subgroup ATEs",
xlim=c(0, .9), xlab="Difference")+geom_vline(xintercept=0, linetype=2)
#> Warning: `qplot()` was deprecated in ggplot2 3.4.0.
#> This warning is displayed once every 8 hours.
#> Call `lifecycle::last_lifecycle_warnings()` to see where this warning was generated.
Providing additional strong evidence that treatment effect variation is correlated with x3. Of course, 0 was an arbitrary cutoff, and we picked x3 because we knew it was important. Hahn, Murray, and Carvalho (2020) propose an alternative decision-theoretic approach to discovering cutpoints to define subgroups using CART.
To help summarize the posterior Hahn, Murray, and Carvalho (2020) recommend fitting a CART model where the response variable is the posterior mean unit-specific estimated treatment effect from BCF (τ̂(xi)) and the predictor variables are the effect modifiers used to fit the BCF model. This is an attempt to find the Bayes estimate of subgroups represented by a recursive partition (under squared error loss on τ with complexity constraints/penalties); it is only an attempt to find the Bayes estimate because of the greedy nature of the CART algorithm.
The complexity constraints or penalties on the tree are up to the user, but there is no particular reason to adopt rpart’s default complexity penalty. If the goal is subgroup identification a reasonable way to proceed is to grow the trees up to some depth that gives credible (i.e. relatively simple) definitions of subgroups and prune back based on the size of the difference in estimated subgroup effects across end nodes, or posterior probabilities of a meaningful difference, or other application-specific criteria.
xdf = data.frame(x); colnames(xdf) = c(paste0("x", 1:ncol(x)))
tree <- rpart(colMeans(bcf_out$tau) ~ ., data=xdf, weights = weights,
control=rpart.control(maxdepth=3, cp=-1))
rpart.plot(tree)
We seem to capture the most treatment effect variability with the x3 splits, although x2 makes an appearance as well with patterns that line up with our data generating process.
We forced the tree above to be the maximal depth 3 tree by specifying
cp=-1
. Now we consider pruning the tree, which is easiest
using the partykit package:
# party trees are a little easier to prune
ptree = as.party(tree)
print(ptree)
#>
#> Model formula:
#> colMeans(bcf_out$tau) ~ x1 + x2 + x3
#>
#> Fitted party:
#> [1] root
#> | [2] x3 < -0.34703
#> | | [3] x3 < -1.48554
#> | | | [4] x3 < -2.26608: 0.000 (w = 57.8, err = 0.1)
#> | | | [5] x3 >= -2.26608: 0.091 (w = 74.2, err = 0.1)
#> | | [6] x3 >= -1.48554
#> | | | [7] x2 < -0.03596: 0.202 (w = 150.1, err = 0.9)
#> | | | [8] x2 >= -0.03596: 0.366 (w = 53.8, err = 0.4)
#> | [9] x3 >= -0.34703
#> | | [10] x3 < 0.72231
#> | | | [11] x2 < 0.05547: 0.451 (w = 98.8, err = 0.2)
#> | | | [12] x2 >= 0.05547: 0.616 (w = 123.3, err = 0.4)
#> | | [13] x3 >= 0.72231
#> | | | [14] x2 < 0.77641: 0.781 (w = 134.9, err = 0.6)
#> | | | [15] x2 >= 0.77641: 0.932 (w = 134.8, err = 0.3)
#>
#> Number of inner nodes: 7
#> Number of terminal nodes: 8
We can start by getting the posterior distributions for the subgroup ATEs defined by the terminal nodes of the tree, and then looking at the differences between subgroups defined by the last splits of the tree:
# Get terminal node for each observation
subgp = predict(ptree, type='node')
# Get posteriors for subgroup ATEs
subgp_id = sort(unique(subgp))
get_sub_post = function(ix, fit, weights) {
subtaus = fit$tau[,ix]
apply(subtaus, 1, weighted.mean, weights[ix])
}
subgp_post = lapply(subgp_id, function(x) get_sub_post(subgp==x, bcf_out, weights))
names(subgp_post) = subgp_id
# Plot the difference in subgroup effects between adjacent terminal nodes in
# the tree above
plot(density(subgp_post[['4']] - subgp_post[['5']]), main="4 vs 5")
abline(v=0, lty=2)
Splits 10 and 13 pretty clearly identify subgroups with different ATEs, but there is more uncertainty considering splits 3 and 6. If we wanted a simpler summary with coarser subgroups we could prune those splits away and re-examine:
pruned_ptree = nodeprune(ptree, c(3,6))
subgp = predict(pruned_ptree, type='node')
subgp_id = sort(unique(subgp))
get_sub_post = function(ix, fit, weights) {
subtaus = fit$tau[,ix]
apply(subtaus, 1, weighted.mean, weights[ix])
}
subgp_post = lapply(subgp_id, function(x) get_sub_post(subgp==x, bcf_out, weights))
names(subgp_post) = subgp_id
# Note that the node indices here are different than in the code block above
plot(density(subgp_post[['3']] - subgp_post[['4']]), main="3 vs 4")
abline(v=0, lty=2)
We can make other comparisons too, like contrasting the subgroup with the largest CATE with its complement:
largest = subgp==11
post_largest = get_sub_post(largest, bcf_out, weights)
post_others = get_sub_post(!largest, bcf_out, weights)
plot(density(post_largest), xlim=c(0,1.2), ylim=c(0,9), main = "Largest CATE subgroup \n vs others")
lines(density(post_others), lty=2)
plot(density(post_largest - post_others), main = "Difference between largest\n CATE subgroup & others")
Note that there are other ways to think about summarizing the posterior with a tree (e.g. using the tree as a generic function approximation tool rather than a subgroup-finder, in which case we might prefer deeper trees). We will explore these in future vignettes.