--- title: "Continous Treatment" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{Continous Treatment} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r, include = FALSE} knitr::opts_chunk$set( collapse = TRUE, comment = "#>" ) knitr::opts_chunk$set(fig.width=7, fig.height=4) set.seed(123) ``` ```{r setup} library(RcausalEGM) ``` ```{r check_pkg} if (!(reticulate::py_module_available('CausalEGM'))){ cat("## Please install the CausalEGM package using the function: install_causalegm()") knitr::knit_exit() } ``` ## Data Generation Let's first generate a simulation dataset with continuous treatment. ```{r simulation} n <- 1000 p <- 200 v <- matrix(rexp(n * p), n, p) rate <- v[,1] + v[,2] scale = 1/rate x = rexp(n=n, rate=rate) y = rnorm(n=n, mean = x + (v[,1] + v[,3])*exp(-x * (v[,1] + v[,3])), sd=1) ``` Let's take a look at the simulation data. ```{r look} oldpar <-par(mfrow=c(1,3)) hist(x, breaks="FD", xlim=c(0,7), col="blue",xlab="x values") hist(y, breaks="FD", xlim=c(0,7), col="red",xlab="y values") boxplot(v[,1:5],main="First five covariates", xlab="Covariate index", ylab="v values") par(oldpar) ``` ## Model training Start training a CausalEGM model. Users can refer to the core API "causalegm" by help(causalegm) for detailed usage. Note that the parameters for x, y, v are required. Besides, users can also specify the z_dims as a integer list with four elements. (x_min, x_max) is the interval of treatment on which we evaluate the causal effect. The key binary_treatment is set to be FALSE as we are dealing with continuous treatment. ```{r train} #help(causalegm) model <- causalegm(x=x,y=y,v=v, n_iter=2000, binary_treatment = FALSE, use_v_gan = FALSE, x_min=0, x_max=3) ``` ## Treatment Effect Estimation The average causal effect evaluated at the 200 uniform points in the interval (x_min, x_max) can be directly obtained from the trained model. ```{r ace} ACE <- model$causal_pre boxplot(ACE, main="ACE distribution", ylab="Values") ``` ## Plot of the average dose-response function (ADRF) We now compare plot of the true ADRF and the ADRF estimated by our model. ```{r adrf} grid_val = seq(0, 3, length.out=200) true_effect = grid_val + 2 * (1+grid_val)**(-3) plot(true_effect, col=2, xlab="Treatment", ylab="Average Dose Response") lines(ACE, col=3) legend(x = "topleft", # Position legend = c("True ADRF", "Estimated ADRF"), lty = c(1, 1), # Line types col = c(2, 3), # Line colors lwd = 2) ```