Skip to content

Commit d7fc879

Browse files
authored
Merge branch 'master' into feat/ts_representation
2 parents 1661f17 + a0cc926 commit d7fc879

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

darts/models/forecasting/torch_forecasting_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,10 +799,13 @@ def _randomize(shape) -> Optional[torch.Tensor]:
799799
if self.uses_static_covariates:
800800
input_names.append("x_static")
801801

802+
# TODO: `dynamo=True` should be the way to go since PyTorch 2.9; we have to wait until RNN module onnx exports
803+
# are fixed
802804
self.model.to_onnx(
803805
file_path=path,
804806
input_sample=(input_sample,),
805807
input_names=input_names,
808+
dynamo=False,
806809
**kwargs,
807810
)
808811

darts/tests/explainability/test_shap_explainer.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,10 @@
2121
LightGBMModel,
2222
LinearRegressionModel,
2323
SKLearnModel,
24-
XGBModel,
2524
)
2625
from darts.tests.conftest import (
2726
GBM_AVAILABLE,
2827
LGBM_AVAILABLE,
29-
XGB_AVAILABLE,
3028
)
3129
from darts.utils.timeseries_generation import linear_timeseries
3230

@@ -162,16 +160,17 @@ class TestShapExplainer:
162160
"output_chunk_length": 4,
163161
},
164162
},
165-
{
166-
"model_cls": XGBModel,
167-
"config": {
168-
"lags": 4,
169-
"lags_past_covariates": [-1, -2, -3],
170-
"lags_future_covariates": [0],
171-
"output_chunk_length": 4,
172-
"add_encoders": add_encoders,
173-
},
174-
},
163+
# # TODO: add back test once shap fixes issue https://github.com/shap/shap/issues/4184
164+
# {
165+
# "model_cls": XGBModel,
166+
# "config": {
167+
# "lags": 4,
168+
# "lags_past_covariates": [-1, -2, -3],
169+
# "lags_future_covariates": [0],
170+
# "output_chunk_length": 4,
171+
# "add_encoders": add_encoders,
172+
# },
173+
# },
175174
],
176175
)
177176
def test_gbm_creation(self, model):
@@ -724,7 +723,8 @@ def test_shap_explanation_object_validity(self):
724723
@pytest.mark.parametrize(
725724
"config",
726725
[(LinearRegressionModel, {})]
727-
+ ([(XGBModel, {})] if XGB_AVAILABLE else [])
726+
# # TODO: add back test once shap fixes issue https://github.com/shap/shap/issues/4184
727+
# + ([(XGBModel, {})] if XGB_AVAILABLE else [])
728728
+ (
729729
[(LightGBMModel, {"likelihood": "quantile", "quantiles": [0.5]})]
730730
if LGBM_AVAILABLE

0 commit comments

Comments
 (0)