I am using 10 fold cross validation, trying to predict binary labels (Y) based on the embedding inputs (X).
I want to save one of the models (perhaps the one with the highest ROC AUC). I’m not sure how to do it because the ROC AUCs are not stored and I don’t know how to grab accordingly.
X = np.array([np.array(x) for x in df['embeddings'].values])
y = df['label'].values
groups = df['chromosome'].values
group_kfold = GroupKFold(n_splits=n_folds)
Initialize figure for plotting
fig, axes = plt.subplots(1, 2, figsize=(15, 6))
all_fpr = []
all_tpr = []
all_accuracy = []
all_pr_auc = []
Perform cross-validation and plot ROC and PR curves for each fold
for i, (train_idx, val_idx) in enumerate(group_kfold.split(X, y, groups)):
X_train_fold, X_val_fold = X[train_idx], X[val_idx]
y_train_fold, y_val_fold = y[train_idx], y[val_idx]
# Initialize classifier
rf_classifier = RandomForestClassifier(n_estimators=n_trees, random_state=42, max_depth=max_depth, n_jobs=-1)
# Train the classifier on this fold
rf_classifier.fit(X_train_fold, y_train_fold)
# Make predictions on the validation set
y_pred_proba = rf_classifier.predict_proba(X_val_fold)[:, 1]
# Calculate ROC curve
fpr, tpr, _ = roc_curve(y_val_fold, y_pred_proba)
all_fpr.append(fpr)
all_tpr.append(tpr)
# Calculate AUC
roc_auc = auc(fpr, tpr)
# Plot ROC curve for this fold
axes[0].plot(fpr, tpr, lw=1, alpha=0.7, label=f'ROC Fold {i+1} (AUC = {roc_auc:.2f})')
# Calculate precision-recall curve
precision, recall, _ = precision_recall_curve(y_val_fold, y_pred_proba)
# Calculate PR AUC
pr_auc = auc(recall, precision)
all_pr_auc.append(pr_auc)
# Plot PR curve for this fold
axes[1].plot(recall, precision, lw=1, alpha=0.7, label=f'PR Curve Fold {i+1} (AUC = {pr_auc:.2f})')
# Calculate accuracy
accuracy = accuracy_score(y_val_fold, rf_classifier.predict(X_val_fold))
all_accuracy.append(accuracy)
# Initialize empty arrays to store interpolated TPR values
interpolated_tpr = []
# Define common set of thresholds
mean_fpr = np.linspace(0, 1, 100)
# Interpolate TPR values for each fold to the common set of thresholds
for fpr, tpr in zip(all_fpr, all_tpr):
interpolated_tpr.append(np.interp(mean_fpr, fpr, tpr))
# Calculate the mean and standard deviation of interpolated TPR values
mean_tpr = np.mean(interpolated_tpr, axis=0)
std_tpr = np.std(interpolated_tpr, axis=0)
# Plot the mean ROC curve with shaded area representing the standard deviation
axes[0].plot(mean_fpr, mean_tpr, color='black', linestyle='--', lw=2, label=f'Average ROC curve ({np.round(auc(mean_fpr, mean_tpr), 2)})')
axes[0].fill_between(mean_fpr, mean_tpr - std_tpr, mean_tpr + std_tpr, color='grey', alpha=0.2)
# Plot ROC for random classifier
axes[0].plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', alpha=0.8)