This vignette is a guide to policy_learn() and some of
the associated S3 methods. The purpose of policy_learn is
to specify a policy learning algorithm and estimate an optimal policy.
For details on the methodology, see the associated paper (Nordland and Holst 2023).
We consider a fixed two-stage problem as a general setup and simulate
data using sim_two_stage() and create a
policy_data object using policy_data():
d <- sim_two_stage(n = 2e3, seed = 1)
pd <- policy_data(d,
                  action = c("A_1", "A_2"),
                  baseline = c("B", "BB"),
                  covariates = list(L = c("L_1", "L_2"),
                                    C = c("C_1", "C_2")),
                  utility = c("U_1", "U_2", "U_3"))
pd
#> Policy data with n = 2000 observations and maximal K = 2 stages.
#> 
#>      action
#> stage    0    1    n
#>     1 1017  983 2000
#>     2  819 1181 2000
#> 
#> Baseline covariates: B, BB
#> State covariates: L, C
#> Average utility: 0.84policy_learn() specify a policy learning algorithm via
the type argument: Q-learning (ql), doubly
robust Q-learning (drql), doubly robust blip learning
(blip), policy tree learning (ptl), and
outcome weighted learning (owl).
Because each policy learning type has varying control arguments,
these are passed as a list using the control argument. To
help the user set the required control arguments and to provide
documentation, each type has a helper function
control_type() which sets the default control arguments and
overwrite values if supplied by the user.
As an example we specify a doubly robust blip learner:
pl_blip <- policy_learn(
  type = "blip",
  control = control_blip(
    blip_models = q_glm(formula = ~ BB + L + C)
  )
)For details on the implementation, see Algorithm 3 in (Nordland and Holst 2023). The only required
control argument for blip learning is a model input. The
blip_models argument expects a q_model. In
this case we input a simple linear model as implemented in
q_glm.
The output of policy_learn() is again a function:
pl_blip
#> Policy learner with arguments:
#> policy_data, g_models=NULL, g_functions=NULL,
#> g_full_history=FALSE, q_models, q_full_history=FALSEIn order to apply the policy learner we need to input a
policy_data object and nuisance models
g_models and q_models for computing the doubly
robust score.
Like policy_eval() is it possible to cross-fit the
doubly robust score used as input to the policy model. The number of
folds for the cross-fitting procedure is provided via the L
argument. As default, the cross-fitted nuisance models are not saved.
The cross-fitted nuisance models can be saved via the
save_cross_fit_models argument:
pl_blip_cross <- policy_learn(
  type = "blip",
  control = control_blip(
    blip_models = q_glm(formula = ~ BB + L + C)
  ),
  L = 2,
  save_cross_fit_models = TRUE
)
po_blip_cross <- pl_blip_cross(
   pd,
   g_models = list(g_glm(), g_glm()),
   q_models = list(q_glm(), q_glm())
 )From a user perspective, nothing has changed. However, the policy object now contains each of the cross-fitted nuisance models:
po_blip_cross$g_functions_cf
#> $`1`
#> $stage_1
#> $model
#> 
#> Call:  NULL
#> 
#> Coefficients:
#> (Intercept)            L            C            B     BBgroup2     BBgroup3  
#>    -0.18321      0.15191      0.90737     -0.03865      0.18927      0.15088  
#> 
#> Degrees of Freedom: 999 Total (i.e. Null);  994 Residual
#> Null Deviance:       1384 
#> Residual Deviance: 1086  AIC: 1098
#> 
#> 
#> $stage_2
#> $model
#> 
#> Call:  NULL
#> 
#> Coefficients:
#> (Intercept)            L            C            B     BBgroup2     BBgroup3  
#>     0.24410      0.13150      0.99426     -0.02289     -0.41777     -0.17383  
#> 
#> Degrees of Freedom: 999 Total (i.e. Null);  994 Residual
#> Null Deviance:       1349 
#> Residual Deviance: 1082  AIC: 1094
#> 
#> 
#> attr(,"full_history")
#> [1] FALSE
#> 
#> $`2`
#> $stage_1
#> $model
#> 
#> Call:  NULL
#> 
#> Coefficients:
#> (Intercept)            L            C            B     BBgroup2     BBgroup3  
#>    0.113952    -0.240397     1.142507    -0.094362    -0.009235    -0.101783  
#> 
#> Degrees of Freedom: 999 Total (i.e. Null);  994 Residual
#> Null Deviance:       1386 
#> Residual Deviance: 1065  AIC: 1077
#> 
#> 
#> $stage_2
#> $model
#> 
#> Call:  NULL
#> 
#> Coefficients:
#> (Intercept)            L            C            B     BBgroup2     BBgroup3  
#>     0.15426      0.01307      0.96485     -0.08554     -0.33532     -0.12597  
#> 
#> Degrees of Freedom: 999 Total (i.e. Null);  994 Residual
#> Null Deviance:       1357 
#> Residual Deviance: 1102  AIC: 1114
#> 
#> 
#> attr(,"full_history")
#> [1] FALSERealistic policy learning is implemented for types ql,
drql, blip and ptl (for a binary
action set). The alpha argument sets the probability
threshold for defining the realistic action set. For implementation
details, see Algorithm 5 in (Nordland and Holst
2023). Here we set a 5% restriction:
pl_blip_alpha <- policy_learn(
  type = "blip",
  control = control_blip(
    blip_models = q_glm(formula = ~ BB + L + C)
  ),
  alpha = 0.05,
  L = 2
)
po_blip_alpha <- pl_blip_alpha(
   pd,
   g_models = list(g_glm(), g_glm()),
   q_models = list(q_glm(), q_glm())
 )The policy object now lists the alpha level as well as
