Estimating conditional average treatment effects (CATEs) for survival outcomes using R-learner with random forest, which is essentially the inverse-probability-censoring-weighted causal forest (implemented via the grf package). The CATE is defined as tau(X) = p(Y(1) > t0 | X = x) - p(Y(0) > t0 | X = x), where Y(1) and Y(0) are counterfactual survival times under the treated and controlled arms, respectively.

Remark: A random survival forest model is used for estimating nuisance parameters (i.e., nuisance outcomes and inverse-probability-censoring weights), and the estimated nuisances are given as inputs of a causal forest model to estimate the target parameter CATEs

surv_rl_grf(
  X,
  Y,
  W,
  D,
  t0 = NULL,
  k.folds = NULL,
  W.hat = NULL,
  Y.hat = NULL,
  C.hat = NULL,
  new.args.grf.nuisance = list(),
  new.args.cf = list(),
  cen.fit = "Kaplan-Meier"
)

Arguments

X

The baseline covariates

Y

The follow-up time

W

The treatment variable (0 or 1)

D

The event indicator

t0

The prediction time of interest

k.folds

Number of folds for cross validation

W.hat

Propensity score

Y.hat

Conditional mean outcome E(Y|X)

C.hat

Censoring weights

new.args.grf.nuisance

Input arguments for a grf model that estimates nuisance parameters

new.args.cf

Input arguments for a causal_forest model that estimates CATE

cen.fit

The choice of model fitting for censoring

Value

A surv_rl_grf object

Examples

# \donttest{
n <- 1000; p <- 25
t0 <- 0.2
Y.max <- 2
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 0.5)
numeratorT <- -log(runif(n))
T <- (numeratorT / exp(1 * X[ ,1, drop = FALSE] + (-0.5 - 1 * X[ ,2, drop = FALSE]) * W)) ^ 2
failure.time <- pmin(T, Y.max)
numeratorC <- -log(runif(n))
censor.time <- (numeratorC / (4 ^ 2)) ^ (1 / 2)
Y <- pmin(failure.time, censor.time)
D <- as.integer(failure.time <= censor.time)
n.test <- 500
X.test <- matrix(rnorm(n.test * p), n.test, p)

surv.rl.grf.fit <- surv_rl_grf(X, Y, W, D, t0, W.hat = 0.5)
cate <- predict(surv.rl.grf.fit)
cate.test <- predict(surv.rl.grf.fit, X.test)
# }