Skip to content

Commit c116405

Browse files
authored
fix bug in create lagged component names without target lags (#2576)
* fix bug in create lagged component names without target lags * update changelog * clean up diffs
1 parent 0b9efd0 commit c116405

File tree

3 files changed

+77
-18
lines changed

3 files changed

+77
-18
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1313

1414
**Fixed**
1515

16+
- Fixed a bug when using `darts.utils.data.tabularization.create_lagged_component_names()` with target `lags=None`, that did not return any lagged target label component names. [#2576](https://github.com/unit8co/darts/pull/2576) by [Dennis Bader](https://github.com/dennisbader).
17+
1618
**Dependencies**
1719

1820
### For developers of the library:

darts/tests/utils/tabularization/test_create_lagged_training_data.py

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2252,7 +2252,9 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
22522252
None,
22532253
None,
22542254
False,
2255+
1,
22552256
["no_static_target_lag-2", "no_static_target_lag-1"],
2257+
["no_static_target_hrz0"],
22562258
),
22572259
# target with static covariate (but don't use them in feature names)
22582260
(
@@ -2263,12 +2265,19 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
22632265
None,
22642266
None,
22652267
False,
2268+
2,
22662269
[
22672270
"static_0_target_lag-4",
22682271
"static_1_target_lag-4",
22692272
"static_0_target_lag-1",
22702273
"static_1_target_lag-1",
22712274
],
2275+
[
2276+
"static_0_target_hrz0",
2277+
"static_1_target_hrz0",
2278+
"static_0_target_hrz1",
2279+
"static_1_target_hrz1",
2280+
],
22722281
),
22732282
# target with static covariate (acting on global target components)
22742283
(
@@ -2279,13 +2288,18 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
22792288
None,
22802289
None,
22812290
True,
2291+
1,
22822292
[
22832293
"static_0_target_lag-4",
22842294
"static_1_target_lag-4",
22852295
"static_0_target_lag-1",
22862296
"static_1_target_lag-1",
22872297
"dummy_statcov_target_global_components",
22882298
],
2299+
[
2300+
"static_0_target_hrz0",
2301+
"static_1_target_hrz0",
2302+
],
22892303
),
22902304
# target with static covariate (component specific)
22912305
(
@@ -2296,6 +2310,7 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
22962310
None,
22972311
None,
22982312
True,
2313+
1,
22992314
[
23002315
"static_0_target_lag-4",
23012316
"static_1_target_lag-4",
@@ -2304,6 +2319,10 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
23042319
"dummy_statcov_target_static_0",
23052320
"dummy_statcov_target_static_1",
23062321
],
2322+
[
2323+
"static_0_target_hrz0",
2324+
"static_1_target_hrz0",
2325+
],
23072326
),
23082327
# target with static covariate (component specific & multivariate)
23092328
(
@@ -2314,6 +2333,7 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
23142333
None,
23152334
None,
23162335
True,
2336+
1,
23172337
[
23182338
"static_0_target_lag-4",
23192339
"static_1_target_lag-4",
@@ -2324,6 +2344,10 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
23242344
"dummy1_statcov_target_static_0",
23252345
"dummy1_statcov_target_static_1",
23262346
],
2347+
[
2348+
"static_0_target_hrz0",
2349+
"static_1_target_hrz0",
2350+
],
23272351
),
23282352
# target + past
23292353
(
@@ -2334,13 +2358,15 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
23342358
[-1],
23352359
None,
23362360
False,
2361+
1,
23372362
[
23382363
"no_static_target_lag-4",
23392364
"no_static_target_lag-3",
23402365
"past_0_pastcov_lag-1",
23412366
"past_1_pastcov_lag-1",
23422367
"past_2_pastcov_lag-1",
23432368
],
2369+
["no_static_target_hrz0"],
23442370
),
23452371
# target + future
23462372
(
@@ -2351,6 +2377,7 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
23512377
None,
23522378
[3],
23532379
False,
2380+
1,
23542381
[
23552382
"no_static_target_lag-2",
23562383
"no_static_target_lag-1",
@@ -2359,6 +2386,7 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
23592386
"future_2_futcov_lag3",
23602387
"future_3_futcov_lag3",
23612388
],
2389+
["no_static_target_hrz0"],
23622390
),
23632391
# past + future
23642392
(
@@ -2369,6 +2397,7 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
23692397
[-1],
23702398
[2],
23712399
False,
2400+
1,
23722401
[
23732402
"past_0_pastcov_lag-1",
23742403
"past_1_pastcov_lag-1",
@@ -2378,6 +2407,7 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
23782407
"future_2_futcov_lag2",
23792408
"future_3_futcov_lag2",
23802409
],
2410+
["no_static_target_hrz0"],
23812411
),
23822412
# target with static (not used) + past + future
23832413
(
@@ -2388,6 +2418,7 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
23882418
[-1],
23892419
[2],
23902420
False,
2421+
1,
23912422
[
23922423
"static_0_target_lag-2",
23932424
"static_1_target_lag-2",
@@ -2401,6 +2432,10 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
24012432
"future_2_futcov_lag2",
24022433
"future_3_futcov_lag2",
24032434
],
2435+
[
2436+
"static_0_target_hrz0",
2437+
"static_1_target_hrz0",
2438+
],
24042439
),
24052440
# multiple series with same components names, including past/future covariates
24062441
(
@@ -2411,6 +2446,7 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
24112446
[-1],
24122447
[2],
24132448
False,
2449+
1,
24142450
[
24152451
"static_0_target_lag-3",
24162452
"static_1_target_lag-3",
@@ -2422,6 +2458,10 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
24222458
"future_2_futcov_lag2",
24232459
"future_3_futcov_lag2",
24242460
],
2461+
[
2462+
"static_0_target_hrz0",
2463+
"static_1_target_hrz0",
2464+
],
24252465
),
24262466
# multiple series with different components will use the first series as reference
24272467
(
@@ -2435,6 +2475,7 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
24352475
[-1],
24362476
[2],
24372477
False,
2478+
1,
24382479
[
24392480
"static_0_target_lag-2",
24402481
"static_1_target_lag-2",
@@ -2448,6 +2489,10 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
24482489
"future_2_futcov_lag2",
24492490
"future_3_futcov_lag2",
24502491
],
2492+
[
2493+
"static_0_target_hrz0",
2494+
"static_1_target_hrz0",
2495+
],
24512496
),
24522497
],
24532498
)
@@ -2466,10 +2511,12 @@ def test_create_lagged_component_names(self, config):
24662511
lags_pc,
24672512
lags_fc,
24682513
use_static_cov,
2514+
ocl,
24692515
expected_lagged_features,
2516+
expected_lagged_labels,
24702517
) = config
24712518
# lags as list
2472-
created_lagged_features, _ = create_lagged_component_names(
2519+
created_lagged_features, created_lagged_labels = create_lagged_component_names(
24732520
target_series=ts_tg,
24742521
past_covariates=ts_pc,
24752522
future_covariates=ts_fc,
@@ -2478,6 +2525,7 @@ def test_create_lagged_component_names(self, config):
24782525
lags_future_covariates=lags_fc,
24792526
concatenate=False,
24802527
use_static_covariates=use_static_cov,
2528+
output_chunk_length=ocl,
24812529
)
24822530

24832531
# converts lags to dictionary format
@@ -2490,18 +2538,23 @@ def test_create_lagged_component_names(self, config):
24902538
lags_fc,
24912539
)
24922540

