|
15 | 15 | from darts.dataprocessing.transformers import BoxCox, Scaler |
16 | 16 | from darts.metrics import mape |
17 | 17 | from darts.tests.conftest import TORCH_AVAILABLE, tfm_kwargs, tfm_kwargs_dev |
| 18 | +from darts.utils.data.torch_datasets.inference_dataset import ( |
| 19 | + SequentialTorchInferenceDataset, |
| 20 | +) |
| 21 | +from darts.utils.data.torch_datasets.training_dataset import ( |
| 22 | + SequentialTorchTrainingDataset, |
| 23 | +) |
18 | 24 |
|
19 | 25 | if not TORCH_AVAILABLE: |
20 | 26 | pytest.skip( |
@@ -232,7 +238,8 @@ def on_train_epoch_end(self, trainer, pl_module): |
232 | 238 | no_train_model.predict(n=4) |
233 | 239 | assert str(err.value) == ( |
234 | 240 | "Input `series` must be provided. This is the result either from fitting on multiple series, " |
235 | | - "from not having fit the model yet, or from loading a model saved with `clean=True`." |
| 241 | + "from fitting with `fit_from_dataset()`, from not having fit the model yet, or from loading a " |
| 242 | + "model saved with `clean=True`." |
236 | 243 | ) |
237 | 244 |
|
238 | 245 | model_manual_save.fit(self.series, epochs=1) |
@@ -284,7 +291,8 @@ def on_train_epoch_end(self, trainer, pl_module): |
284 | 291 | model_manual_save.predict(n=4) |
285 | 292 | assert str(err.value) == ( |
286 | 293 | "Input `series` must be provided. This is the result either from fitting on multiple series, " |
287 | | - "from not having fit the model yet, or from loading a model saved with `clean=True`." |
| 294 | + "from fitting with `fit_from_dataset()`, from not having fit the model yet, or from loading a " |
| 295 | + "model saved with `clean=True`." |
288 | 296 | ) |
289 | 297 | # Predicting while giving the training series in args should yield same prediction |
290 | 298 | assert model_manual_save.predict( |
@@ -2374,6 +2382,173 @@ def test_fit_with_stride(self, stride): |
2374 | 2382 | assert len(train_set) == len(val_set) == math.ceil(3 / stride) |
2375 | 2383 | assert train_set.stride == val_set.stride == stride |
2376 | 2384 |
|
| 2385 | + def test_predict_after_fit_from_dataset(self): |
| 2386 | + """Test that the model can predict after being trained with `fit_from_dataset` using all covariates.""" |
| 2387 | + icl, ocl = kwargs["input_chunk_length"], kwargs["output_chunk_length"] |
| 2388 | + n = 1 |
| 2389 | + series = [ |
| 2390 | + self.series[: icl + ocl].with_static_covariates(pd.DataFrame({"sc": [0.0]})) |
| 2391 | + ] |
| 2392 | + pc = [self.series[: icl + ocl]] |
| 2393 | + fc = [self.series[: icl + ocl + n]] |
| 2394 | + |
| 2395 | + model = TiDEModel(**kwargs) |
| 2396 | + |
| 2397 | + train_dataset = SequentialTorchTrainingDataset( |
| 2398 | + series=series, |
| 2399 | + past_covariates=pc, |
| 2400 | + future_covariates=fc, |
| 2401 | + input_chunk_length=icl, |
| 2402 | + output_chunk_length=ocl, |
| 2403 | + use_static_covariates=True, |
| 2404 | + ) |
| 2405 | + |
| 2406 | + # check training works and covariates are used |
| 2407 | + model.fit_from_dataset(train_dataset=train_dataset) |
| 2408 | + assert model.uses_past_covariates |
| 2409 | + assert model.uses_future_covariates |
| 2410 | + assert model.uses_static_covariates |
| 2411 | + assert model._expect_past_covariates |
| 2412 | + assert model._expect_future_covariates |
| 2413 | + assert model._expect_static_covariates |
| 2414 | + |
| 2415 | + with pytest.raises(ValueError) as exc: |
| 2416 | + _ = model.predict(n=n) |
| 2417 | + assert str(exc.value).startswith("Input `series` must be provided.") |
| 2418 | + |
| 2419 | + hfc_kwargs = {"forecast_horizon": n, "retrain": False, "overlap_end": True} |
| 2420 | + self.helper_predict_raise_on_missing_input( |
| 2421 | + model, "predict", series, pc, fc, n=n |
| 2422 | + ) |
| 2423 | + self.helper_predict_raise_on_missing_input( |
| 2424 | + model, "historical_forecasts", series, pc, fc, **hfc_kwargs |
| 2425 | + ) |
| 2426 | + self.helper_predict_from_ds_raise_on_missing_input( |
| 2427 | + model, |
| 2428 | + series, |
| 2429 | + pc, |
| 2430 | + fc, |
| 2431 | + n=n, |
| 2432 | + input_chunk_length=icl, |
| 2433 | + output_chunk_length=ocl, |
| 2434 | + ) |
| 2435 | + |
| 2436 | + # check predict methods |
| 2437 | + inference_dataset = SequentialTorchInferenceDataset( |
| 2438 | + series=series, |
| 2439 | + past_covariates=pc, |
| 2440 | + future_covariates=fc, |
| 2441 | + n=n, |
| 2442 | + input_chunk_length=icl, |
| 2443 | + output_chunk_length=ocl, |
| 2444 | + use_static_covariates=True, |
| 2445 | + ) |
| 2446 | + pred1 = model.predict( |
| 2447 | + n=n, series=series, past_covariates=pc, future_covariates=fc |
| 2448 | + )[0] |
| 2449 | + pred2 = model.predict_from_dataset(n=n, dataset=inference_dataset)[0] |
| 2450 | + pred3 = model.historical_forecasts( |
| 2451 | + forecast_horizon=n, |
| 2452 | + series=series, |
| 2453 | + past_covariates=pc, |
| 2454 | + future_covariates=fc, |
| 2455 | + retrain=False, |
| 2456 | + overlap_end=True, |
| 2457 | + )[0] |
| 2458 | + # extract only the last hist fc which should be the same as the regular predictions |
| 2459 | + pred3 = pred3[-1] |
| 2460 | + assert pred1 == pred2 |
| 2461 | + np.testing.assert_array_almost_equal(pred3.all_values(), pred1.all_values()) |
| 2462 | + assert pred3.time_index.equals(pred1.time_index) |
| 2463 | + assert pred3.static_covariates.equals(series[0].static_covariates) |
| 2464 | + |
| 2465 | + def helper_predict_raise_on_missing_input( |
| 2466 | + self, model, fn: str, series, pc, fc, **kwargs |
| 2467 | + ): |
| 2468 | + """Helper function to test that the model raises an error when calling `predict()` or `historical_forecasts()` |
| 2469 | + after `fit_from_dataset()` with missing inputs.""" |
| 2470 | + with pytest.raises(ValueError) as exc: |
| 2471 | + _ = getattr(model, fn)(series=series, **kwargs) |
| 2472 | + assert str(exc.value).startswith("The model was trained with past covariates.") |
| 2473 | + with pytest.raises(ValueError) as exc: |
| 2474 | + _ = getattr(model, fn)(series=series, past_covariates=pc, **kwargs) |
| 2475 | + assert str(exc.value).startswith( |
| 2476 | + "The model was trained with future covariates." |
| 2477 | + ) |
| 2478 | + with pytest.raises(ValueError) as exc: |
| 2479 | + _ = getattr(model, fn)(series=series, future_covariates=fc, **kwargs) |
| 2480 | + assert str(exc.value).startswith("The model was trained with past covariates.") |
| 2481 | + with pytest.raises(ValueError) as exc: |
| 2482 | + _ = getattr(model, fn)( |
| 2483 | + series=[series[0].with_static_covariates(None)], |
| 2484 | + past_covariates=pc, |
| 2485 | + future_covariates=fc, |
| 2486 | + **kwargs, |
| 2487 | + ) |
| 2488 | + assert str(exc.value).startswith( |
| 2489 | + "The model was trained with static covariates." |
| 2490 | + ) |
| 2491 | + |
| 2492 | + def helper_predict_from_ds_raise_on_missing_input( |
| 2493 | + self, |
| 2494 | + model, |
| 2495 | + series, |
| 2496 | + pc, |
| 2497 | + fc, |
| 2498 | + n, |
| 2499 | + **kwargs, |
| 2500 | + ): |
| 2501 | + """Helper function to test that the model raises an error when calling `predict_from_dataset()` after |
| 2502 | + `fit_from_dataset()` with missing inputs.""" |
| 2503 | + inf_dataset = SequentialTorchInferenceDataset( |
| 2504 | + n=n, |
| 2505 | + series=series, |
| 2506 | + **kwargs, |
| 2507 | + ) |
| 2508 | + with pytest.raises(ValueError) as exc: |
| 2509 | + _ = model.predict_from_dataset(n=n, dataset=inf_dataset) |
| 2510 | + assert str(exc.value).startswith( |
| 2511 | + "This model has been trained with `past_covariates`" |
| 2512 | + ) |
| 2513 | + |
| 2514 | + inf_dataset = SequentialTorchInferenceDataset( |
| 2515 | + n=n, |
| 2516 | + series=series, |
| 2517 | + past_covariates=pc, |
| 2518 | + **kwargs, |
| 2519 | + ) |
| 2520 | + with pytest.raises(ValueError) as exc: |
| 2521 | + _ = model.predict_from_dataset(n=n, dataset=inf_dataset) |
| 2522 | + assert str(exc.value).startswith( |
| 2523 | + "This model has been trained with `historic_future_covariates`" |
| 2524 | + ) |
| 2525 | + |
| 2526 | + inf_dataset = SequentialTorchInferenceDataset( |
| 2527 | + n=n, |
| 2528 | + series=series, |
| 2529 | + future_covariates=fc, |
| 2530 | + **kwargs, |
| 2531 | + ) |
| 2532 | + with pytest.raises(ValueError) as exc: |
| 2533 | + _ = model.predict_from_dataset(n=n, dataset=inf_dataset) |
| 2534 | + assert str(exc.value).startswith( |
| 2535 | + "This model has been trained with `past_covariates`" |
| 2536 | + ) |
| 2537 | + |
| 2538 | + inf_dataset = SequentialTorchInferenceDataset( |
| 2539 | + n=n, |
| 2540 | + series=series, |
| 2541 | + past_covariates=pc, |
| 2542 | + future_covariates=fc, |
| 2543 | + use_static_covariates=False, |
| 2544 | + **kwargs, |
| 2545 | + ) |
| 2546 | + with pytest.raises(ValueError) as exc: |
| 2547 | + _ = model.predict_from_dataset(n=n, dataset=inf_dataset) |
| 2548 | + assert str(exc.value).startswith( |
| 2549 | + "This model has been trained with `static_covariates`" |
| 2550 | + ) |
| 2551 | + |
2377 | 2552 | def helper_equality_encoders( |
2378 | 2553 | self, first_encoders: dict[str, Any], second_encoders: dict[str, Any] |
2379 | 2554 | ): |
|
0 commit comments