Skip to content

Commit ba922a6

Browse files
authored
fix tests for m1 (#2682)
1 parent 2000b35 commit ba922a6

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

darts/tests/models/forecasting/test_ensemble_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -839,7 +839,7 @@ def test_save_load_ensemble_models(self, tmpdir_fn, model_cls):
839839
# test clean save
840840
path = os.path.join(full_model_path_str, f"clean_{model_cls.__name__}.pkl")
841841
model.save(path, clean=True)
842-
clean_model = model_cls.load(path)
842+
clean_model = model_cls.load(path, pl_trainer_kwargs={"accelerator": "cpu"})
843843
for i, m in enumerate(clean_model.forecasting_models):
844844
if not issubclass(type(m), LocalForecastingModel):
845845
assert m.training_series is None

0 commit comments

Comments
 (0)