Other types of models

In the following, we explain the counterfactuals workflow for both a classification and a regression task using concrete use cases.

library("counterfactuals")
library("iml")
library("rpart")

Other types of models

The Predictor class of the iml package provides the necessary flexibility to cover classification and regression models fitted with diverse R packages. In the introduction vignette, we saw models fitted with the mlr3 and randomForest packages. In the following, we show extensions to - an classification tree fitted with the caret package, the mlr (a predecesor of mlr3) and tidymodels. For each model we generate counterfactuals for the 100th row of the plasma dataset of the gamlss.data package using the WhatIf method.

data(plasma, package = "gamlss.data")
x_interest = plasma[100L,]

rpart - caret package

library("caret")
treecaret = caret::train(retplasma ~ ., data = plasma[-100L,], method = "rpart", 
  tuneGrid = data.frame(cp = 0.01))
predcaret = Predictor$new(model = treecaret, data = plasma[-100L,], y = "retplasma")
predcaret$predict(x_interest)
#>   .prediction
#> 1    342.9231
nicecaret = NICERegr$new(predcaret, optimization = "proximity", 
  margin_correct = 0.5, return_multiple = FALSE)
nicecaret$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
#> 1 Counterfactual(s) 
#>  
#> Desired outcome range: [500, Inf] 
#>  
#> Head: 
#>      age    sex smokstat      bmi vituse calories   fat fiber alcohol cholesterol betadiet retdiet betaplasma
#>    <int> <fctr>   <fctr>    <num> <fctr>    <num> <num> <num>   <num>       <num>    <int>   <int>      <int>
#> 1:    46      1        3 35.25969      3   2667.5 131.6  10.1       0       550.5     1210    1291        218

rpart - tidymodels package

library("tidymodels")
treetm = decision_tree(mode = "regression", engine = "rpart") %>% 
  fit(retplasma ~ ., data = plasma[-100L,])
predtm = Predictor$new(model = treetm, data = plasma[-100L,], y = "retplasma")
predtm$predict(x_interest)
#>      .pred
#> 1 342.9231
nicetm = NICERegr$new(predtm, optimization = "proximity", 
  margin_correct = 0.5, return_multiple = FALSE)
nicetm$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
#> 1 Counterfactual(s) 
#>  
#> Desired outcome range: [500, Inf] 
#>  
#> Head: 
#>      age    sex smokstat      bmi vituse calories   fat fiber alcohol cholesterol betadiet retdiet betaplasma
#>    <int> <fctr>   <fctr>    <num> <fctr>    <num> <num> <num>   <num>       <num>    <int>   <int>      <int>
#> 1:    46      1        3 35.25969      3   2667.5 131.6  10.1       0       550.5     1210    1291        218

rpart - mlr package

library("mlr")
#> Warning in fun(pkgname, pkgpath): Packages 'paradox' and 'ParamHelpers' are conflicting and should not be loaded in the same session
#> Warning in fun(pkgname, pkgpath): Packages 'mlr3' and 'mlr' are conflicting and should not be loaded in the same session
task = mlr::makeRegrTask(data = plasma[-100L,], target = "retplasma")
mod = mlr::makeLearner("regr.rpart")

treemlr = mlr::train(mod, task)
predmlr = Predictor$new(model = treemlr, data = plasma[-100L,], y = "retplasma")
predmlr$predict(x_interest)
#>   .prediction
#> 1    342.9231
nicemlr = NICERegr$new(predmlr, optimization = "proximity", 
  margin_correct = 0.5, return_multiple = FALSE)
nicemlr$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
#> 1 Counterfactual(s) 
#>  
#> Desired outcome range: [500, Inf] 
#>  
#> Head: 
#>      age    sex smokstat      bmi vituse calories   fat fiber alcohol cholesterol betadiet retdiet betaplasma
#>    <int> <fctr>   <fctr>    <num> <fctr>    <num> <num> <num>   <num>       <num>    <int>   <int>      <int>
#> 1:    46      1        3 35.25969      3   2667.5 131.6  10.1       0       550.5     1210    1291        218

Decision tree - rpart package

treerpart = rpart(retplasma ~ ., data = plasma[-100L,])
predrpart = Predictor$new(model = treerpart, data = plasma[-100L,], y = "retplasma")
predrpart$predict(x_interest)
#>       pred
#> 1 342.9231
nicerpart = NICERegr$new(predrpart, optimization = "proximity", 
  margin_correct = 0.5, return_multiple = FALSE)
nicerpart$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
#> 1 Counterfactual(s) 
#>  
#> Desired outcome range: [500, Inf] 
#>  
#> Head: 
#>      age    sex smokstat      bmi vituse calories   fat fiber alcohol cholesterol betadiet retdiet betaplasma
#>    <int> <fctr>   <fctr>    <num> <fctr>    <num> <num> <num>   <num>       <num>    <int>   <int>      <int>
#> 1:    46      1        3 35.25969      3   2667.5 131.6  10.1       0       550.5     1210    1291        218