I’m studying ML performance in case of severe censoring.
With 50% censoring, Deepsurv with tuning fails in two out of five folds. However, with 90% censoring, all five folds fail with the error.
Error in check_prediction_data.PredictionDataSurv(pdata, train_task = task) : Assertion on 'pdata$crank' failed: Contains missing values (element 1).
Interestingly, Deepsurv with tuning works fine with three folds and 90% censoring.
This issue was posted in SO here and here without reproducible data and resolved at mlr3’s GitHub: Error on missing values without missing values The issue was found and described as very unbalanced dataset cause some of the cross-validation resamples not to include all of the labels. Unfortunately, I tried the first solution offered by @mllg during task instantiation but without luck.
A less problematic issue occurs with Survival Logistic-Hazard Learner in a few runs with the error.
> #Without pipelines
> Error in assert_surv_matrix(x) : Survival probabilities must be
> (non-strictly) decreasing and between [0, 1]
> #With pipelines
> Error in assert_surv_matrix(x) : Survival probabilities must be
> (non-strictly) decreasing and between [0, 1] This happened PipeOp
> ml_ann_loghaz.tuned's $train()
My first thought is that tuning does not respect the task role “stratum” during instantiation.
Of course, there is an easy solution to reduce the number of folds, as I did earlier (5 to 3) but I hope for a more realistic solution.
Here is a sample code:
library(mlr3pipelines)
library(mlr3proba)
library(mlr3extralearners)
library(mlr3tuning) #AutoTuner
library(mlr3pipelines)
library(reticulate)
# control Pyhon warnings through reticulate:
warnings <- import("warnings")
# ignore / suppress:
warnings$simplefilter("ignore")
search_space_ann = ps(
dropout = p_dbl(lower = 0, upper = 1),
weight_decay = p_dbl(lower = 0, upper = 0.5),
learning_rate = p_dbl(lower = 0, upper = 1),
num_layers = p_int(lower = 1, upper = 4),
num_nodes_per_layer = p_int(lower = 1, upper = 32),
.extra_trafo = function(x, param_set) {
x$num_nodes = rep(x$num_nodes_per_layer, x$num_layers)
x$num_layers = NULL
x$num_nodes_per_layer = NULL
x
}
)
#as_learner(po("scale") %>>%
at_lrn_deepsurv = auto_tuner(
learner = lrn("surv.deepsurv",
id = "ml_ann_deepsurv", frac = 0.3,
optimizer = "adam", early_stopping = TRUE, epochs = 100),
search_space = search_space_ann,
resampling = rsmp_tune_cv,
measure = msr("surv.cindex"),
terminator = trm_tune,
tuner = tnr("random_search")
)
at_lrn_loghaz = auto_tuner(
learner = lrn("surv.loghaz",
id = "ml_ann_loghaz", frac = 0.3,
optimizer = "adam", early_stopping = TRUE, epochs = 100),
search_space = search_space_ann,
resampling = rsmp_tune_cv,
measure = msr("surv.cindex"),
terminator = trm_tune,
tuner = tnr("random_search")
)
#df <- read_csv("simulation_data/PH_90_500_linear/sim_seed_1134.csv")
tsk_sim = as_task_surv(df, time = "time", event = "event", type = "right", id = "sim1")
tsk_sim$set_col_roles("event", c("target", "stratum"))
set.seed(1)
rsmp_tune_cv = rsmp("holdout")
trm_tune = trm("evals", n_evals = 100)
set.seed(1)
rsmp_outer_5cv = rsmp("cv", folds = 5)
rsmp_outer_5cv$instantiate(tsk_sim)
rsmp_outer_3cv = rsmp("cv", folds = 3)
rsmp_outer_3cv$instantiate(tsk_sim)
#Error
set.seed(1)
at_lrn_deepsurv$train(tsk_sim, row_ids = rsmp_outer_5cv$train_set(1))
INFO [16:24:38.892] [bbotk] Starting to optimize 5 parameter(s) with '<OptimizerBatchRandomSearch>' and '<TerminatorEvals> [n_evals=100, k=0]'
INFO [16:24:38.977] [bbotk] Evaluating 1 configuration(s)
INFO [16:24:38.994] [mlr3] Running benchmark with 1 resampling iterations
INFO [16:24:39.008] [mlr3] Applying learner 'ml_ann_deepsurv' on task 'sim1' (iter 1/1)
...
INFO [16:24:40.759] [mlr3] Finished benchmark
INFO [16:24:40.812] [bbotk] Result of batch 7:
INFO [16:24:40.818] [bbotk] dropout weight_decay learning_rate num_layers num_nodes_per_layer surv.cindex warnings
INFO [16:24:40.818] [bbotk] 0.6097489 0.1321245 0.4230986 2 31 0.6007143 0
INFO [16:24:40.818] [bbotk] errors runtime_learners uhash
INFO [16:24:40.818] [bbotk] 0 0.113 544d865f-4801-488d-83e5-2e30ce55b80c
INFO [16:24:40.850] [bbotk] Evaluating 1 configuration(s)
INFO [16:24:40.869] [mlr3] Running benchmark with 1 resampling iterations
INFO [16:24:40.882] [mlr3] Applying learner 'ml_ann_deepsurv' on task 'sim1' (iter 1/1)
Error in check_prediction_data.PredictionDataSurv(pdata, train_task = task) :
Assertion on 'pdata$crank' failed: Contains missing values (element 1).
#No Error
set.seed(1)
lrn("surv.deepsurv")$train(tsk_sim, row_ids = rsmp_outer_5cv$train_set(1))
#No Error
set.seed(1)
at_lrn_deepsurv$train(tsk_sim, row_ids = rsmp_outer_3cv$train_set(1))
#Error after 2 to 3 trials
set.seed(1)
at_lrn_loghaz$train(tsk_sim, row_ids = rsmp_outer_5cv$train_set(1))
Error in assert_surv_matrix(x) :
Survival probabilities must be (non-strictly) decreasing and between [0, 1]
Here is the data
df <- structure(list(age = c(64, 47, 52, 70, 62, 59, 38, 63, 61, 51,
57, 63, 61, 69, 66, 69, 52, 66, 40, 66, 56, 66, 52, 60, 54, 68,
52, 60, 57, 54, 43, 67, 54, 60, 62, 68, 66, 61, 61, 58, 61, 54,
68, 66, 59, 57, 58, 67, 50, 57, 67, 60, 59, 44, 48, 47, 52, 66,
64, 65, 62, 68, 45, 65, 53, 61, 55, 60, 64, 66, 59, 59, 48, 61,
55, 54, 62, 53, 59, 57, 57, 52, 70, 64, 57, 62, 57, 55, 63, 63,
51, 54, 63, 45, 63, 57, 66, 63, 65, 54, 59, 64, 53, 57, 61, 64,
61, 44, 53, 64, 69, 69, 64, 55, 69, 56, 49, 40, 58, 66, 66, 62,
64, 61, 66, 58, 63, 60, 50, 69, 57, 39, 49, 59, 54, 63, 57, 35,
61, 55, 41, 52, 67, 58, 60, 50, 60, 62, 69, 70, 62, 53, 51, 57,
57, 62, 64, 65, 57, 59, 31, 60, 69, 64, 44, 53, 63, 67, 62, 55,
70, 63, 51, 37, 50, 60, 62, 52, 62, 57, 58, 51, 46, 58, 48, 57,
49, 57, 59, 60, 54, 54, 49, 49, 56, 67, 50, 52, 59, 58, 59, 59,
67, 53, 68, 59, 65, 55, 49, 51, 51, 64, 47, 65, 48, 58, 38, 68,
63, 59, 65, 45, 69, 61, 51, 49, 43, 63, 39, 51, 62, 59, 58, 67,
57, 62, 64, 59, 60, 67, 54, 38, 52, 48, 54, 55, 57, 50, 65, 65,
61, 68, 56, 57, 67, 64, 64, 56, 59, 63, 55, 65, 63, 60, 57, 68,
59, 61, 40, 50, 66, 56, 55, 55, 56, 62, 49, 56, 54, 62, 57, 62,
52, 59, 70, 55, 64, 65, 57, 66, 61, 64, 55, 66, 56, 67, 68, 60,
60, 50, 68, 54, 65, 61, 51, 57, 69, 63, 56, 54, 58, 70, 32, 64,
56, 56, 67, 58, 66, 65, 65, 67, 64, 55, 55, 63, 62, 67, 42, 58,
64, 63, 62, 58, 58, 59, 48, 62, 55, 55, 55, 49, 52, 56, 54, 64,
68, 46, 65, 69, 58, 59, 69, 62, 58, 63, 65, 61, 50, 49, 52, 59,
59, 53, 48, 45, 63, 58, 62, 66, 57, 45, 52, 49, 64, 69, 68, 60,
56, 65, 42, 49, 66, 59, 40, 68, 62, 69, 53, 49, 53, 58, 51, 66,
68, 62, 64, 56, 54, 51, 65, 58, 49, 60, 63, 69, 63, 68, 60, 55,
67, 62, 57, 55, 67, 69, 67, 50, 55, 63, 57, 60, 60, 56, 47, 65,
60, 63, 49, 61, 49, 69, 62, 51, 54, 67, 70, 55, 69, 62, 59, 60,
69, 69, 51, 60, 59, 42, 65, 63, 66, 59, 40, 62, 61, 70, 59, 62,
67, 64, 53, 68, 65, 60, 55, 46, 52, 65, 45, 60, 67, 47, 58, 53,
62, 52, 53, 56, 50, 59, 57, 63, 59, 68, 60, 56, 68, 60, 56, 49,
58, 58, 58, 67, 60, 49, 65, 67, 64, 58), gender = c(0, 0, 1,
0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1,
0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1,
0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1,
0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1,
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1,
1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1,
0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0,
1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1,
1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1,
1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1,
1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0,
0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1,
0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0,
1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1,
0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1,
0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1,
0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1,
1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0,
1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1,
1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0), event = c(0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0), time = c(2, 15, 12,
1, 2, 9, 21, 1, 6, 4, 1, 2, 7, 4, 10, 2, 1, 4, 11, 1, 7, 4, 7,
11, 2, 1, 4, 9, 5, 4, 17, 6, 13, 6, 4, 5, 1, 7, 8, 6, 1, 10,
1, 7, 8, 4, 14, 4, 6, 14, 2, 7, 1, 13, 13, 5, 3, 1, 5, 4, 3,
4, 25, 1, 7, 7, 6, 2, 2, 2, 10, 2, 9, 6, 4, 7, 6, 4, 11, 11,
1, 9, 6, 1, 3, 15, 13, 3, 2, 5, 17, 2, 3, 7, 3, 1, 3, 2, 7, 3,
8, 5, 4, 1, 7, 3, 2, 3, 17, 5, 6, 2, 9, 1, 5, 7, 3, 1, 6, 6,
1, 5, 6, 3, 6, 2, 5, 2, 17, 1, 7, 18, 9, 5, 8, 6, 5, 4, 4, 5,
10, 1, 4, 12, 5, 5, 5, 7, 10, 1, 4, 2, 7, 13, 1, 6, 1, 1, 2,
2, 7, 1, 1, 1, 6, 4, 7, 1, 1, 6, 3, 2, 5, 24, 8, 5, 8, 18, 2,
4, 15, 10, 3, 3, 2, 11, 4, 7, 7, 2, 13, 8, 29, 9, 4, 1, 27, 17,
1, 8, 5, 2, 1, 10, 3, 1, 12, 8, 3, 9, 3, 8, 2, 3, 15, 5, 2, 3,
4, 4, 2, 6, 6, 6, 8, 8, 16, 1, 19, 4, 1, 5, 3, 4, 6, 1, 3, 4,
2, 2, 2, 5, 6, 4, 10, 12, 2, 16, 9, 1, 1, 5, 1, 5, 4, 2, 4, 5,
5, 7, 2, 2, 1, 4, 6, 1, 8, 10, 4, 6, 4, 7, 7, 13, 4, 11, 1, 7,
5, 9, 2, 8, 9, 13, 14, 1, 8, 7, 3, 3, 10, 1, 3, 4, 8, 5, 5, 3,
8, 8, 2, 18, 1, 1, 2, 1, 5, 2, 7, 4, 1, 11, 29, 2, 3, 12, 3,
6, 2, 3, 1, 1, 4, 2, 19, 10, 2, 2, 8, 5, 5, 1, 4, 7, 4, 4, 8,
5, 4, 7, 7, 1, 8, 5, 2, 8, 2, 2, 11, 1, 6, 2, 1, 9, 2, 6, 2,
12, 7, 6, 3, 14, 3, 6, 1, 30, 1, 9, 4, 2, 11, 6, 2, 8, 10, 2,
7, 4, 10, 7, 12, 13, 1, 7, 9, 6, 6, 1, 2, 1, 14, 6, 7, 2, 2,
2, 3, 3, 6, 2, 1, 9, 15, 5, 3, 1, 1, 1, 7, 4, 6, 2, 2, 9, 2,
1, 5, 8, 12, 9, 4, 1, 12, 10, 6, 5, 7, 1, 16, 1, 7, 4, 1, 5,
5, 6, 1, 1, 6, 17, 4, 1, 3, 4, 19, 4, 6, 13, 8, 3, 1, 20, 18,
10, 7, 1, 7, 10, 2, 7, 15, 4, 3, 1, 13, 16, 7, 7, 4, 2, 3, 14,
6, 3, 8, 1, 3, 5, 7, 3, 2, 9, 8, 6, 3, 3, 3, 2, 11, 8, 16, 7,
9, 3, 2, 8, 3, 4, 1, 18)), row.names = c(NA, -500L), class = c("tbl_df",
"tbl", "data.frame"))