|
13 | 13 |
|
14 | 14 | if TYPE_CHECKING: |
15 | 15 | import autogluon.timeseries |
| 16 | + import darts |
16 | 17 | import gluonts.dataset.pandas |
17 | 18 |
|
18 | 19 |
|
@@ -291,7 +292,75 @@ def convert_input_data( |
291 | 292 |
|
292 | 293 |
|
293 | 294 | class DartsAdapter(DatasetAdapter): |
294 | | - pass |
| 295 | + @classmethod |
| 296 | + def convert_input_data( |
| 297 | + cls, |
| 298 | + past: datasets.Dataset, |
| 299 | + future: datasets.Dataset, |
| 300 | + *, |
| 301 | + target_column: str | list[str], |
| 302 | + id_column: str, |
| 303 | + timestamp_column: str, |
| 304 | + static_columns: list[str], |
| 305 | + ) -> tuple["list[darts.TimeSeries]", "list[darts.TimeSeries] | None", "list[darts.TimeSeries] | None"]: |
| 306 | + try: |
| 307 | + from darts import TimeSeries |
| 308 | + except ModuleNotFoundError: |
| 309 | + raise ModuleNotFoundError(f"Please install darts before using {cls.__name__}") |
| 310 | + |
| 311 | + if isinstance(target_column, str): |
| 312 | + target_column = [target_column] |
| 313 | + |
| 314 | + past_covariates_names = [] |
| 315 | + future_covariates_names = [] |
| 316 | + for col, feat in past.features.items(): |
| 317 | + if col not in [id_column, timestamp_column, *target_column, *static_columns]: |
| 318 | + assert isinstance(feat, datasets.Sequence) |
| 319 | + # Only include numeric dtypes for past/future covariates |
| 320 | + if any(t in feat.feature.dtype for t in ["int", "float", "double"]): |
| 321 | + if col in future.column_names: |
| 322 | + future_covariates_names.append(col) |
| 323 | + else: |
| 324 | + past_covariates_names.append(col) |
| 325 | + |
| 326 | + target_series = [] |
| 327 | + past_covariates = [] |
| 328 | + future_covariates = [] |
| 329 | + for i in range(len(past)): |
| 330 | + past_i = past[i] |
| 331 | + future_i = future[i] |
| 332 | + target_series.append( |
| 333 | + TimeSeries( |
| 334 | + times=pd.DatetimeIndex(past_i[timestamp_column]), |
| 335 | + values=np.stack([past_i[col] for col in target_column], axis=1).astype("float32"), |
| 336 | + static_covariates=pd.Series({col: past_i[col] for col in static_columns}), |
| 337 | + components=target_column, |
| 338 | + ), |
| 339 | + ) |
| 340 | + if len(past_covariates_names) > 0: |
| 341 | + past_covariates.append( |
| 342 | + TimeSeries( |
| 343 | + times=pd.DatetimeIndex(past_i[timestamp_column]), |
| 344 | + values=np.stack([past_i[col] for col in past_covariates_names], axis=1), |
| 345 | + components=past_covariates_names, |
| 346 | + ).astype("float32"), |
| 347 | + ) |
| 348 | + if len(future_covariates_names) > 0: |
| 349 | + future_covariates.append( |
| 350 | + TimeSeries( |
| 351 | + times=pd.DatetimeIndex(np.concatenate([past_i[timestamp_column], future_i[timestamp_column]])), |
| 352 | + values=np.stack( |
| 353 | + [np.concatenate([past_i[col], future_i[col]]) for col in future_covariates_names], |
| 354 | + axis=1, |
| 355 | + ).astype("float32"), |
| 356 | + components=future_covariates_names, |
| 357 | + ) |
| 358 | + ) |
| 359 | + if len(past_covariates_names) == 0: |
| 360 | + past_covariates = None |
| 361 | + if len(future_covariates_names) == 0: |
| 362 | + future_covariates = None |
| 363 | + return target_series, past_covariates, future_covariates |
295 | 364 |
|
296 | 365 |
|
297 | 366 | DATASET_ADAPTERS: dict[str, Type[DatasetAdapter]] = { |
|
0 commit comments