[Feature] SHAP-based explainer for torch models#3049
Conversation
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
This ensure that the last possible index is always explained when `add_encoders` is used. Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
|
FYI, SHAP has just release a new minor version 0.52.0, which include major updates to binary build and distribution. I have re-run the explainer unit tests locally (overriding the Darts newer package cap) and all of them have passed. cc: @CloseChoice |
|
View / edit / reply to this conversation on ReviewNB daidahao commented on 2026-06-10T19:03:54Z I would suggest using the Darts-style here, just like other notebooks:
set_option("plotting.use_darts_style", True)
|
|
View / edit / reply to this conversation on ReviewNB daidahao commented on 2026-06-10T19:03:55Z Are we certain that |
dennisbader
left a comment
There was a problem hiding this comment.
Looks great, thanks a lot @daidahao 💯
I pushed several changes and updated the PR description accordingly.
|
Thank you for all the refactoring and improvements. Glad that we pushed it past the finish line! 🏁 |
Checklist before merging this PR:
Fixes #871. Fixes #2788. Fixes #2296. Fixes #2571. Fixes #1262. Fixes #2566. Fixes #1332.
Summary
This PR:
Improved
ShapExplainer:ShapExplainercan now also explain anyTorchForecastingModelincluding regular torch models (TiDEModel, ...) as well as foundation models (Chronos2, ...). It supports global and local explanations and can output SHAP values for further analysis.explain_single()to explain a single model forecast in detail, in addition to the existing batched methodexplain(). This is useful for local explanations of individual predictions with reduced computational cost.summary_plot()can now also be computed on any optional foreground series using parametersforeground_series,foreground_past_covariates,foreground_future_covariates.ShapExplainercan now also explain the forecasted likelihood parameter of probabilistic forecasts.ShapExplainer.force_plot_from_ts()toforce_plot()to simplify.Fixed
ShapExplainerincluding mismatched SHAP method enum values, feature naming conventions, inconsistent instance count inexplain().all().For developers of the library:
ShapSingleExplainabilityResultclass as the return type ofexplain_single()method inShapExplainerand to store the SHAP results of a single instance explanation. This is in contrast to the existingShapExplainabilityResultwhich stores results for batched explanations.Motivation
An increasing number of models in Darts are torch-based (recently #3002, #2980, #2944) and users need a consistent way to explain their forecasts.
For scikit-learn models, the existing
ShapExplainerprovides SHAP-based explanations with method selection based on model type.For torch models, we need an explainer that can handle the different model architectures, while conforming to existing explainability API patterns.
permutationprovides general applicability and faster explanations thankernelorsampling. Users can choose other SHAP methods if desired.PLForecastingModulein a genericnn.Modulethat can be explained by these methods, in addition to the current numpy-based function wrapper.Design
ShapExplainernow supports TorchForecastingModel under same unified APIMethods
explain()for horizon/component-level explanations over forecastable timestamps.explain_single()for one forecast instance (equivalent prediction context topredict(n=output_chunk_length)).summary_plot()shows distributions of feature contributions.force_plot()shows feature contributions for a specific horizon/component.Use Cases
Summary Plot
Feature-importance distribution analysis per horizon/component for torch models.
Force Plot
Local additive contribution view for a selected horizon and target component.
Explaining Multiple Instances
Batch explanations from foreground data with optional sampling controls for performance.
Explaining Single Instance
Per-instance explanation API (
explain_single()) for local interpretability.Explaining Probabilistic Forecasts
Probabilistic torch models are supported by explaining each likelihood parameter component, treating them as separate targets. This is useful for understanding how features contribute to uncertainty estimates.
Bug Fixes
generate_fit_predict_encodings), improving consistency with forecasting behavior.