|
import shap
|
|
import numpy as np
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
|
|
class ShapExplainer:
|
|
def __init__(self, model, feature_names=None):
|
|
self.model = model
|
|
self.explainer = None
|
|
self.feature_names = feature_names
|
|
|
|
def fit(self, X_background):
|
|
"""Fit the explainer with background data"""
|
|
|
|
if hasattr(self.model, 'preprocessor') and self.model.preprocessor is not None:
|
|
X_processed = self.model.preprocessor.transform(X_background)
|
|
|
|
if hasattr(X_processed, "toarray"):
|
|
X_processed = X_processed.toarray()
|
|
else:
|
|
X_processed = X_background
|
|
|
|
|
|
if hasattr(self.model, 'model'):
|
|
self.explainer = shap.TreeExplainer(self.model.model)
|
|
else:
|
|
self.explainer = shap.TreeExplainer(self.model)
|
|
|
|
def explain_instance(self, instance):
|
|
"""Generate SHAP values for a single instance"""
|
|
|
|
if hasattr(self.model, 'preprocessor') and self.model.preprocessor is not None:
|
|
instance_processed = self.model.preprocessor.transform(instance)
|
|
|
|
if hasattr(instance_processed, "toarray"):
|
|
instance_processed = instance_processed.toarray()
|
|
else:
|
|
instance_processed = instance
|
|
|
|
|
|
shap_values = self.explainer.shap_values(instance_processed)
|
|
|
|
|
|
if isinstance(shap_values, list):
|
|
shap_values = shap_values[1]
|
|
|
|
return shap_values
|
|
|
|
def generate_explanation(self, instance, original_features=None):
|
|
"""Generate human-readable explanation"""
|
|
shap_values = self.explain_instance(instance)
|
|
|
|
|
|
if self.feature_names is None and hasattr(self.model, 'feature_names'):
|
|
self.feature_names = self.model.feature_names
|
|
|
|
|
|
if original_features is not None:
|
|
|
|
feature_importance = [(name, abs(shap_values[0][i]))
|
|
for i, name in enumerate(self.feature_names)]
|
|
feature_importance.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
|
|
explanation = []
|
|
for feature, importance in feature_importance[:5]:
|
|
if importance > 0.01:
|
|
value = original_features[feature].values[0]
|
|
direction = "increased" if shap_values[0][self.feature_names.index(feature)] > 0 else "decreased"
|
|
explanation.append(f"{feature} = {value} {direction} the likelihood")
|
|
|
|
return " and ".join(explanation)
|
|
|
|
return "Explanation requires original feature values"
|
|
|
|
def plot_force(self, instance, matplotlib=True):
|
|
"""Generate force plot for an instance"""
|
|
shap_values = self.explain_instance(instance)
|
|
|
|
if matplotlib:
|
|
shap.force_plot(self.explainer.expected_value,
|
|
shap_values,
|
|
instance,
|
|
feature_names=self.feature_names,
|
|
matplotlib=True)
|
|
else:
|
|
return shap.force_plot(self.explainer.expected_value,
|
|
shap_values,
|
|
instance,
|
|
feature_names=self.feature_names) |