Context
Let’s imagine I am interested in predicting sepal length in the iris dataset using catboost.
Objective
My main objective is understanding the effect of each categorical value for target
in the fitting. For that I will be using SHAP to understand contributions on the prediction.
from catboost import CatBoostRegressor
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import pandas as pd
# Load Iris dataset
iris = load_iris()
data = pd.DataFrame(data=iris.data, columns=iris.feature_names)
data['target'] = iris.target
# Set the target to sepal length
X = data.drop(columns=['sepal length (cm)'])
y = data['sepal length (cm)']
# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Initialize and train the CatBoostClassifier
# Initialize and train the CatBoostRegressor
cat_features = ['target']
model = CatBoostRegressor(iterations=100, learning_rate=0.1, depth=6, verbose=False)
model.fit(X_train, y_train, cat_features=cat_features)
Now I use shap
import shap
shap.initjs()
# Calculate SHAP values
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
# Visualize SHAP values
shap.summary_plot(shap_values, X, feature_names=X.columns)
But I am only interested in knowing the associated effect of a specific value in the target variable (for example value 0 that is associated to setosa).
For that I just filter the dataset where the observations are 0 in target.
# Filter the dataset for target = 0
X_test_target_0 = X_test[X_test['target'] == 0]
# Get the corresponding SHAP values
shap_values_target_0 = explainer.shap_values(X_test_target_0)
# Display SHAP values for target = 0 observation 0
shap.summary_plot(shap_values_target_0, X_test_target_0, feature_names=X.columns)
shap_values_instance = shap_values_target_0[1]
# Display SHAP values for target = 0 observation 1
shap.force_plot(explainer.expected_value, shap_values_instance, X_test_target_0.iloc[1], feature_names=X.columns,matplotlib=True)
Question to solve
However when I check the values associated to it they are very different and I assumed they would be equal.
print(X.columns)
target_index = list(X.columns).index('target')
shap_values_target_0
Why are they different?, my assumption was they should be the same.
Thank you for the answer in advanced ! 🙂
3