You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Feature/ Skip resampling in TFT to provide speed-up (#2898)
* Allow skipped resampling in TFT for faster inference
Resampling in TFT's VariableSelectionNetwork has introduced overheads
for training due to slow `interpolate()` implementation in PyTorch. I've
added an option `skip_resampling` to skip over such operations in TFT
while accuracy are largely not affected.
- `skip_resampling` defaults to `False` and TFT would retain the old
behaviour of applying interpolation on feature embeddings.
- When set to `True`, all interpolation operations are skipped (in
`_GatedResidualNetwork`) or replaced by projection (`_ResampleNorm`).
- Quite a few typing errors are fixed in TFT.
* Add back `forward()` to `_ResampleNorm`
* Fix a TFT kernel error on MPS device
* Add TFT tests for MPS devices and `skip_resampling` option
* Update CHANGELOG for TFT skip resampling & MPS bug
* Fix `test_on_mps` w/ `tfm_kwargs§ deep copy
Previous shallow copy of `tfm_kwargs` modified `"pl_trainer_kwargs"` for
other tests and led to many test fails. We modify a deep copy here for
MPS test to fix it.
* Remove TFT test on MPS devices
MPS memory is not available on GitHub. TFT test on MPS is removed.
* Fix a bug in `_VariableSelectionNetwork`
* Expand TFT static covariate test w/ `skip_resampling`
Test static covariate support with and without `skip_resampling`
* Update CHANGELOG.md
Co-authored-by: Dennis Bader <[email protected]>
* Update CHANGELOG.md
* Replace interpolation with linear projection
- Rename `skip_resampling` option to `skip_interpolation`.
- When set to `True`, all interpolation would be replaced by linear
projection during the feature embedding sampling operations.
* Update CHANGELOG for renamed `skip_interpolation` option
---------
Co-authored-by: Dennis Bader <[email protected]>
Copy file name to clipboardExpand all lines: CHANGELOG.md
+2Lines changed: 2 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -11,6 +11,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
11
11
12
12
**Improved**
13
13
14
+
- Added hyperparameter `skip_interpolation` to `TFTModel` that will replace 1D interpolation on feature embeddings with linear projection. When `True`, it can greatly increase training and inference efficiency while predictive accuracy remains largely unaffected. [#2898](https://github.com/unit8co/darts/pull/2898) by [Zhihao Dai](https://github.com/daidahao).
14
15
- Added mixed precision and 16-bit precision support to `TorchForecastingModel`. Simply specify `{"precision": "bf16-mixed" }` for `pl_trainer_kwargs` to enable mixed precision training. Alternatively, declare a custom `pytorch_lightning.Trainer` with a `"precision"` parameter and pass the trainer to `fit()`. Other precision options such as `"64-true"` and `"16-mixed"` supported by `pytorch_lightning` are also allowed. [#2883](https://github.com/unit8co/darts/pull/2883) by [Zhihao Dai](https://github.com/daidahao).
15
16
- 🔴 Added future and static covariates support to `BlockRNNModel`. This improvement required changes to the underlying model architecture which means that saved model instances from older Darts versions cannot be loaded any longer. [#2845](https://github.com/unit8co/darts/pull/2845) by [Gabriel Margaria](https://github.com/Jaco-Pastorius).
16
17
-`from_group_dataframe()` now supports creating `TimeSeries` from **additional DataFrame backends** (Polars, PyArrow, ...). We leverage `narwhals` as the compatibility layer between DataFrame libraries. See their [documentation](https://narwhals-dev.github.io/narwhals/) for all supported backends. [#2766](https://github.com/unit8co/darts/pull/2766) by [He Weilin](https://github.com/cnhwl).
@@ -23,6 +24,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
23
24
24
25
**Fixed**
25
26
27
+
- Fixed a bug causing crashes when running `TFTModel` on MPS devices (macOS with GPUs). [#2898](https://github.com/unit8co/darts/pull/2898) by [Zhihao Dai](https://github.com/daidahao).
26
28
- Fixed a bug when saving a `GlobalNaiveModel` directly after fitting it (without performing prediction). [#2895](https://github.com/unit8co/darts/pull/2895), by [Alain Gysi](https://github.com/Kurokabe)
27
29
- Fixed a bug when using an `EnsembleModel` with `train_forecasting_models=False` and at least one torch model in `forecasting_models`, where calling `historical_forecasts()` with `retrain=True` raised an exception due to the torch models being unintentionally reset. [#2894](https://github.com/unit8co/darts/pull/2894) by [Dennis Bader](https://github.com/dennisbader).
0 commit comments