Skip to content

Commit 5c91f34

Browse files
abdulfatirshchur
andauthored
Handle DatasetDict in clean_and_validate_predictions (#22)
Co-authored-by: Oleksandr Shchur <[email protected]>
1 parent 6eb2a51 commit 5c91f34

File tree

2 files changed

+34
-10
lines changed

2 files changed

+34
-10
lines changed

src/fev/task.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -599,17 +599,16 @@ def _to_dataset(preds: datasets.Dataset | list[dict]) -> datasets.Dataset:
599599
raise ValueError(f"predictions must be of type `datasets.Dataset` (received {type(preds)})")
600600
return preds
601601

602-
if self.is_multivariate:
603-
if isinstance(predictions, datasets.DatasetDict):
604-
pass
605-
elif isinstance(predictions, dict):
606-
predictions = datasets.DatasetDict({col: _to_dataset(preds) for col, preds in predictions.items()})
602+
if not isinstance(predictions, datasets.DatasetDict):
603+
if self.is_multivariate:
604+
if isinstance(predictions, dict):
605+
predictions = datasets.DatasetDict({col: _to_dataset(preds) for col, preds in predictions.items()})
606+
else:
607+
raise ValueError(
608+
f"predictions for multivariate tasks must be of type `datasets.DatasetDict` or `dict` (received {type(predictions)})"
609+
)
607610
else:
608-
raise ValueError(
609-
f"predictions for multivariate tasks must be of type `datasets.DatasetDict` or `dict` (received {type(predictions)})"
610-
)
611-
else:
612-
predictions = datasets.DatasetDict({self.target_column: _to_dataset(predictions)})
611+
predictions = datasets.DatasetDict({self.target_column: _to_dataset(predictions)})
613612

614613
predictions = predictions.cast(self.predictions_schema).with_format("numpy")
615614
for target_column, predictions_for_column in predictions.items():

test/test_task.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,31 @@ def test_when_multivariate_task_is_created_then_data_contains_correct_columns(ta
145145
assert set(future_data.column_names) == set(all_column_names) - set(target_column)
146146

147147

148+
@pytest.mark.parametrize("return_list", [True, False])
149+
def test_when_predictions_provided_as_dataset_dict_for_univariate_task_then_predictions_can_be_scores(return_list):
150+
def naive_forecast_univariate(task: fev.Task) -> list[dict]:
151+
past_data, future_data = task.get_input_data()
152+
predictions = []
153+
for ts in past_data:
154+
predictions.append({"predictions": [ts[task.target_column][-1] for _ in range(task.horizon)]})
155+
if return_list:
156+
return predictions
157+
else:
158+
return datasets.DatasetDict({task.target_column: datasets.Dataset.from_list(predictions)})
159+
160+
task = fev.Task(
161+
dataset_path="autogluon/chronos_datasets",
162+
dataset_config="monash_m1_yearly",
163+
eval_metric="MASE",
164+
extra_metrics=["WAPE"],
165+
horizon=4,
166+
)
167+
predictions = naive_forecast_univariate(task)
168+
summary = task.evaluation_summary(predictions, model_name="naive")
169+
for metric in ["MASE", "WAPE"]:
170+
assert np.isfinite(summary[metric])
171+
172+
148173
@pytest.mark.parametrize("target_column", [["OT"], ["OT", "LULL", "HULL"]])
149174
@pytest.mark.parametrize("return_dict", [True, False])
150175
def test_when_multivariate_task_is_used_then_predictions_can_be_scored(target_column, return_dict):

0 commit comments

Comments
 (0)