Skip to content

Commit 61cc25e

Browse files
authored
Add multivariate <-> univariate conversion support (#27)
1 parent 1161caa commit 61cc25e

File tree

5 files changed

+330
-103
lines changed

5 files changed

+330
-103
lines changed

src/fev/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.5.0"
1+
__version__ = "0.6.0b"

src/fev/adapters.py

Lines changed: 183 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,59 @@
11
from __future__ import annotations
22

3+
from abc import ABC, abstractmethod
34
from typing import TYPE_CHECKING, Any, Literal, Type
45

56
import datasets
67
import numpy as np
78
import pandas as pd
89

10+
from . import utils
911
from .task import Task
1012

1113
if TYPE_CHECKING:
1214
import autogluon.timeseries
1315
import gluonts.dataset.pandas
1416

1517

16-
class DatasetAdapter:
18+
class DatasetAdapter(ABC):
1719
"""Convert a time series dataset into format suitable for other frameworks."""
1820

21+
@classmethod
22+
@abstractmethod
1923
def convert_input_data(
20-
self,
24+
cls,
2125
past: datasets.Dataset,
2226
future: datasets.Dataset,
23-
task: Task,
27+
*,
28+
target_column: str | list[str],
29+
id_column: str,
30+
timestamp_column: str,
31+
static_columns: list[str],
2432
) -> Any:
25-
raise NotImplementedError
33+
"""Convert the input data of the task into a format compatible with the framework."""
34+
pass
35+
36+
37+
class DatasetsAdapter(DatasetAdapter):
38+
"""Keeps data formatted as datasets.Dataset objects."""
39+
40+
@classmethod
41+
def convert_input_data(
42+
cls,
43+
past: datasets.Dataset,
44+
future: datasets.Dataset,
45+
*,
46+
target_column: str | list[str],
47+
id_column: str,
48+
timestamp_column: str,
49+
static_columns: list[str],
50+
) -> tuple[datasets.Dataset, datasets.Dataset]:
51+
return past, future
2652

2753

2854
class PandasAdapter(DatasetAdapter):
55+
"""Converts data to pandas.DataFrame objects."""
56+
2957
@staticmethod
3058
def _to_long_df(dataset: datasets.Dataset, id_column: str) -> pd.DataFrame:
3159
"""Convert time series dataset into long DataFrame format.
@@ -44,28 +72,30 @@ def _to_long_df(dataset: datasets.Dataset, id_column: str) -> pd.DataFrame:
4472
df_dict[col] = np.concatenate(df[col])
4573
return pd.DataFrame(df_dict).astype({id_column: str})
4674

75+
@classmethod
4776
def convert_input_data(
48-
self,
77+
cls,
4978
past: datasets.Dataset,
5079
future: datasets.Dataset,
51-
task: Task,
52-
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
53-
past_df = self._to_long_df(past.remove_columns(task.static_columns), id_column=task.id_column)
54-
future_df = self._to_long_df(future.remove_columns(task.static_columns), id_column=task.id_column)
55-
if len(task.static_columns) > 0:
56-
static_df = past.select_columns([task.id_column] + task.static_columns).to_pandas()
80+
*,
81+
target_column: str | list[str],
82+
id_column: str,
83+
timestamp_column: str,
84+
static_columns: list[str],
85+
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame | None]:
86+
past_df = cls._to_long_df(past.remove_columns(static_columns), id_column=id_column)
87+
future_df = cls._to_long_df(future.remove_columns(static_columns), id_column=id_column)
88+
if len(static_columns) > 0:
89+
static_df = past.select_columns([id_column] + static_columns).to_pandas()
5790
# Infer numeric dtypes if possible (e.g., object -> float), but make sure that id_column has str dtype
58-
static_df = static_df.infer_objects().astype({task.id_column: str})
91+
static_df = static_df.infer_objects().astype({id_column: str})
5992
else:
6093
static_df = None
6194
return past_df, future_df, static_df
6295

6396

6497
class GluonTSAdapter(PandasAdapter):
65-
"""Converts dataset to format required by GluonTS.
66-
67-
Optionally, this adapter can fill in missing values in the dynamic & static feature columns.
68-
"""
98+
"""Converts dataset to format required by GluonTS."""
6999

70100
@staticmethod
71101
def _convert_dtypes(df: pd.DataFrame, float_dtype: str = "float32") -> pd.DataFrame:
@@ -78,48 +108,64 @@ def _convert_dtypes(df: pd.DataFrame, float_dtype: str = "float32") -> pd.DataFr
78108
astype_dict[col] = float_dtype
79109
return df.astype(astype_dict)
80110

111+
@classmethod
81112
def convert_input_data(
82-
self,
113+
cls,
83114
past: datasets.Dataset,
84115
future: datasets.Dataset,
85-
task: Task,
116+
*,
117+
target_column: str | list[str],
118+
id_column: str,
119+
timestamp_column: str,
120+
static_columns: list[str],
86121
) -> tuple["gluonts.dataset.pandas.PandasDataset", "gluonts.dataset.pandas.PandasDataset"]:
87122
try:
88123
from gluonts.dataset.pandas import PandasDataset
89124
except ModuleNotFoundError:
90-
raise ModuleNotFoundError(f"Please install GluonTS before using {self.__class__.__name__}")
91-
if task.is_multivariate:
92-
raise ValueError(f"{self.__class__.__name__} currently does not support multivariate tasks.")
93-
past_df, future_df, static_df = super().convert_input_data(past=past, future=future, task=task)
125+
raise ModuleNotFoundError(f"Please install GluonTS before using {cls.__name__}")
126+
assert isinstance(target_column, str), f"{cls.__name__} does not support multivariate tasks."
127+
128+
past_df, future_df, static_df = super().convert_input_data(
129+
past=past,
130+
future=future,
131+
target_column=target_column,
132+
id_column=id_column,
133+
timestamp_column=timestamp_column,
134+
static_columns=static_columns,
135+
)
94136

95-
past_df = self._convert_dtypes(past_df)
96-
future_df = self._convert_dtypes(future_df)
137+
past_df = cls._convert_dtypes(past_df)
138+
future_df = cls._convert_dtypes(future_df)
97139
if static_df is not None:
98-
static_df = self._convert_dtypes(static_df.set_index(task.id_column))
140+
static_df = cls._convert_dtypes(static_df.set_index(id_column))
141+
else:
142+
static_df = pd.DataFrame()
99143

144+
# GluonTS needs to know the data frequency, we infer it from the timestamps
145+
freq = pd.infer_freq(np.concatenate([past[0][timestamp_column], future[0][timestamp_column]]))
100146
# GluonTS uses pd.Period, which requires frequencies like 'M' instead of 'ME'
101-
gluonts_freq = pd.tseries.frequencies.get_period_alias(task.freq)
147+
gluonts_freq = pd.tseries.frequencies.get_period_alias(freq)
102148
# We compute names of feature columns after non-numeric columns have been removed
103-
feat_dynamic_real = list(future_df.columns.drop([task.id_column, task.timestamp_column]))
104-
past_feat_dynamic_real = list(past_df.columns.drop(list(future_df.columns) + [task.target_column]))
149+
feat_dynamic_real = list(future_df.columns.drop([id_column, timestamp_column]))
150+
past_feat_dynamic_real = list(past_df.columns.drop(list(future_df.columns) + [target_column]))
105151
past_dataset = PandasDataset.from_long_dataframe(
106152
past_df,
107-
item_id=task.id_column,
108-
timestamp=task.timestamp_column,
109-
target=task.target_column,
153+
item_id=id_column,
154+
timestamp=timestamp_column,
155+
target=target_column,
110156
static_features=static_df,
111157
freq=gluonts_freq,
112158
feat_dynamic_real=feat_dynamic_real,
113159
past_feat_dynamic_real=past_feat_dynamic_real,
114160
)
115161
prediction_dataset = PandasDataset.from_long_dataframe(
116162
pd.concat([past_df, future_df]),
117-
item_id=task.id_column,
118-
timestamp=task.timestamp_column,
119-
target=task.target_column,
163+
item_id=id_column,
164+
timestamp=timestamp_column,
165+
target=target_column,
120166
static_features=static_df,
121167
freq=gluonts_freq,
122-
future_length=task.horizon,
168+
future_length=len(future[0][timestamp_column]),
123169
feat_dynamic_real=feat_dynamic_real,
124170
past_feat_dynamic_real=past_feat_dynamic_real,
125171
)
@@ -143,30 +189,42 @@ class NixtlaAdapter(PandasAdapter):
143189
timestamp_column: str = "ds"
144190
target_column: str = "y"
145191

192+
@classmethod
146193
def convert_input_data(
147-
self,
194+
cls,
148195
past: datasets.Dataset,
149196
future: datasets.Dataset,
150-
task: Task,
151-
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
152-
if task.is_multivariate:
153-
raise ValueError(f"{self.__class__.__name__} currently does not support multivariate tasks.")
154-
past_df, future_df, static_df = super().convert_input_data(past=past, future=future, task=task)
197+
*,
198+
target_column: str | list[str],
199+
id_column: str,
200+
timestamp_column: str,
201+
static_columns: list[str],
202+
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame | None]:
203+
assert isinstance(target_column, str), f"{cls.__name__} does not support multivariate tasks."
204+
205+
past_df, future_df, static_df = super().convert_input_data(
206+
past=past,
207+
future=future,
208+
target_column=target_column,
209+
id_column=id_column,
210+
timestamp_column=timestamp_column,
211+
static_columns=static_columns,
212+
)
155213
past_df = past_df.rename(
156214
columns={
157-
task.id_column: self.id_column,
158-
task.timestamp_column: self.timestamp_column,
159-
task.target_column: self.target_column,
215+
id_column: cls.id_column,
216+
timestamp_column: cls.timestamp_column,
217+
target_column: cls.target_column,
160218
}
161219
)
162220
future_df = future_df.rename(
163221
columns={
164-
task.id_column: self.id_column,
165-
task.timestamp_column: self.timestamp_column,
222+
id_column: cls.id_column,
223+
timestamp_column: cls.timestamp_column,
166224
}
167225
)
168226
if static_df is not None:
169-
static_df = static_df.rename(columns={task.id_column: self.id_column})
227+
static_df = static_df.rename(columns={id_column: cls.id_column})
170228

171229
return past_df, future_df, static_df
172230

@@ -176,38 +234,53 @@ class AutoGluonAdapter(PandasAdapter):
176234
177235
Returns
178236
-------
179-
past_df : autogluon.timeseries.TimeSeriesDataFrame
237+
past_data : autogluon.timeseries.TimeSeriesDataFrame
180238
Dataframe containing the past values of the time series as well as all dynamic features.
181239
182-
If static features are present in the dataset, they are stored as `past_df.static_features`.
240+
Target column is always renamed to "target".
241+
242+
If static features are present in the dataset, they are stored as `past_data.static_features`.
183243
known_covariates : autogluon.timeseries.TimeSeriesDataFrame
184244
Dataframe containing the future values of the dynamic features that are known in the future.
185245
"""
186246

247+
target_column: str = "target"
248+
249+
@classmethod
187250
def convert_input_data(
188-
self,
251+
cls,
189252
past: datasets.Dataset,
190253
future: datasets.Dataset,
191-
task: Task,
254+
*,
255+
target_column: str | list[str],
256+
id_column: str,
257+
timestamp_column: str,
258+
static_columns: list[str],
192259
) -> tuple["autogluon.timeseries.TimeSeriesDataFrame", "autogluon.timeseries.TimeSeriesDataFrame"]:
193260
try:
194261
from autogluon.timeseries import TimeSeriesDataFrame
195262
except ModuleNotFoundError:
196-
raise ModuleNotFoundError(f"Please install AutoGluon before using {self.__class__.__name__}")
197-
if task.is_multivariate:
198-
raise ValueError(f"{self.__class__.__name__} currently does not support multivariate tasks.")
199-
200-
past_df, future_df, static_df = super().convert_input_data(past=past, future=future, task=task)
263+
raise ModuleNotFoundError(f"Please install AutoGluon before using {cls.__name__}")
264+
assert isinstance(target_column, str), f"{cls.__name__} does not support multivariate tasks."
265+
266+
past_df, future_df, static_df = super().convert_input_data(
267+
past=past,
268+
future=future,
269+
target_column=target_column,
270+
id_column=id_column,
271+
timestamp_column=timestamp_column,
272+
static_columns=static_columns,
273+
)
201274
past_data = TimeSeriesDataFrame.from_data_frame(
202-
past_df,
203-
id_column=task.id_column,
204-
timestamp_column=task.timestamp_column,
275+
past_df.rename(columns={target_column: cls.target_column}),
276+
id_column=id_column,
277+
timestamp_column=timestamp_column,
205278
static_features_df=static_df,
206279
)
207280
known_covariates = TimeSeriesDataFrame.from_data_frame(
208281
future_df,
209-
id_column=task.id_column,
210-
timestamp_column=task.timestamp_column,
282+
id_column=id_column,
283+
timestamp_column=timestamp_column,
211284
)
212285
return past_data, known_covariates
213286

@@ -218,6 +291,7 @@ class DartsAdapter(DatasetAdapter):
218291

219292
DATASET_ADAPTERS: dict[str, Type[DatasetAdapter]] = {
220293
"pandas": PandasAdapter,
294+
"datasets": DatasetsAdapter,
221295
"gluonts": GluonTSAdapter,
222296
"nixtla": NixtlaAdapter,
223297
"darts": DartsAdapter,
@@ -227,21 +301,65 @@ class DartsAdapter(DatasetAdapter):
227301

228302
def convert_input_data(
229303
task: Task,
230-
adapter: Literal["pandas", "gluonts", "nixtla", "darts", "autogluon"] = "pandas",
304+
adapter: Literal["pandas", "datasets", "gluonts", "nixtla", "darts", "autogluon"] = "pandas",
305+
*,
306+
as_univariate: bool = False,
307+
univariate_target_column: str = "target",
231308
**kwargs,
232309
) -> Any:
233310
"""Convert the output of `task.get_input_data()` to a format compatible with popular forecasting frameworks.
234311
235312
Parameters
236313
----------
237-
task : fev.Task
314+
task
238315
Task object for which input data must be converted.
239-
adapter : {"pandas", "gluonts", "nixtla", "darts", "autogluon"}
316+
adapter : {"pandas", "datasets", "gluonts", "nixtla", "darts", "autogluon"}
240317
Format to which the dataset must be converted.
318+
as_univariate
319+
If True, separate instances will be created from each target column before passing the data to the adapter.
320+
321+
Equivalent to setting `generate_univariate_targets_from = "__ALL__"` in `Task` constructor.
322+
univariate_target_column
323+
Target column name used when as_univariate=True. Only used by the "datasets" adapter.
241324
**kwargs
242325
Keyword arguments passed to :meth:`fev.Task.get_input_data`.
243326
"""
244327
past, future = task.get_input_data(**kwargs)
328+
329+
if as_univariate:
330+
if univariate_target_column in past.column_names and univariate_target_column != task.target_column:
331+
raise ValueError(
332+
f"Column '{univariate_target_column}' already exists. Choose a different univariate_target_column."
333+
)
334+
target_column = univariate_target_column
335+
if task.is_multivariate:
336+
past = utils.generate_univariate_targets_from_multivariate(
337+
past,
338+
id_column=task.id_column,
339+
new_target_column=target_column,
340+
generate_univariate_targets_from=task.target_columns_list,
341+
)
342+
# We cannot apply generate_univariate_targets_from_multivariate to future since it does not contain target cols,
343+
# so we just repeat each entry and insert the IDs from past, repeating entries as [0, 0, ..., 1, 1, ..., N -1, N - 1, ...]
344+
original_column_order = future.column_names
345+
future = future.select([i for i in range(len(future)) for _ in range(len(task.target_columns_list))])
346+
future = future.remove_columns(task.id_column).add_column(name=task.id_column, column=past[task.id_column])
347+
future = future.select_columns(original_column_order)
348+
else:
349+
if target_column not in past.column_names:
350+
past = past.rename_column(task.target_column, target_column)
351+
else:
352+
target_column = task.target_column
353+
245354
if adapter not in DATASET_ADAPTERS:
246355
raise KeyError(f"`adapter` must be one of {list(DATASET_ADAPTERS)}")
247-
return DATASET_ADAPTERS[adapter]().convert_input_data(past=past, future=future, task=task)
356+
adapter_cls = DATASET_ADAPTERS[adapter]
357+
358+
return adapter_cls().convert_input_data(
359+
past=past,
360+
future=future,
361+
target_column=target_column,
362+
id_column=task.id_column,
363+
timestamp_column=task.timestamp_column,
364+
static_columns=task.static_columns,
365+
)

0 commit comments

Comments
 (0)