Skip to content

Commit 35bd3b4

Browse files
authored
Implement DartsAdapter (#30)
1 parent 38be741 commit 35bd3b4

File tree

1 file changed

+70
-1
lines changed

1 file changed

+70
-1
lines changed

src/fev/adapters.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
if TYPE_CHECKING:
1515
import autogluon.timeseries
16+
import darts
1617
import gluonts.dataset.pandas
1718

1819

@@ -291,7 +292,75 @@ def convert_input_data(
291292

292293

293294
class DartsAdapter(DatasetAdapter):
294-
pass
295+
@classmethod
296+
def convert_input_data(
297+
cls,
298+
past: datasets.Dataset,
299+
future: datasets.Dataset,
300+
*,
301+
target_column: str | list[str],
302+
id_column: str,
303+
timestamp_column: str,
304+
static_columns: list[str],
305+
) -> tuple["list[darts.TimeSeries]", "list[darts.TimeSeries] | None", "list[darts.TimeSeries] | None"]:
306+
try:
307+
from darts import TimeSeries
308+
except ModuleNotFoundError:
309+
raise ModuleNotFoundError(f"Please install darts before using {cls.__name__}")
310+
311+
if isinstance(target_column, str):
312+
target_column = [target_column]
313+
314+
past_covariates_names = []
315+
future_covariates_names = []
316+
for col, feat in past.features.items():
317+
if col not in [id_column, timestamp_column, *target_column, *static_columns]:
318+
assert isinstance(feat, datasets.Sequence)
319+
# Only include numeric dtypes for past/future covariates
320+
if any(t in feat.feature.dtype for t in ["int", "float", "double"]):
321+
if col in future.column_names:
322+
future_covariates_names.append(col)
323+
else:
324+
past_covariates_names.append(col)
325+
326+
target_series = []
327+
past_covariates = []
328+
future_covariates = []
329+
for i in range(len(past)):
330+
past_i = past[i]
331+
future_i = future[i]
332+
target_series.append(
333+
TimeSeries(
334+
times=pd.DatetimeIndex(past_i[timestamp_column]),
335+
values=np.stack([past_i[col] for col in target_column], axis=1).astype("float32"),
336+
static_covariates=pd.Series({col: past_i[col] for col in static_columns}),
337+
components=target_column,
338+
),
339+
)
340+
if len(past_covariates_names) > 0:
341+
past_covariates.append(
342+
TimeSeries(
343+
times=pd.DatetimeIndex(past_i[timestamp_column]),
344+
values=np.stack([past_i[col] for col in past_covariates_names], axis=1),
345+
components=past_covariates_names,
346+
).astype("float32"),
347+
)
348+
if len(future_covariates_names) > 0:
349+
future_covariates.append(
350+
TimeSeries(
351+
times=pd.DatetimeIndex(np.concatenate([past_i[timestamp_column], future_i[timestamp_column]])),
352+
values=np.stack(
353+
[np.concatenate([past_i[col], future_i[col]]) for col in future_covariates_names],
354+
axis=1,
355+
).astype("float32"),
356+
components=future_covariates_names,
357+
)
358+
)
359+
if len(past_covariates_names) == 0:
360+
past_covariates = None
361+
if len(future_covariates_names) == 0:
362+
future_covariates = None
363+
return target_series, past_covariates, future_covariates
295364

296365

297366
DATASET_ADAPTERS: dict[str, Type[DatasetAdapter]] = {

0 commit comments

Comments
 (0)