2493-
created_lagged_features_dict_lags, _ = create_lagged_component_names(
2494-
target_series=ts_tg,
2495-
past_covariates=ts_pc,
2496-
future_covariates=ts_fc,
2497-
lags=lags_as_dict["target"],
2498-
lags_past_covariates=lags_as_dict["past"],
2499-
lags_future_covariates=lags_as_dict["future"],
2500-
concatenate=False,
2501-
use_static_covariates=use_static_cov,
2541+
created_lagged_features_dict_lags, created_lagged_labels_dict_lags = (
2542+
create_lagged_component_names(
2543+
target_series=ts_tg,
2544+
past_covariates=ts_pc,
2545+
future_covariates=ts_fc,
2546+
lags=lags_as_dict["target"],
2547+
lags_past_covariates=lags_as_dict["past"],
2548+
lags_future_covariates=lags_as_dict["future"],
2549+
concatenate=False,
2550+
use_static_covariates=use_static_cov,
2551+
output_chunk_length=ocl,
2552+
)
25022553
)
25032554
assert expected_lagged_features == created_lagged_features
25042555
assert expected_lagged_features == created_lagged_features_dict_lags
2556+
assert expected_lagged_labels == created_lagged_labels
2557+
assert expected_lagged_labels == created_lagged_labels_dict_lags
25052558

25062559
@pytest.mark.parametrize(
25072560
"config",

darts/utils/data/tabularization.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -859,10 +859,21 @@ def create_lagged_component_names(
859859
[lags, lags_past_covariates, lags_future_covariates],
860860
["target", "pastcov", "futcov"],
861861
):
862-
if variate is None or variate_lags is None:
862+
if variate is None:
863863
continue
864864

865865
components = get_single_series(variate).components.tolist()
866+
# target labels
867+
if variate_type == "target":
868+
label_feature_names = [
869+
f"{name}_target_hrz{lag}"
870+
for lag in range(output_chunk_length)
871+
for name in components
872+
]
873+
874+
if variate_lags is None:
875+
continue
876+
866877
if isinstance(variate_lags, dict):
867878
if "default_lags" in variate_lags:
868879
raise_log(
@@ -894,13 +905,6 @@ def create_lagged_component_names(
894905
for name in components
895906
]
896907

897-
if variate_type == "target" and lags:
898-
label_feature_names = [
899-
f"{name}_target_hrz{lag}"
900-
for lag in range(output_chunk_length)
901-
for name in components
902-
]
903-
904908
# static covariates
905909
if use_static_covariates:
906910
static_covs = get_single_series(target_series).static_covariates

0 commit comments

Comments
 (0)