Skip to content

[Feature] SHAP-based explainer for torch models#3049

Merged
dennisbader merged 159 commits into
unit8co:masterfrom
daidahao:feature/shap-torch
Jun 19, 2026
Merged

[Feature] SHAP-based explainer for torch models#3049
dennisbader merged 159 commits into
unit8co:masterfrom
daidahao:feature/shap-torch

Conversation

@daidahao

@daidahao daidahao commented Mar 27, 2026

Copy link
Copy Markdown
Contributor

Checklist before merging this PR:

  • Mentioned all issues that this PR fixes or addresses.
  • Summarized the updates of this PR under Summary.
  • Added an entry under Unreleased in the Changelog.

Fixes #871. Fixes #2788. Fixes #2296. Fixes #2571. Fixes #1262. Fixes #2566. Fixes #1332.

Summary

This PR:

Improved

  • Improvements to ShapExplainer :
    • 🚀🚀 ShapExplainer can now also explain any TorchForecastingModel including regular torch models (TiDEModel, ...) as well as foundation models (Chronos2, ...). It supports global and local explanations and can output SHAP values for further analysis.
    • Added method explain_single() to explain a single model forecast in detail, in addition to the existing batched method explain(). This is useful for local explanations of individual predictions with reduced computational cost.
    • Method summary_plot() can now also be computed on any optional foreground series using parameters foreground_series, foreground_past_covariates, foreground_future_covariates.
    • ShapExplainer can now also explain the forecasted likelihood parameter of probabilistic forecasts.
    • Added a new notebook for Explainability of Forecasting Models including detailed usage examples of ShapExplainer.
    • 🔴 Renamed method force_plot_from_ts() to force_plot() to simplify.

Fixed

  • Fixed several bugs in ShapExplainer including mismatched SHAP method enum values, feature naming conventions, inconsistent instance count in explain().
  • Fixed a bug in explainability utils where stationarity tests were not properly conducted due to usage of all().

For developers of the library:

  • Added ShapSingleExplainabilityResult class as the return type of explain_single() method in ShapExplainer and to store the SHAP results of a single instance explanation. This is in contrast to the existing ShapExplainabilityResult which stores results for batched explanations.

Motivation

