What I do:
I analyse different biomarkers from EEG data with the help of different machine learning algorithms and different pre-processing steps etc. This results in several models for each combination of pre-processing step and algorithm.
Each model is trained using StratifiedGroupKFold with a total of 6 folds.
Each fold is saved as a joblib as .joblib
The biomarkers:
Each band of the EEG signal has a number of biomarkers. These biomarkers in turn consist of all signals from all electrodes of the EEG. A biomarker therefore consists of several features, which must not be separated (each biomarker must contain all electrode data).
What I would like to do:
In my first approach, I trained each model with all biomarkers. I would now like to use a feature importance to find out whether I can omit some of them.
To do this, I would like to look at each preprocessing step and each model.
I was recommended SHAP but my problem is that I don’t know how to summarise the folds and the channel of each biomarker.
Here is my first attempt to summarise the folds at least (for the example I am only using all folds from one modell):
for i, (train_index, test_index) in enumerate(sgkf.split(X, y, groups)):
X_test, y_test = X.iloc[test_index], y.iloc[test_index]
# Modell
fold_file = fold_files[i]
clf = joblib.load(fold_file)
# SHAP-Explainer
explainer = shap.LinearExplainer(clf, X_test)
shap_values = explainer.shap_values(X_test)
sv = explainer(X_test)
all_shap_values.append(shap_values)
shap_values_stacked = np.vstack([sv[1] for sv in all_shap_values])
shap_values_mean = np.abs(shap_values_stacked).mean(0)
importance_df = pd.DataFrame({
"feature": columns,
"shap_values": shap_values_mean
})
I first tried it via the explainer.shap_values because that seemed the easiest way. But then I can’t plot it, I need the sv = expaliner(X).
My question is divided into 2 parts:
- how can I summarise the folds? (average value?)
- how can I group the channels of each biomarker? Can I add the values or will I distort the result?
(The biomarkers are named in such a way that I can easily identify the channels )
Thanks in adavance!