From sciagent-skills
Explains ML predictions using SHAP (Shapley values) with Tree/Deep/Linear/Kernel explainers and plots (beeswarm, waterfall, bar). Debug models, rank features, audit fairness.
npx claudepluginhub jaechang-hits/sciagent-skills --plugin sciagent-skillsThis skill uses the workspace's default tool permissions.
SHAP (SHapley Additive exPlanations) is a unified framework for explaining machine learning model predictions using Shapley values from cooperative game theory. It quantifies each feature's contribution to individual predictions and provides both local (per-instance) and global (dataset-level) explanations with theoretical guarantees of consistency and additivity.
Computes SHAP values and generates plots (waterfall, beeswarm, bar, scatter, force, heatmap) to explain ML model predictions, feature importance, bias. Supports XGBoost, PyTorch, TensorFlow, black-box models.
Explains ML model predictions with SHAP values and plots (waterfall, beeswarm, bar, force, heatmap) for feature importance, debugging, bias analysis across tree-based, deep learning, linear, black-box models.
Explains ML model predictions using SHAP, LIME, and feature importance to identify influential features and debug behavior.
Share bugs, ideas, or general feedback.
SHAP (SHapley Additive exPlanations) is a unified framework for explaining machine learning model predictions using Shapley values from cooperative game theory. It quantifies each feature's contribution to individual predictions and provides both local (per-instance) and global (dataset-level) explanations with theoretical guarantees of consistency and additivity.
pip install shap matplotlib
# Optional: xgboost lightgbm tensorflow torch (depending on model)
import shap
import xgboost as xgb
from sklearn.model_selection import train_test_split
# Load example data
X, y = shap.datasets.adult()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
# Train model
model = xgb.XGBClassifier(n_estimators=100).fit(X_train, y_train)
# Explain: select explainer → compute → visualize
explainer = shap.TreeExplainer(model)
shap_values = explainer(X_test)
shap.plots.beeswarm(shap_values) # Global importance
shap.plots.waterfall(shap_values[0]) # Single prediction
print(f"Base value: {shap_values.base_values[0]:.3f}")
print(f"SHAP values shape: {shap_values.values.shape}") # (n_samples, n_features)
Choose based on model type:
| Model Type | Explainer | Speed | Exactness |
|---|---|---|---|
| Tree-based (XGBoost, LightGBM, RF, CatBoost) | TreeExplainer | Fast | Exact |
| Linear (LogReg, GLM, Ridge) | LinearExplainer | Instant | Exact |
| Deep learning (TensorFlow, PyTorch) | DeepExplainer | Fast | Approximate |
| Deep learning (gradient-based) | GradientExplainer | Fast | Approximate |
| Any model (black-box) | KernelExplainer | Slow | Approximate |
| Any model (permutation-based) | PermutationExplainer | Very slow | Exact |
| Unsure? | shap.Explainer | Auto | Auto |
# Tree-based models (most common)
explainer = shap.TreeExplainer(model)
# Linear models
explainer = shap.LinearExplainer(model, X_train)
# Deep learning
explainer = shap.DeepExplainer(model, X_train[:100])
# Any model (model-agnostic, slower)
explainer = shap.KernelExplainer(model.predict, shap.kmeans(X_train, 50))
# Auto-select
explainer = shap.Explainer(model, X_train)
shap_values = explainer(X_test)
# shap_values object contains:
# .values — SHAP values array (n_samples, n_features)
# .base_values — Expected model output (baseline)
# .data — Original feature values
# Verify additivity: prediction = base_value + sum(SHAP values)
print(f" {shap_values.base_values[0]:.3f} + {shap_values.values[0].sum():.3f} = "
f"{shap_values.base_values[0] + shap_values.values[0].sum():.3f}")
# Beeswarm: feature importance + value distributions (most informative)
shap.plots.beeswarm(shap_values, max_display=15)
# Bar: clean mean |SHAP| importance
shap.plots.bar(shap_values)
# Waterfall: detailed breakdown of one prediction
shap.plots.waterfall(shap_values[0])
# Force: additive force visualization
shap.plots.force(shap_values[0])
# Scatter: how a feature affects predictions
shap.plots.scatter(shap_values[:, "Age"])
# Colored by interaction feature
shap.plots.scatter(shap_values[:, "Age"], color=shap_values[:, "Education-Num"])
# Heatmap: multi-sample SHAP grid
shap.plots.heatmap(shap_values[:100])
# Decision plot: cumulative SHAP paths
shap.plots.decision(shap_values.base_values[0], shap_values.values[:10],
feature_names=X_test.columns.tolist())
# Cohort comparison
import numpy as np
mask_a = X_test["Age"] < 40
shap.plots.bar({
"Under 40": shap_values[mask_a],
"40+": shap_values[~mask_a]
})
| Parameter | Explainer/Function | Default | Effect |
|---|---|---|---|
feature_perturbation | TreeExplainer | "tree_path_dependent" | "interventional" for causal interpretation (requires background data) |
model_output | TreeExplainer | "raw" | "probability" to explain probabilities instead of log-odds |
data (background) | KernelExplainer, DeepExplainer | Required | 100-1000 representative samples; use shap.kmeans(X, 50) for efficiency |
nsamples | KernelExplainer | "auto" | Higher = more accurate but slower; minimum 2×features |
max_display | All plot functions | 10 | Number of features shown in plots |
alpha | scatter/beeswarm | 1.0 | Point transparency for dense datasets |
show | All plot functions | True | Set False to get matplotlib figure for saving |
clustering | beeswarm | None | shap.utils.hclust(...) to cluster correlated features |
SHAP values have three theoretical guarantees (unique among explanation methods):
prediction = base_value + sum(SHAP values) — exact decompositionInterpretation: Positive SHAP → pushes prediction higher; Negative → lower; Magnitude → strength of impact.
Understand what your model outputs — SHAP explains the output space:
model_output="probability" for probability explanations| Method | Local | Global | Consistent | Model-agnostic |
|---|---|---|---|---|
| SHAP | Yes | Yes | Yes | Yes |
| Permutation importance | No | Yes | No | Yes |
| Gini/split importance | No | Yes | No | Trees only |
| LIME | Yes | No | No | Yes |
| Integrated Gradients | Yes | No | Partial | NN only |
shap_interaction = explainer.shap_interaction_values(X_test)
# Shape: (n_samples, n_features, n_features)
# Diagonal = main effects; off-diagonal = pairwise interactions
Background data establishes the baseline (expected model output). Selection affects SHAP magnitudes but not relative importance.
shap.kmeans(X_train, 50) for efficient summarizationtree_path_dependent: no background data needed (uses tree structure)import numpy as np
# Find misclassified samples
predictions = model.predict(X_test)
errors = predictions != y_test
error_indices = np.where(errors)[0]
# Explain errors
for idx in error_indices[:3]:
print(f"Sample {idx}: predicted={predictions[idx]}, actual={y_test.iloc[idx]}")
shap.plots.waterfall(shap_values[idx])
# Check for data leakage: unexpected high-importance features
mean_abs_shap = np.abs(shap_values.values).mean(0)
top_features = X_test.columns[mean_abs_shap.argsort()[-5:]]
print(f"Top features (check for leakage): {list(top_features)}")
# Compare SHAP distributions across groups
group_a = shap_values[X_test["Sex"] == 0]
group_b = shap_values[X_test["Sex"] == 1]
shap.plots.bar({"Female": group_a, "Male": group_b})
# Check protected attribute importance
sex_importance = np.abs(shap_values[:, "Sex"].values).mean()
total_importance = np.abs(shap_values.values).mean()
print(f"Sex contribution: {sex_importance/total_importance:.1%} of total importance")
import joblib
# Save explainer for reuse
joblib.dump(explainer, 'explainer.pkl')
explainer = joblib.load('explainer.pkl')
# Batch computation for API responses
def explain_batch(X_batch, explainer, top_n=5):
sv = explainer(X_batch)
results = []
for i in range(len(X_batch)):
top_idx = np.abs(sv.values[i]).argsort()[-top_n:]
results.append({
'prediction': sv.base_values[i] + sv.values[i].sum(),
'top_features': {X_batch.columns[j]: sv.values[i][j] for j in top_idx}
})
return results
import mlflow
import matplotlib.pyplot as plt
with mlflow.start_run():
model = xgb.XGBClassifier().fit(X_train, y_train)
explainer = shap.TreeExplainer(model)
shap_values = explainer(X_test)
shap.plots.beeswarm(shap_values, show=False)
mlflow.log_figure(plt.gcf(), "shap_beeswarm.png")
plt.close()
for feat, imp in zip(X_test.columns, np.abs(shap_values.values).mean(0)):
mlflow.log_metric(f"shap_{feat}", imp)
| Output | Type | Description |
|---|---|---|
shap_values | shap.Explanation | Object with .values (n_samples, n_features), .base_values (baseline), .data (input features) |
| Waterfall plot | matplotlib figure | Single-instance explanation showing feature contributions from base value to prediction |
| Beeswarm plot | matplotlib figure | Global summary: feature importance × direction for all samples |
| Bar plot | matplotlib figure | Mean absolute SHAP values per feature (global importance ranking) |
| Force plot | HTML/matplotlib | Interactive or static visualization of a single prediction |
mean_abs_shap | pd.Series | Per-feature mean absolute SHAP value for ranking and reporting |
| Problem | Cause | Solution |
|---|---|---|
| Very slow computation | Using KernelExplainer for tree model | Use TreeExplainer for tree-based models |
| Slow on large dataset | Computing all samples at once | Sample subset: explainer(X_test[:1000]) or batch |
| SHAP values don't sum to prediction | Wrong model output type | Check model_output parameter; verify additivity |
| Log-odds vs probability confusion | Tree classifier defaults to log-odds | Use TreeExplainer(model, model_output="probability") |
| Plots too cluttered | Too many features shown | Set max_display=10 or use feature clustering |
| DeepExplainer error | Background data too small | Use 100-1000 background samples |
| Memory error | Large dataset + many features | Reduce background data with shap.kmeans(X, 50) |
| Force plot not rendering | Missing JS in notebook | Run shap.initjs() at notebook start |
| Inconsistent importance across runs | KernelExplainer sampling variance | Increase nsamples or use deterministic explainer |
| Negative importance for relevant feature | Feature interactions or correlations | Use feature_perturbation="interventional" or scatter plots |
references/theory.md — Mathematical foundations: Shapley value formula, key properties (additivity, symmetry, dummy, monotonicity), computation algorithms (Tree SHAP, Kernel SHAP, Deep SHAP, Linear SHAP), conditional expectations (interventional vs observational), comparison with LIME/DeepLIFT/LRP/Integrated Gradients, interaction values, theoretical limitationsNot migrated from original: references/explainers.md (340 lines) — detailed constructor parameters, methods, and performance benchmarks for each explainer class. Explainer selection guide and common usage are covered inline in Workflow Step 1 and Key Parameters.
Not migrated from original: references/plots.md (508 lines) — comprehensive parameter reference for all 9 plot types with advanced customization (violin, decision, feature clustering). Main plot types are covered inline in Workflow Steps 3-6.
Not migrated from original: references/workflows.md (606 lines) — detailed step-by-step workflows for feature engineering, model comparison, deep learning explanation, production deployment, and time series. Core patterns are covered in Common Recipes; consult original for extended workflows.
TreeExplainer > LinearExplainer > DeepExplainer > KernelExplainer. Only use model-agnostic explainers when no specialized one existsshap.kmeans() for efficiencyfeature_perturbation="interventional" for causal interpretation or feature clustering for grouped importance