An increasing number of models in Darts are torch-based (recently #3002, #2980, #2944) and users need a consistent way to explain their forecasts.

For scikit-learn models, the existing ShapExplainer provides SHAP-based explanations with method selection based on model type.
For torch models, we need an explainer that can handle the different model architectures, while conforming to existing explainability API patterns.

  • Why SHAP? SHAP gives additive, model-agnostic feature attributions that are consistent across explainers.
  • Why Permutation Explainer? For torch models, defaulting to permutation provides general applicability and faster explanations than kernel or sampling. Users can choose other SHAP methods if desired.
  • Why not DeepExplainer or GradientExplainer? Both are designed for deep learning models and are faster than KernelSHAP. However, they have limitations (from my experiments):
    • DeepExplainer is incompatible with many torch models due to reused layers.
    • Both do not output base values, which are needed for consistent SHAP result objects and visualizations (e.g., waterfall, force plots).
  • Why not captum? Meta's PyTorch native library supports various attribution methods (Integrated Gradients, DeepLIFT, etc.) and is efficient for torch models. However, as of now, it does not support multi-target explanations. Forecasts in Darts are multi-target in nature (multiple horizons x components x likelihood parameters), so using captum would incur for-loop overhead.
  • Future: We can consider supporting DeepExplainer/GradientExplainer as additional SHAP methods in the future if they yield better efficiency for some torch models. This would require wrapping PLForecastingModule in a generic nn.Module that can be explained by these methods, in addition to the current numpy-based function wrapper.

Design

  • ShapExplainer now supports TorchForecastingModel under same unified API
    • This is achieved via new model-specific SHAP adapters (SKLearnShapAdapter and TorchShapAdapter)
  • For torch: It builds SHAP inputs from torch inference datasets to stay consistent with Darts prediction semantics.
  • It handles deterministic and probabilistic models (for probabilistic models, explanations are produced for likelihood parameter components).

Methods

  • explain() for horizon/component-level explanations over forecastable timestamps.
  • explain_single() for one forecast instance (equivalent prediction context to predict(n=output_chunk_length)).
  • summary_plot() shows distributions of feature contributions.
  • force_plot() shows feature contributions for a specific horizon/component.

Use Cases

Summary Plot

Feature-importance distribution analysis per horizon/component for torch models.

import shap
shap.initjs()

from darts.datasets import WineDataset
from darts.explainability import ShapExplainer
from darts.models import TiDEModel

series = WineDataset().load().astype("float32")
model = TiDEModel(12, 12).fit(series[:36])
explainer = ShapExplainer(model)
explainer.summary_plot(horizons=[1])
summary

Force Plot

Local additive contribution view for a selected horizon and target component.

explainer.force_plot(horizon=1)
Screenshot 2026-03-27 at 10 41 52

Explaining Multiple Instances

Batch explanations from foreground data with optional sampling controls for performance.

result = explainer.explain(series[:36])
# return a `TimeSeries` of SHAP values where time index
# corresponds to the instance timestamps
result.get_explanation(horizon=1)
# return the raw SHAP explanation object for custom visualizations
shap_object = result.get_shap_explanation_object(horizon=1)
# plot waterfall for the first forecast instance
shap.plots.waterfall(shap_object[0])
waterfall

Explaining Single Instance

Per-instance explanation API (explain_single()) for local interpretability.

single_result = explainer.explain_single(series[:36])
# return a `TimeSeries` of SHAP values where time index corresponds to the **prediction** timestamp
single_result.get_explanation()
# return the raw SHAP explanation object for custom visualizations
single_shap_object = single_result.get_shap_explanation_object()
# plot heatmap for the single instance explanation along the horizon
shap.plots.heatmap(single_shap_object, instance_order=np.arange(12))
heatmap

Explaining Probabilistic Forecasts

Probabilistic torch models are supported by explaining each likelihood parameter component, treating them as separate targets. This is useful for understanding how features contribute to uncertainty estimates.

from darts.utils.likelihood_models import QuantileRegression
# fit a probabilistic model with quantile regression likelihood
prob_model = TiDEModel(12, 12, likelihood=QuantileRegression(quantiles=[0.1, 0.5, 0.9]))
prob_model.fit(series[:36])
# create an explainer for the probabilistic model
prob_explainer = ShapExplainer(prob_model)
# explain the probabilistic forecasts
# this will produce explanations for each likelihood parameter component
# (e.g., Y_q0.100, Y_q0.500, Y_q0.900)
prob_result = prob_explainer.explain(series[:36])
# get SHAP values as a `TimeSeries` for the 0.1 quantile at horizon 1
prob_result.get_explanation(horizon=1, component="Y_q0.100")
            Y_target_lag-12  Y_target_lag-11  Y_target_lag-10  Y_target_lag-9  Y_target_lag-8  ...  Y_target_lag-5  Y_target_lag-4  Y_target_lag-3  Y_target_lag-2  Y_target_lag-1
1981-01-01     -3697.863974      -252.308866       -41.762030        0.572893    -1353.563396  ...      -91.867447     -128.090894      -39.738832      208.761212      -53.789530
1981-02-01     -2648.507187       -80.287658       -53.070808       45.709788     -821.725195  ...       30.775013      -14.196725     -861.172957      392.385305      139.613268
1981-03-01      -477.149089      -195.594982       -51.709828       14.808723      553.345521  ...      -35.324536      273.863595    -1509.285057     -393.562763       31.391378
1981-04-01     -1998.012530      -171.417969       -66.743827      -61.221131      904.461867  ...     -232.607326      590.834413     1727.039878     -267.061452        7.137201
1981-05-01     -1777.624966      -124.021548        -9.517231       -8.676134     -132.987919  ...     -384.271071     -501.348593      927.618172     -129.376857       22.660183
...                     ...              ...              ...             ...             ...  ...             ...             ...             ...             ...             ...
1982-09-01       169.768241        50.163736       159.011448       63.019742    -1706.948496  ...      -76.503501      -20.519832       74.718658      225.857574      -72.105523
1982-10-01      1054.076957       454.064094       260.788809       60.042956    -1343.217095  ...        6.109713      -43.923037     -872.911818      249.960171      -23.256607
1982-11-01      4563.048261       555.714791       -94.331751       99.717854     -361.191972  ...       15.797448      297.566318    -1125.641917       62.627629      -13.438255
1982-12-01      6351.228711      -202.557555       -23.213595      -24.084026      838.380352  ...     -266.064357      417.857843     -372.773744        8.018883      -16.682221
1983-01-01     -2495.383209      -168.530121       -40.911461      -65.382957      382.564556  ...     -297.340957      112.898120     -220.364707      327.032420      -58.788773

shape: (25, 12, 1), freq: MS, size: 2.34 KB

Bug Fixes

  • Improved input processing for explainers by using prediction-aware encoder generation for foreground data (generate_fit_predict_encodings), improving consistency with forecasting behavior.
  • Better validation and clearer errors in explainability result querying (component/horizon checks).
  • Improved stationarity warnings to indicate the specific component and series index.

daidahao added 30 commits March 27, 2026 10:26
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
This ensure that the last possible index is always explained when
`add_encoders` is used.

Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
@daidahao

daidahao commented May 28, 2026

Copy link
Copy Markdown
Contributor Author

FYI, SHAP has just release a new minor version 0.52.0, which include major updates to binary build and distribution.

I have re-run the explainer unit tests locally (overriding the Darts newer package cap) and all of them have passed.

cc: @CloseChoice

@review-notebook-app

review-notebook-app Bot commented Jun 10, 2026

Copy link
Copy Markdown

View / edit / reply to this conversation on ReviewNB

daidahao commented on 2026-06-10T19:03:54Z
----------------------------------------------------------------

I would suggest using the Darts-style here, just like other notebooks:

set_option("plotting.use_darts_style", True)

@review-notebook-app

review-notebook-app Bot commented Jun 10, 2026

Copy link
Copy Markdown

View / edit / reply to this conversation on ReviewNB

daidahao commented on 2026-06-10T19:03:55Z
----------------------------------------------------------------

Are we certain that ShapExplainer can explain likelihood parameters of SKLearn models? I thought it could only explain model whose likelihood="quantile".


@dennisbader dennisbader left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks a lot @daidahao 💯

I pushed several changes and updated the PR description accordingly.

@dennisbader dennisbader merged commit c96169f into unit8co:master Jun 19, 2026
9 checks passed
@daidahao

Copy link
Copy Markdown
Contributor Author

@dennisbader

Thank you for all the refactoring and improvements. Glad that we pushed it past the finish line! 🏁

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment