--- title: "Illustration with Nonlinear Effects" package: multimedia output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{Illustration with Nonlinear Effects} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r, include = FALSE} knitr::opts_chunk$set( fig.align = "center", collapse = TRUE, comment = "#>", warning = FALSE, message = FALSE, cache = FALSE, dev.args = list(bg = "transparent"), out.width = 600, crop = NULL ) knitr::knit_hooks$set(output = multimedia::ansi_aware_handler) suppressPackageStartupMessages(library(ggplot2)) options( ggplot2.discrete.colour = c( "#9491D9", "#F24405", "#3F8C61", "#8C2E62", "#F2B705", "#11A0D9" ), ggplot2.discrete.fill = c( "#9491D9", "#F24405", "#3F8C61", "#8C2E62", "#F2B705", "#11A0D9" ), ggplot2.continuous.colour = function(...) { scale_color_distiller(palette = "Spectral", ...) }, ggplot2.continuous.fill = function(...) { scale_fill_distiller(palette = "Spectral", ...) }, crayon.enabled = TRUE ) th <- theme_classic() + theme( panel.background = element_rect(fill = "transparent"), strip.background = element_rect(fill = "transparent"), plot.background = element_rect(fill = "transparent", color = NA), panel.grid.major = element_blank(), panel.grid.minor = element_blank(), legend.background = element_rect(fill = "transparent"), legend.box.background = element_rect(fill = "transparent"), legend.position = "bottom" ) theme_set(th) ``` This vignette gives a simplified example similar to the "Quick Start with Random Data," except in this simulated dataset, there are true direct and indirect effects. The data are low dimensional, which also makes it easy to visualize our model's behavior and the effect of various alterations. Before discussing data and models, let's load libraries and define helper functions that will be used for reorganizing the data later in this vignette. ```{r setup} library(glue) library(tidyverse) library(multimedia) set.seed(20240111) #' Helper to plot the real data plot_exper <- function(fit, profile) { pivot_samples(fit, profile) |> ggplot() + geom_point(aes(mediator, value, col = factor(treatment))) + facet_grid(. ~ outcome) + theme(legend.position = "bottom") } #' Convert a SummarizedExperiment to long form pivot_samples <- function(fit, profile) { exper_sample <- sample(fit, profile = profile) outcomes(exper_sample) |> bind_cols(mediators(exper_sample), t_outcome) |> pivot_longer(starts_with("outcome"), names_to = "outcome") } ``` We will consdier a mediation analysis with two outcomes and a single mediator. The toy dataset below is included in the package -- you can vary the number of samples and true direct effect size by passing in arguments. the treatment, mediator, and outcomes. We use random spline functions for the direct effect. ```{r prepare_data} xy_data <- demo_spline() xy_long <- list( true = pivot_longer(xy_data, starts_with("outcome"), names_to = "outcome") ) ``` The block below estimates a random forest model on these randomly generated data (after first converting them into an object of class `mediation_data`). We can vary `ranger` parameters using the arguments to `rf_model()`. ```{r fit_model} exper <- mediation_data( xy_data, starts_with("outcome"), "treatment", "mediator" ) fit <- multimedia( exper, outcome_estimator = rf_model(num.trees = 1e3) ) |> estimate(exper) ``` Let's look at the estimated effects. ```{r effect_estimates} direct_effect(fit) |> effect_summary() indirect_overall(fit) |> effect_summary() ``` Mediation analysis depends on untestable assumptions on the absence of confounders that influence both the mediators and outcome. For this reason, we can run a sensitivity analysis: How strong would the confounder have to be before we would observe a flipped indirect effect estimate? ```{r sensitivity_analysis} confound_ix <- expand.grid(mediator = 1, outcome = 1:2) # sensitivity_curve <- sensitivity(fit, exper, confound_ix) sensitivity_curve <- read_csv("https://go.wisc.edu/0xcyr1") plot_sensitivity(sensitivity_curve) ``` ```{r} perturb <- matrix( c( 0, 3, 0, 3, 0, 0, 0, 0, 0 ), nrow = 3, byrow = TRUE ) # sensitivity_curve <- sensitivity_perturb(fit, exper, perturb) sensitivity_curve <- read_csv("https://go.wisc.edu/75mz1b") plot_sensitivity(sensitivity_curve, x_var = "nu") ``` If we want to look at samples at new treatment assignments, we have to first construct a new treatment profile. We can accomplish this using the `setup_profile` object. ```{r} t_mediator <- tibble(treatment = factor(rep(0:1, each = nrow(exper) / 2))) t_outcome <- tibble(treatment = factor(rep(0:1, each = nrow(exper) / 2))) profile <- setup_profile(fit, t_mediator, t_outcome) xy_long[["fitted"]] <- pivot_samples(fit, profile) ``` Let's estimate three altered models. The first two nullify the effects in the mediation and outcome models, respectively. The third assumes a linear, rather than random forest, outcome model. ```{r alteration} altered_m <- nullify(fit, "T->M") |> estimate(exper) altered_ty <- nullify(fit, "T->Y") |> estimate(exper) fit_lm <- multimedia( exper, outcome_estimator = lm_model() ) |> estimate(exper) ``` We can sample from these models and organize the results into one long data.frame. The function `pivot_samples` is defined at the top of the script. ```{r reshape_altered} pretty_labels <- c( "Original Data", "RF (Full)", "RF (T-\\->M)", "RF (T-\\->Y)", "Linear Model" ) xy_long <- c( xy_long, list( altered_m = pivot_samples(altered_m, profile), altered_ty = pivot_samples(altered_ty, profile), linear = pivot_samples(fit_lm, profile) ) ) |> bind_rows(.id = "fit_type") |> mutate( fit_type = case_when( fit_type == "linear" ~ "Linear Model", fit_type == "true" ~ "Original Data", fit_type == "fitted" ~ "RF (Full)", fit_type == "altered_ty" ~ "RF (T-\\->Y)", fit_type == "altered_m" ~ "RF (T-\\->M)" ), fit_type = factor(fit_type, levels = pretty_labels), outcome = case_when( outcome == "outcome_1" ~ "Y[1]", outcome == "outcome_2" ~ "Y[2]" ) ) ``` The figure below compares the original data and model with the various nullified versions, using the combined `xy_long` dataset constructed above. ```{r visualize_long, fig.width = 8, fig.height = 3} xy_long |> sample_frac(size = 0.1) |> ggplot(aes(mediator, col = treatment)) + geom_point(aes(y = value), size = 0.4, alpha = 0.9) + geom_rug(alpha = 0.99, linewidth = 0.1) + facet_grid(outcome ~ fit_type) + labs(x = expression("Mediator M"), y = "Outcome", col = "Treatment") + theme( strip.text = element_text(size = 11), axis.title = element_text(size = 14), legend.title = element_text(size = 12) ) ``` We can calculate bootstrap confidence intervals for the direct and overall indirect effects using the block below. ```{r bootstraps} fs <- list(direct_effect = direct_effect, indirect_overall = indirect_overall) # bootstraps <- bootstrap(fit, exper, fs = fs) bootstraps <- readRDS(url("https://go.wisc.edu/977l04")) ggplot(bootstraps$direct_effect) + geom_histogram(aes(direct_effect)) + facet_wrap(~outcome) ggplot(bootstraps$indirect_overall) + geom_histogram(aes(indirect_effect)) + facet_wrap(~outcome) ``` ```{r} sessionInfo() ```