Install devtools or remotes if not already installed:
The rCISSVAE package can be installed with:
This package uses reticulate to interface with the
python version of the package ciss_vae. Therefore, it is
necessary to make sure that you have a venv or conda environment set up
that has the ciss_vae package installed. If you are
comfortable creating an environment and installing the package, great!
Then all you need to do is tell reticulate where to point.
For Venv
For conda
If you do not want to manually create the virtual environment, you
can use the helper function create_cissvae_env() to create
a virtual environment (venv) in your current working directory.
create_cissvae_env(
envname = "./cissvae_environment", ## name of environment
path = NULL, ## add path to wherever you want virtual environment to be
install_python = FALSE, ## set to TRUE if you want create_cisssvae_env to install python for you
python_version = "3.10" ## set to whatever version you want >=3.10. Python 3.10 or 3.11 recommended
)Note: If you run into issues with create_cissvae_env(), you can create the virtual environment manually by following this tutorial
Once the environment is created, activate it using:
If you want to install other python packages (eg seaborn) to your
environment, you can use reticulate::virtualenv_install().
check_devices() to see available gpu
devicesUse check_devices() to see what cpu/gpu devices are
available for model training. Optionally pass the path to your virtual
environment using the parameter env_path to make sure that
the correct environment is activated.
Available Devices:
• cpu (Main system processor)
[1] "cpu"
Once reticulate is pointing to the virtual environment containing the
ciss_vae python package, you can either use the
run_cissvae function or the autotune_cissvae
function. For details on using the CISSVAE model with binary/categorical
values see binary
variables tutorial.
If you know what hyperparameters you want to use for the model, use
the run_cissvae function.
Your data should be in a DataFrame format with optional
index column. If you already have clusters you want to use, they should
be in a separate vector from the dataframe. If you do not have clusters
to begin with, set ‘clusters’ in run_cissvae() to NULL.
With the rCISSVAE package comes a sample dataset with predetermined clusters. We will use this for the sake of the tutorial.
The dataset, df_missing contains an index column as well
as the following:
| Characteristic | N = 8,000 |
|---|---|
| Age | 9.99 (8.73, 11.44) |
| Salary | 5.70 (5.34, 6.17) |
| ZipCode10001 | 2,628 (33%) |
| ZipCode20002 | 2,697 (34%) |
| ZipCode30003 | 2,675 (33%) |
| Y11 | 0 (-11, 7) |
| Unknown | 3,122 |
| Y12 | 49 (-22, 58) |
| Unknown | 3,118 |
| Y13 | 72 (-17, 95) |
| Unknown | 3,110 |
| Y14 | 69 (-12, 119) |
| Unknown | 3,129 |
| Y15 | 72 (-12, 134) |
| Unknown | 3,141 |
| Y21 | -9 (-22, 0) |
| Unknown | 3,135 |
| Y22 | 47 (-35, 58) |
| Unknown | 3,094 |
| Y23 | 74 (-29, 100) |
| Unknown | 3,098 |
| Y24 | 71 (-22, 128) |
| Unknown | 3,146 |
| Y25 | 73 (-22, 145) |
| Unknown | 3,106 |
| Y31 | 0 (-14, 11) |
| Unknown | 2,067 |
| Y32 | 59 (-17, 69) |
| Unknown | 2,056 |
| Y33 | 81 (-13, 101) |
| Unknown | 2,013 |
| Y34 | 79 (-6, 124) |
| Unknown | 2,051 |
| Y35 | 81 (-6, 139) |
| Unknown | 2,054 |
| Y41 | 0 (-6, 5) |
| Unknown | 2,032 |
| Y42 | 27 (-8, 32) |
| Unknown | 2,022 |
| Y43 | 37 (-5, 46) |
| Unknown | 2,013 |
| Y44 | 36 (-3, 56) |
| Unknown | 2,023 |
| Y45 | 37 (-3, 63) |
| Unknown | 2,086 |
| Y51 | 1.8 (-3.6, 5.8) |
| Unknown | 2,077 |
| Y52 | 25 (-5, 29) |
| Unknown | 2,034 |
| Y53 | 34 (-3, 41) |
| Unknown | 1,976 |
| Y54 | 33 (-1, 50) |
| Unknown | 2,047 |
| Y55 | 34 (-1, 56) |
| Unknown | 2,050 |
Age, Salary and ZipCode columns represent demographic data with no missingness and columns Y1t-Y5t represent biomarker data obtained at different timepoints t.
run_cissvae() ParametersThe run_cissvae() function is a comprehensive wrapper
for all basic steps in running the CISS-VAE model, including dataset
preparation, optional clustering, and running the imputation model.
Dataset Parameters - data: A DataFrame
containing the dataset to be imputed. Contains optional index
column.
- index_col: Name of index column to be preserved when imputing. Index
column will not have any values held out for validation
- val_proportion: Fraction of non-missing entries to hold out for
validation during training. To use different proportions for each
cluster, pass a vector.
- replacement_value: Fill value for masked entries during training.
Default is 0.0.
- columns_ignore: Character or integer vector containing columns to
exclude when selecting validation data. These columns will be used
during training. - print_dataset: Set TRUE to print dataset
summary information during processing.
Clustering Parameters (optional)
- clusters: Vector of one cluster label per row of ‘data’ dataframe. If
NULL, will automatically determine clusters using Leiden Clustering or
KMeans.
- n_clusters: Number of clusters for KMeans clustering when ‘clusters’
is NULL. If n_clusters is NULL, will use Leiden Clustering for
clustering.
- leiden_resolution: Resolution parameter for Leiden clustering.
Defaults to 0.5. - k_neighbors: Number of nearest neighbors for the
Leiden KNN graph construction. Defaults to 15. - leiden_objective:
Objective function for Leiden clustering. One of {“CPM”, “RB”,
“Modularity”}. Defaults to “CPM”. - missingness_proportion_matrix:
Optional pre-computed missingness proportion matrix for feature-based
clustering. If provided, clustering will be based on these proportions
instead of direct 0/1 missingness pattern.
- scale_features: Set TRUE to scale features when using
missingness proportion matrix clustering.
Model Parameters
- hidden_dims: A vector containing the sizes of hidden layers in
encoder/decoder. The length of this vector determines number of hidden
layers.
- latent_dim: The dimension of the latent space representation.
- layer_order_enc: A vector stating the pattern of ‘shared’ and
‘unshared’ layers for the encoder. The length must match
length(hidden_dims). Default c(‘unshared’, ‘unshared’,
‘unshared’).
- layer_order_dec: A vector stating the pattern of ‘shared’ and
‘unshared’ layers for the decoder. The length must match
length(hidden_dims). Default c(‘shared’, ‘shared’,
‘shared’).
- latent_shared: Whether latent space weights are shared across
clusters. If FALSE, will have separate latent weights for each
cluster.
- ouput_shared: If FALSE, will have separate output layer for each
cluster.
- batch_size: Integer. Mini-batch size for training. Larger values may
improve training stability but require more memory.
- return_model: If TRUE, returns the model object. Set TRUE to use
plot_vae_architecture() after running.
- epochs: Number of epochs for initial training phase
- initial_lr: Initial learning rate for optimizer.
- decay_factor: Exponential decay factor for learning rate.
- beta: Weight for KL divergence term in VAE loss function.
- device: Device specification for computation (“cpu” or “cuda”). If
NULL, automatically selects best available device.
- max-loops: Max number of impute-refit loops to perform.
- patience: Training stops if validation loss doesn’t improve for this
many consecutive impute-refit loops.
- epochs-per-loop: Number of epochs per refit loop. If null, uses same
value as epochs. Default NULL.
- decay_factor_refit: Decay factor for refit loops. If NULL, uses same
value as decay_factor. Default NULL.
- beta_refit: KL weight for refit loops. If NULL, uses same value as
beta. Default NULL.
Optional Parameters
- verbose: Set TRUE to print MSE for each loop as it runs.
- return_silhouettes: If clusters not given, will return silhouette
scores for automatic clustering. - return_history: If TRUE,
returns training history as data.frame. Good for checking for
overfitting. - return_dataset: If TRUE, returns
ClusterDataset object.
To run the imputation model, first load your data. You can use the
cluster_summary() function to visualize the missingness by
cluster. The cluster_summary() function builds off of {gtsummary}.
library(tidyverse)
library(reticulate)
library(rCISSVAE)
library(gtsummary)
## Set correct virtualenv
reticulate::use_virtualenv("./cissvae_environment", required = TRUE)
## Load the data
data(df_missing)
data(clusters) ## actual cluster labels in clusters$clusters (other column is index)
cluster_summary(
data = df_missing,
clusters = clusters$clusters,
include =setdiff(names(df_missing), "index"),
statistic = list(
all_continuous() ~ "{mean} ({sd})",
all_categorical() ~ "{n} / {N}\n ({p}%)"),
missing = "always")| Characteristic | N | 0 N = 2,0001 |
1 N = 2,0001 |
2 N = 2,0001 |
3 N = 2,0001 |
|---|---|---|---|---|---|
| Age | 8,000 | 10.10 (2.04) | 10.19 (2.08) | 10.21 (2.14) | 10.29 (2.06) |
| Unknown | 0 | 0 | 0 | 0 | |
| Salary | 8,000 | 5.81 (0.61) | 5.83 (0.62) | 5.83 (0.61) | 5.81 (0.60) |
| Unknown | 0 | 0 | 0 | 0 | |
| ZipCode10001 | 8,000 | 646 / 2,000 (32%) | 674 / 2,000 (34%) | 663 / 2,000 (33%) | 645 / 2,000 (32%) |
| Unknown | 0 | 0 | 0 | 0 | |
| ZipCode20002 | 8,000 | 703 / 2,000 (35%) | 652 / 2,000 (33%) | 655 / 2,000 (33%) | 687 / 2,000 (34%) |
| Unknown | 0 | 0 | 0 | 0 | |
| ZipCode30003 | 8,000 | 651 / 2,000 (33%) | 674 / 2,000 (34%) | 682 / 2,000 (34%) | 668 / 2,000 (33%) |
| Unknown | 0 | 0 | 0 | 0 | |
| Y11 | 4,878 | -21 (10) | -16 (9) | 8 (5) | -3 (6) |
| Unknown | 1,281 | 1,288 | 0 | 553 | |
| Y12 | 4,882 | 69 (11) | -26 (9) | 55 (6) | -24 (8) |
| Unknown | 1,264 | 1,283 | 0 | 571 | |
| Y13 | 4,890 | 77 (12) | -25 (9) | 98 (12) | -17 (7) |
| Unknown | 1,289 | 1,264 | 0 | 557 | |
| Y14 | 4,871 | 73 (12) | -21 (8) | 125 (16) | -11 (6) |
| Unknown | 1,300 | 1,283 | 0 | 546 | |
| Y15 | 4,859 | 76 (12) | -12 (6) | 141 (19) | -14 (6) |
| Unknown | 1,273 | 1,293 | 0 | 575 | |
| Y21 | 4,865 | -33 (12) | -28 (11) | 1 (7) | -12 (7) |
| Unknown | 1,266 | 1,292 | 0 | 577 | |
| Y22 | 4,906 | 69 (12) | -40 (12) | 54 (6) | -36 (10) |
| Unknown | 1,266 | 1,276 | 0 | 552 | |
| Y23 | 4,902 | 79 (13) | -38 (11) | 104 (13) | -29 (9) |
| Unknown | 1,273 | 1,275 | 0 | 550 | |
| Y24 | 4,854 | 75 (12) | -32 (10) | 135 (18) | -22 (7) |
| Unknown | 1,302 | 1,287 | 0 | 557 | |
| Y25 | 4,894 | 78 (13) | -22 (8) | 153 (21) | -25 (8) |
| Unknown | 1,257 | 1,294 | 0 | 555 | |
| Y31 | 5,933 | -18 (10) | -13 (9) | 13 (5) | 1 (6) |
| Unknown | 192 | 1,285 | 0 | 590 | |
| Y32 | 5,944 | 74 (11) | -24 (10) | 62 (7) | -21 (8) |
| Unknown | 206 | 1,287 | 0 | 563 | |
| Y33 | 5,987 | 84 (13) | -23 (10) | 108 (13) | -14 (7) |
| Unknown | 203 | 1,267 | 0 | 543 | |
| Y34 | 5,949 | 81 (13) | -17 (8) | 136 (17) | -7 (6) |
| Unknown | 195 | 1,275 | 0 | 581 | |
| Y35 | 5,946 | 83 (13) | -8 (6) | 153 (20) | -10 (7) |
| Unknown | 204 | 1,285 | 0 | 565 | |
| Y41 | 5,968 | -8 (4) | -5 (3) | 6 (2) | 1 (2) |
| Unknown | 184 | 1,279 | 0 | 569 | |
| Y42 | 5,978 | 35 (6) | -11 (4) | 29 (4) | -9 (3) |
| Unknown | 199 | 1,282 | 0 | 541 | |
| Y43 | 5,987 | 39 (7) | -10 (3) | 49 (6) | -6 (3) |
| Unknown | 217 | 1,242 | 0 | 554 | |
| Y44 | 5,977 | 37 (7) | -8 (3) | 62 (9) | -3 (2) |
| Unknown | 186 | 1,280 | 0 | 557 | |
| Y45 | 5,914 | 39 (7) | -4 (3) | 70 (10) | -5 (2) |
| Unknown | 204 | 1,305 | 0 | 577 | |
| Y51 | 5,923 | -5.4 (3.6) | -2.9 (3.0) | 6.9 (1.9) | 2.5 (2.0) |
| Unknown | 222 | 1,279 | 0 | 576 | |
| Y52 | 5,966 | 32 (5) | -8 (3) | 26 (3) | -6 (3) |
| Unknown | 209 | 1,283 | 0 | 542 | |
| Y53 | 6,024 | 35 (6) | -6 (3) | 44 (6) | -3 (2) |
| Unknown | 184 | 1,243 | 0 | 549 | |
| Y54 | 5,953 | 34 (6) | -5 (3) | 55 (7) | -1 (2) |
| Unknown | 217 | 1,281 | 0 | 549 | |
| Y55 | 5,950 | 35 (6) | -2 (2) | 62 (9) | -2 (2) |
| Unknown | 207 | 1,292 | 0 | 551 | |
| 1 Mean (SD); n / N (%) | |||||
Then, plug your data and clusters into the run_cissvae()
function.
## Run the imputation model.
dat = run_cissvae(
data = df_missing,
index_col = "index",
val_proportion = 0.1, ## pass a vector for different proportions by cluster
columns_ignore = c("Age", "Salary", "ZipCode10001", "ZipCode20002", "ZipCode30003"), ## If there are columns in addition to the index you want to ignore when selecting validation set, list them here. In this case, we ignore the 'demographic' columns because we do not want to remove data from them for validation purposes.
clusters = clusters$clusters, ## we have precomputed cluster labels so we pass them here
epochs = 500,
return_silhouettes = FALSE,
return_history = TRUE, # Get detailed training history
verbose = FALSE,
return_model = TRUE, ## Allows for plotting model schematic
device = "cpu", # Explicit device selection
layer_order_enc = c("unshared", "shared", "unshared"),
layer_order_dec = c("shared", "unshared", "shared")
)
## Retrieve results
imputed_df <- dat$imputed
silhouette <- dat$silhouettes
training_history <- dat$history # Detailed training progress
## Plot training progress
if (!is.null(training_history)) {
plot(training_history$epoch, training_history$loss,
type = "l", main = "Training Loss Over Time",
xlab = "Epoch", ylab = "Loss")
}
plot_vae_architecture(model = dat$model, save_path = "test_plot_arch.png")
print(head(dat$imputed_dataset))Cluster dataset:
ClusterDataset(n_samples=8000, n_features=30, n_clusters=4)
• Original missing: 61800 / 200000 (30.90%)
• Validation held-out: 13783 (9.97% of non-missing)
• .data shape: (8000, 30)
• .masks shape: (8000, 30)
• .val_data shape: (8000, 30)
#> index Age Salary ZipCode10001 ZipCode20002 ZipCode30003 Y11
#> 0 0 11.044449 6.366204 0 1 0 -4.0495372
#> 1 1 9.727260 5.912558 1 0 0 0.5461677
#> 2 2 11.383020 6.636472 0 1 0 -1.2134339
#> 3 3 13.560905 5.896255 0 0 1 -10.6082144
#> 4 4 9.542490 6.128326 1 0 0 0.3575883
#> 5 5 9.542521 6.393217 1 0 0 4.7617960
#> Y12 Y13 Y14 Y15 Y21 Y22 Y23
#> 0 -14.39339 -0.5147629 -14.369148 -17.564449 -10.736261 -18.53887 -35.77263
#> 1 -19.02272 -12.1895180 -7.722473 3.319988 -7.470250 -20.88160 -25.92436
#> 2 -19.03144 -20.3589058 -15.126495 -17.251385 -18.448421 -21.01385 -34.40086
#> 3 -22.24773 -7.1759834 -14.207619 -21.339748 -21.971752 -24.32133 -40.18794
#> 4 -16.48769 -11.3127708 7.535458 10.184475 -7.576005 -27.63498 -15.74972
#> 5 -18.96558 -12.2694435 5.667511 -9.094982 -4.120708 -29.16836 -25.07210
#> Y24 Y25 Y31 Y32 Y33 Y34 Y35
#> 0 -28.098907 -30.24259 -1.627203 -6.557133 -16.769653 -10.6946259 -13.900185
#> 1 -17.231422 -18.69529 4.355482 -9.500225 -10.927032 -5.8868866 -6.088921
#> 2 -27.250603 -28.83982 -2.169048 -10.735180 -17.193771 -10.4940796 -12.286621
#> 3 -26.304344 -33.35109 -7.478534 -13.641594 -25.310104 -0.5022125 -15.429726
#> 4 -1.223648 -18.55935 8.103012 -13.563183 -9.832379 11.2815590 -2.962196
#> 5 -4.624435 -19.35896 7.710739 -8.304111 -10.381554 7.6654510 -4.992691
#> Y41 Y42 Y43 Y44 Y45 Y51 Y52
#> 0 -0.9049605 -2.512167 3.527925 -3.694853 -5.680294 2.587588 1.84293461
#> 1 2.6245856 -4.218277 -5.776196 -1.379498 -2.329605 6.080512 0.04265404
#> 2 0.1033239 -4.733656 -7.215717 -3.350798 -6.895340 2.531148 0.31131554
#> 3 -2.8251922 -6.273870 -8.293898 -2.398297 3.204130 0.138556 -0.79631424
#> 4 3.6173897 -4.619453 -3.856878 7.120106 -1.495659 2.406390 -3.34118080
#> 5 2.1276770 -4.269014 -4.685389 5.764248 -2.990555 2.464692 -3.37204909
#> Y53 Y54 Y55
#> 0 -4.681194 -2.2484055 -2.6790791
#> 1 -2.290062 -0.8873978 0.5625324
#> 2 -5.427431 -1.3301620 -2.3243809
#> 3 -5.729860 -1.6395874 -4.4457073
#> 4 -1.915190 6.5168343 0.1034508
#> 5 -2.237415 5.5294094 -1.0737858
Before running CISS-VAE, you can cluster features based on their missingness patterns. This helps identify features that tend to be missing together systematically, which can improve imputation quality.
library(rCISSVAE)
data(df_missing)
cluster_result <- cluster_on_missing(
data = df_missing,
cols_ignore = c("index", "Age", "Salary", "ZipCode10001", "ZipCode20002", "ZipCode30003"),
n_clusters = 4, # Use KMeans with 4 clusters
seed = 42
)
cluster_summary(df_missing, factor(cluster_result$clusters), include = setdiff(names(df_missing), "index"),
statistic = list(
gtsummary::all_continuous() ~ "{mean} ({sd})",
gtsummary::all_categorical() ~ "{n} / {N}\n ({p}%)"),
missing = "always")
cat(paste("Clustering quality (silhouette):", round(cluster_result$silhouette, 3)))
result <- run_cissvae(
data = df_missing,
index_col = "index",
clusters = cluster_result$clusters,
return_history = TRUE,
verbose = FALSE,
device = "cpu"
)| Characteristic | N | 0 N = 2,0001 |
1 N = 2,0001 |
2 N = 2,0001 |
3 N = 2,0001 |
|---|---|---|---|---|---|
| Age | 8,000 | 10.10 (2.04) | 10.19 (2.08) | 10.21 (2.14) | 10.29 (2.06) |
| Unknown | 0 | 0 | 0 | 0 | |
| Salary | 8,000 | 5.81 (0.61) | 5.83 (0.62) | 5.83 (0.61) | 5.81 (0.60) |
| Unknown | 0 | 0 | 0 | 0 | |
| ZipCode10001 | 8,000 | 646 / 2,000 (32%) | 674 / 2,000 (34%) | 663 / 2,000 (33%) | 645 / 2,000 (32%) |
| Unknown | 0 | 0 | 0 | 0 | |
| ZipCode20002 | 8,000 | 703 / 2,000 (35%) | 652 / 2,000 (33%) | 655 / 2,000 (33%) | 687 / 2,000 (34%) |
| Unknown | 0 | 0 | 0 | 0 | |
| ZipCode30003 | 8,000 | 651 / 2,000 (33%) | 674 / 2,000 (34%) | 682 / 2,000 (34%) | 668 / 2,000 (33%) |
| Unknown | 0 | 0 | 0 | 0 | |
| Y11 | 4,878 | -21 (10) | -16 (9) | 8 (5) | -3 (6) |
| Unknown | 1,281 | 1,288 | 0 | 553 | |
| Y12 | 4,882 | 69 (11) | -26 (9) | 55 (6) | -24 (8) |
| Unknown | 1,264 | 1,283 | 0 | 571 | |
| Y13 | 4,890 | 77 (12) | -25 (9) | 98 (12) | -17 (7) |
| Unknown | 1,289 | 1,264 | 0 | 557 | |
| Y14 | 4,871 | 73 (12) | -21 (8) | 125 (16) | -11 (6) |
| Unknown | 1,300 | 1,283 | 0 | 546 | |
| Y15 | 4,859 | 76 (12) | -12 (6) | 141 (19) | -14 (6) |
| Unknown | 1,273 | 1,293 | 0 | 575 | |
| Y21 | 4,865 | -33 (12) | -28 (11) | 1 (7) | -12 (7) |
| Unknown | 1,266 | 1,292 | 0 | 577 | |
| Y22 | 4,906 | 69 (12) | -40 (12) | 54 (6) | -36 (10) |
| Unknown | 1,266 | 1,276 | 0 | 552 | |
| Y23 | 4,902 | 79 (13) | -38 (11) | 104 (13) | -29 (9) |
| Unknown | 1,273 | 1,275 | 0 | 550 | |
| Y24 | 4,854 | 75 (12) | -32 (10) | 135 (18) | -22 (7) |
| Unknown | 1,302 | 1,287 | 0 | 557 | |
| Y25 | 4,894 | 78 (13) | -22 (8) | 153 (21) | -25 (8) |
| Unknown | 1,257 | 1,294 | 0 | 555 | |
| Y31 | 5,933 | -18 (10) | -13 (9) | 13 (5) | 1 (6) |
| Unknown | 192 | 1,285 | 0 | 590 | |
| Y32 | 5,944 | 74 (11) | -24 (10) | 62 (7) | -21 (8) |
| Unknown | 206 | 1,287 | 0 | 563 | |
| Y33 | 5,987 | 84 (13) | -23 (10) | 108 (13) | -14 (7) |
| Unknown | 203 | 1,267 | 0 | 543 | |
| Y34 | 5,949 | 81 (13) | -17 (8) | 136 (17) | -7 (6) |
| Unknown | 195 | 1,275 | 0 | 581 | |
| Y35 | 5,946 | 83 (13) | -8 (6) | 153 (20) | -10 (7) |
| Unknown | 204 | 1,285 | 0 | 565 | |
| Y41 | 5,968 | -8 (4) | -5 (3) | 6 (2) | 1 (2) |
| Unknown | 184 | 1,279 | 0 | 569 | |
| Y42 | 5,978 | 35 (6) | -11 (4) | 29 (4) | -9 (3) |
| Unknown | 199 | 1,282 | 0 | 541 | |
| Y43 | 5,987 | 39 (7) | -10 (3) | 49 (6) | -6 (3) |
| Unknown | 217 | 1,242 | 0 | 554 | |
| Y44 | 5,977 | 37 (7) | -8 (3) | 62 (9) | -3 (2) |
| Unknown | 186 | 1,280 | 0 | 557 | |
| Y45 | 5,914 | 39 (7) | -4 (3) | 70 (10) | -5 (2) |
| Unknown | 204 | 1,305 | 0 | 577 | |
| Y51 | 5,923 | -5.4 (3.6) | -2.9 (3.0) | 6.9 (1.9) | 2.5 (2.0) |
| Unknown | 222 | 1,279 | 0 | 576 | |
| Y52 | 5,966 | 32 (5) | -8 (3) | 26 (3) | -6 (3) |
| Unknown | 209 | 1,283 | 0 | 542 | |
| Y53 | 6,024 | 35 (6) | -6 (3) | 44 (6) | -3 (2) |
| Unknown | 184 | 1,243 | 0 | 549 | |
| Y54 | 5,953 | 34 (6) | -5 (3) | 55 (7) | -1 (2) |
| Unknown | 217 | 1,281 | 0 | 549 | |
| Y55 | 5,950 | 35 (6) | -2 (2) | 62 (9) | -2 (2) |
| Unknown | 207 | 1,292 | 0 | 551 | |
| 1 Mean (SD); n / N (%) | |||||
#> Clustering quality (silhouette): 0.135
To create clusters based on proportion of missingness across all
timepoints for a given feature, you can provide a pre-computed
missingness proportion matrix (by using
create_missingness_prop_matrix() or manually) directly to
run_cissvae():
## Standardize df_missing column names to feature_timepoint format
colnames(df_missing) = c('index', 'Age', 'Salary', 'ZipCode10001', 'ZipCode20002', 'ZipCode30003', 'Y1_1', 'Y1_2', 'Y1_3', 'Y1_4', 'Y1_5', 'Y2_1', 'Y2_2', 'Y2_3', 'Y2_4', 'Y2_5', 'Y3_1', 'Y3_2', 'Y3_3', 'Y3_4', 'Y3_5', 'Y4_1', 'Y4_2', 'Y4_3', 'Y4_4', 'Y4_5', 'Y5_1', 'Y5_2', 'Y5_3', 'Y5_4', 'Y5_5')
# Create and examine missingness proportion matrix
prop_matrix <- create_missingness_prop_matrix(df_missing,
index_col = "index",
cols_ignore = c('Age', 'Salary', 'ZipCode10001', 'ZipCode20002', 'ZipCode30003'),
repeat_feature_names = c("Y1", "Y2", "Y3", "Y4", "Y5"))
cat("Missingness proportion matrix dimensions:\n")
cat(dim(prop_matrix), "\n")
cat("Sample of proportion matrix:\n")
print(head(prop_matrix[, 1:5]))
# Use proportion matrix with scaling for better clustering
advanced_result <- run_cissvae(
data = df_missing,
index_col = "index",
clusters = NULL, # Let function cluster using prop_matrix
columns_ignore = c('Age', 'Salary', 'ZipCode10001', 'ZipCode20002', 'ZipCode30003'),
missingness_proportion_matrix = prop_matrix,
scale_features = TRUE, # Standardize features before clustering
n_clusters = 4,
leiden_resolution = 0.1,
epochs = 5,
return_history = TRUE,
return_silhouettes = TRUE,
device = "cpu",
verbose = FALSE,
return_clusters = TRUE
)
print("Clustering quality:")
print(paste("Silhouette score:", round(advanced_result$silhouette_width, 3)))
## Plotting imputation loss by epoch
ggplot2::ggplot(data = advanced_result$training_history, aes(x = epoch, y = imputation_error)) + geom_point() + labs(y = "Imputation Loss", x = "Epoch") +
theme_classic()#> Missingness proportion matrix dimensions:
#> 8000 5
#> Sample of proportion matrix:
#> Y1 Y2 Y3 Y4 Y5
#> 1 0.4 0.4 0.2 0.4 0.2
#> 2 0.4 0.2 0.2 0.2 0.2
#> 3 0.4 0.2 0.2 0.4 0.2
#> 4 0.4 0.2 0.4 0.4 0.2
#> 5 0.4 0.4 0.2 0.2 0.4
#> 6 0.2 0.2 0.4 0.2 0.4
#> Clustering quality:
#> Silhouette score: 0.656
For hyperparameter optimization in autotune_cissvae(),
parameters can be specified as:
beta = 0.01 → parameter
remains constant across trialsc(64, 128, 256) →
Optuna selects from the provided optionsreticulate::tuple(1e-4, 1e-3) → Optuna suggests floats in
the specified rangeThe layer arrangement strategies control how shared and unshared layers are positioned:
"at_end": Places shared layers at the
end of the encoder or start of the decoder"at_start": Places shared layers at
the start of the encoder or end of the decoder"alternating": Distributes shared
layers evenly throughout the architecture"random": Uses random placement of
shared layers (with reproducible seed)library(tidyverse)
library(reticulate)
library(rCISSVAE)
reticulate::use_virtualenv("./cissvae_environment", required = TRUE)
data(df_missing)
data(clusters)
aut <- autotune_cissvae(
data = df_missing,
index_col = "index",
clusters = clusters$clusters,
save_model_path = NULL,
save_search_space_path = NULL,
n_trials = 3, ## Using low number of trials for demo
study_name = "comprehensive_vae_autotune",
device_preference = "cpu",
show_progress = FALSE, # Set true for Rich progress bars with training visualization
optuna_dashboard_db = "sqlite:///optuna_study.db", # Save results to database
load_if_exists = FALSE, ## Set true to load and continue study if it exists
seed = 42,
verbose = FALSE,
# Search strategy options
constant_layer_size = FALSE, # Allow different sizes per layer
evaluate_all_orders = FALSE, # Sample layer arrangements efficiently
max_exhaustive_orders = 100, # Limit for exhaustive search
## Hyperparameter search space
num_hidden_layers = c(2, 5), # Try 2-5 hidden layers
hidden_dims = c(64, 512), # Layer sizes from 64 to 512
latent_dim = c(10, 100), # Latent dimension range
latent_shared = c(TRUE, FALSE),
output_shared = c(TRUE, FALSE),
lr = 0.01, # Learning rate range
decay_factor = 0.99,
beta = 0.01, # KL weight range
num_epochs = 500, # Fixed epochs for demo
batch_size = c(1000, 4000), # Batch size options
num_shared_encode = c(0, 1, 2, 3),
num_shared_decode = c(0, 1, 2, 3),
# Layer placement strategies - try different arrangements
encoder_shared_placement = c("at_end", "at_start", "alternating", "random"),
decoder_shared_placement = c("at_start", "at_end", "alternating", "random"),
refit_patience = 2, # Early stopping patience
refit_loops = 100, # Fixed refit loops
epochs_per_loop = 100, # Epochs per refit loop
reset_lr_refit = c(TRUE, FALSE)
)
# Analyze results
imputed <- aut$imputed
best_model <- aut$model
study <- aut$study
results <- aut$results
# View best hyperparameters
print("Trial results:")
results %>% kable() %>%
kable_styling(font_size=12)
# Plot model architecture
plot_vae_architecture(best_model, title = "Optimized CISSVAE Architecture")#> [1] "Trial results:"
| trial_number | imputation_error | num_hidden_layers | hidden_dim_0 | hidden_dim_1 | hidden_dim_2 | hidden_dim_3 | hidden_dim_4 | latent_dim | latent_shared | output_shared | batch_size | num_shared_encode | num_shared_decode | encoder_shared_placement | decoder_shared_placement | layer_order_enc_used | layer_order_dec_used |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 29.62110 | 5 | 512 | 64 | 512 | 64 | 64 | 100 | FALSE | FALSE | 1000 | 1 | 1 | alternating | at_start | S,U,U,U,U | S,U,U,U,U |
| 1 | 33.00819 | 5 | 64 | 64 | 64 | 512 | 64 | 10 | TRUE | TRUE | 1000 | 1 | 0 | alternating | alternating | S,U,U,U,U | U,U,U,U,U |
| 2 | 58.28946 | 2 | 512 | 64 | NaN | NaN | NaN | 10 | TRUE | FALSE | 4000 | 0 | 3 | alternating | random | U,U | S,S |
For more information on using optuna dashboard, see the Optuna Dashboard tutorial