the g-model used to define the realistic action set:
po_blip_alpha$g_functions
#> $stage_1
#> $model
#> 
#> Call:  NULL
#> 
#> Coefficients:
#> (Intercept)            L            C            B     BBgroup2     BBgroup3  
#>    -0.03295     -0.05107      1.02271     -0.06478      0.09582      0.02370  
#> 
#> Degrees of Freedom: 1999 Total (i.e. Null);  1994 Residual
#> Null Deviance:       2772 
#> Residual Deviance: 2161  AIC: 2173
#> 
#> 
#> $stage_2
#> $model
#> 
#> Call:  NULL
#> 
#> Coefficients:
#> (Intercept)            L            C            B     BBgroup2     BBgroup3  
#>     0.19814      0.07355      0.97991     -0.05280     -0.37163     -0.14598  
#> 
#> Degrees of Freedom: 1999 Total (i.e. Null);  1994 Residual
#> Null Deviance:       2707 
#> Residual Deviance: 2186  AIC: 2198
#> 
#> 
#> attr(,"full_history")
#> [1] FALSEget_policy_functions()A policy function is great for evaluating a given policy
or even implementing or simulating from a single-stage policy. However,
the function is not useful for implementing or simulating from a learned
multi-stage policy. To access the policy function for each stage we use
get_policy_functions(). In this case we get the second
stage policy function:
The stage specific policy requires a data.table with
named columns as input and returns a character vector with the
recommended actions:
get_policy()Applying the policy learner returns a policy_object
containing all of the components needed to specify the learned policy.
In this the only component of the policy is a model for the blip
function:
po_blip$blip_functions$stage_1$blip_model
#> $model
#> 
#> Call:  NULL
#> 
#> Coefficients:
#> (Intercept)     BBgroup2     BBgroup3            L            C  
#>      0.4076       0.2585       0.2231       0.1765       0.8624  
#> 
#> Degrees of Freedom: 1999 Total (i.e. Null);  1995 Residual
#> Null Deviance:       56820 
#> Residual Deviance: 53220     AIC: 12250
#> 
#> attr(,"class")
#> [1] "q_glm"To access and apply the policy itself use get_policy(),
which behaves as a policy meaning that we can apply to any
(suitable) policy_data object to get the policy
actions:
sessionInfo()
#> R version 4.4.1 (2024-06-14)
#> Platform: aarch64-apple-darwin23.5.0
#> Running under: macOS Sonoma 14.6.1
#> 
#> Matrix products: default
#> BLAS:   /Users/oano/.asdf/installs/R/4.4.1/lib/R/lib/libRblas.dylib 
#> LAPACK: /Users/oano/.asdf/installs/R/4.4.1/lib/R/lib/libRlapack.dylib;  LAPACK version 3.12.0
#> 
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#> 
#> time zone: Europe/Copenhagen
#> tzcode source: internal
#> 
#> attached base packages:
#> [1] splines   stats     graphics  grDevices utils     datasets  methods  
#> [8] base     
#> 
#> other attached packages:
#> [1] ggplot2_3.5.1       data.table_1.15.4   polle_1.5          
#> [4] SuperLearner_2.0-29 gam_1.22-4          foreach_1.5.2      
#> [7] nnls_1.5           
#> 
#> loaded via a namespace (and not attached):
#>  [1] sass_0.4.9          utf8_1.2.4          future_1.33.2      
#>  [4] lattice_0.22-6      listenv_0.9.1       digest_0.6.36      
#>  [7] magrittr_2.0.3      evaluate_0.24.0     grid_4.4.1         
#> [10] iterators_1.0.14    mvtnorm_1.2-5       policytree_1.2.3   
#> [13] fastmap_1.2.0       jsonlite_1.8.8      Matrix_1.7-0       
#> [16] survival_3.6-4      fansi_1.0.6         scales_1.3.0       
#> [19] numDeriv_2016.8-1.1 codetools_0.2-20    jquerylib_0.1.4    
#> [22] lava_1.8.0          cli_3.6.3           rlang_1.1.4        
#> [25] mets_1.3.4          parallelly_1.37.1   future.apply_1.11.2
#> [28] munsell_0.5.1       withr_3.0.0         cachem_1.1.0       
#> [31] yaml_2.3.8          tools_4.4.1         parallel_4.4.1     
#> [34] colorspace_2.1-0    ranger_0.16.0       globals_0.16.3     
#> [37] vctrs_0.6.5         R6_2.5.1            lifecycle_1.0.4    
#> [40] pkgconfig_2.0.3     timereg_2.0.5       progressr_0.14.0   
#> [43] bslib_0.7.0         pillar_1.9.0        gtable_0.3.5       
#> [46] Rcpp_1.0.13         glue_1.7.0          xfun_0.45          
#> [49] tibble_3.2.1        highr_0.11          knitr_1.47         
#> [52] farver_2.1.2        htmltools_0.5.8.1   rmarkdown_2.27     
#> [55] labeling_0.4.3      compiler_4.4.1