Skip to content

Commit e2d6686

Browse files
Feature/ Skip resampling in TFT to provide speed-up (#2898)
* Allow skipped resampling in TFT for faster inference Resampling in TFT's VariableSelectionNetwork has introduced overheads for training due to slow `interpolate()` implementation in PyTorch. I've added an option `skip_resampling` to skip over such operations in TFT while accuracy are largely not affected. - `skip_resampling` defaults to `False` and TFT would retain the old behaviour of applying interpolation on feature embeddings. - When set to `True`, all interpolation operations are skipped (in `_GatedResidualNetwork`) or replaced by projection (`_ResampleNorm`). - Quite a few typing errors are fixed in TFT. * Add back `forward()` to `_ResampleNorm` * Fix a TFT kernel error on MPS device * Add TFT tests for MPS devices and `skip_resampling` option * Update CHANGELOG for TFT skip resampling & MPS bug * Fix `test_on_mps` w/ `tfm_kwargs§ deep copy Previous shallow copy of `tfm_kwargs` modified `"pl_trainer_kwargs"` for other tests and led to many test fails. We modify a deep copy here for MPS test to fix it. * Remove TFT test on MPS devices MPS memory is not available on GitHub. TFT test on MPS is removed. * Fix a bug in `_VariableSelectionNetwork` * Expand TFT static covariate test w/ `skip_resampling` Test static covariate support with and without `skip_resampling` * Update CHANGELOG.md Co-authored-by: Dennis Bader <[email protected]> * Update CHANGELOG.md * Replace interpolation with linear projection - Rename `skip_resampling` option to `skip_interpolation`. - When set to `True`, all interpolation would be replaced by linear projection during the feature embedding sampling operations. * Update CHANGELOG for renamed `skip_interpolation` option --------- Co-authored-by: Dennis Bader <[email protected]>
1 parent 9ba3258 commit e2d6686

File tree

4 files changed

+110
-49
lines changed

4 files changed

+110
-49
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1111

1212
**Improved**
1313

14+
- Added hyperparameter `skip_interpolation` to `TFTModel` that will replace 1D interpolation on feature embeddings with linear projection. When `True`, it can greatly increase training and inference efficiency while predictive accuracy remains largely unaffected. [#2898](https://github.com/unit8co/darts/pull/2898) by [Zhihao Dai](https://github.com/daidahao).
1415
- Added mixed precision and 16-bit precision support to `TorchForecastingModel`. Simply specify `{"precision": "bf16-mixed" }` for `pl_trainer_kwargs` to enable mixed precision training. Alternatively, declare a custom `pytorch_lightning.Trainer` with a `"precision"` parameter and pass the trainer to `fit()`. Other precision options such as `"64-true"` and `"16-mixed"` supported by `pytorch_lightning` are also allowed. [#2883](https://github.com/unit8co/darts/pull/2883) by [Zhihao Dai](https://github.com/daidahao).
1516
- 🔴 Added future and static covariates support to `BlockRNNModel`. This improvement required changes to the underlying model architecture which means that saved model instances from older Darts versions cannot be loaded any longer. [#2845](https://github.com/unit8co/darts/pull/2845) by [Gabriel Margaria](https://github.com/Jaco-Pastorius).
1617
- `from_group_dataframe()` now supports creating `TimeSeries` from **additional DataFrame backends** (Polars, PyArrow, ...). We leverage `narwhals` as the compatibility layer between DataFrame libraries. See their [documentation](https://narwhals-dev.github.io/narwhals/) for all supported backends. [#2766](https://github.com/unit8co/darts/pull/2766) by [He Weilin](https://github.com/cnhwl).
@@ -23,6 +24,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
2324

2425
**Fixed**
2526

27+
- Fixed a bug causing crashes when running `TFTModel` on MPS devices (macOS with GPUs). [#2898](https://github.com/unit8co/darts/pull/2898) by [Zhihao Dai](https://github.com/daidahao).
2628
- Fixed a bug when saving a `GlobalNaiveModel` directly after fitting it (without performing prediction). [#2895](https://github.com/unit8co/darts/pull/2895), by [Alain Gysi](https://github.com/Kurokabe)
2729
- Fixed a bug when using an `EnsembleModel` with `train_forecasting_models=False` and at least one torch model in `forecasting_models`, where calling `historical_forecasts()` with `retrain=True` raised an exception due to the torch models being unintentionally reset. [#2894](https://github.com/unit8co/darts/pull/2894) by [Dennis Bader](https://github.com/dennisbader).
2830

darts/models/forecasting/tft_model.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(
4242
output_dim: tuple[int, int],
4343
variables_meta: dict[str, dict[str, list[str]]],
4444
num_static_components: int,
45-
hidden_size: Union[int, list[int]],
45+
hidden_size: int,
4646
lstm_layers: int,
4747
num_attention_heads: int,
4848
full_attention: bool,
@@ -51,7 +51,8 @@ def __init__(
5151
categorical_embedding_sizes: dict[str, tuple[int, int]],
5252
dropout: float,
5353
add_relative_index: bool,
54-
norm_type: Union[str, nn.Module],
54+
norm_type: Union[str, type[nn.Module]],
55+
skip_interpolation: bool = False,
5556
**kwargs,
5657
):
5758
"""PyTorch module implementing the TFT architecture from `this paper <https://arxiv.org/pdf/1912.09363.pdf>`_
@@ -98,8 +99,12 @@ def __init__(
9899
likelihood
99100
The likelihood model to be used for probabilistic forecasts. By default, the TFT uses
100101
a ``QuantileRegression`` likelihood.
101-
norm_type: str | nn.Module
102+
norm_type: str | type[nn.Module]
102103
The type of LayerNorm variant to use.
104+
skip_interpolation: bool
105+
Whether to skip interpolation and replace with linear projection on feature embeddings in
106+
VariableSelectionNetwork. Setting this to `True` could increase training and inference speed.
107+
Defaults to `False` to preserve the permutation in the feature embedding space.
103108
**kwargs
104109
all parameters required for :class:`darts.models.forecasting.pl_forecasting_module.PLForecastingModule`
105110
base class.
@@ -119,6 +124,7 @@ def __init__(
119124
self.feed_forward = feed_forward
120125
self.dropout = dropout
121126
self.add_relative_index = add_relative_index
127+
self.skip_interpolation = skip_interpolation
122128

123129
if isinstance(norm_type, str):
124130
try:
@@ -182,6 +188,7 @@ def __init__(
182188
single_variable_grns={},
183189
context_size=None, # no context for static variables
184190
layer_norm=self.layer_norm,
191+
skip_interpolation=self.skip_interpolation,
185192
)
186193

187194
# variable selection for encoder and decoder
@@ -202,6 +209,7 @@ def __init__(
202209
prescalers=self.prescalers_linear,
203210
single_variable_grns={},
204211
layer_norm=self.layer_norm,
212+
skip_interpolation=self.skip_interpolation,
205213
)
206214

207215
self.decoder_vsn = _VariableSelectionNetwork(
@@ -213,6 +221,7 @@ def __init__(
213221
prescalers=self.prescalers_linear,
214222
single_variable_grns={},
215223
layer_norm=self.layer_norm,
224+
skip_interpolation=self.skip_interpolation,
216225
)
217226

218227
# static encoders
@@ -368,11 +377,11 @@ def decoder_variables(self) -> list[str]:
368377
return self.variables_meta["model_config"]["time_varying_decoder_input"]
369378

370379
@staticmethod
371-
def expand_static_context(context: torch.Tensor, time_steps: int) -> torch.Tensor:
380+
def expand_static_context(context: torch.Tensor) -> torch.Tensor:
372381
"""
373382
add time dimension to static context
374383
"""
375-
return context[:, None].expand(-1, time_steps, -1)
384+
return context.unsqueeze(1).contiguous()
376385

377386
@staticmethod
378387
def get_relative_index(
@@ -409,7 +418,7 @@ def get_attention_mask_future(
409418
encoder_length: int,
410419
decoder_length: int,
411420
batch_size: int,
412-
device: str,
421+
device: torch.device,
413422
full_attention: bool,
414423
) -> torch.Tensor:
415424
"""
@@ -466,7 +475,6 @@ def forward(self, x_in: PLModuleInput) -> torch.Tensor:
466475
batch_size = x_cont_past.shape[dim_samples]
467476
encoder_length = self.input_chunk_length
468477
decoder_length = self.output_chunk_length
469-
time_steps = encoder_length + decoder_length
470478

471479
# avoid unnecessary regeneration of attention mask
472480
if batch_size != self.batch_size_last:
@@ -549,23 +557,23 @@ def forward(self, x_in: PLModuleInput) -> torch.Tensor:
549557
static_covariate_var = None
550558

551559
static_context_expanded = self.expand_static_context(
552-
context=self.static_context_grn(static_embedding), time_steps=time_steps
560+
self.static_context_grn(static_embedding)
553561
)
554562

555563
embeddings_varying_encoder = {
556564
name: input_vectors_past[name] for name in self.encoder_variables
557565
}
558566
embeddings_varying_encoder, encoder_sparse_weights = self.encoder_vsn(
559567
x=embeddings_varying_encoder,
560-
context=static_context_expanded[:, :encoder_length],
568+
context=static_context_expanded,
561569
)
562570

563571
embeddings_varying_decoder = {
564572
name: input_vectors_future[name] for name in self.decoder_variables
565573
}
566574
embeddings_varying_decoder, decoder_sparse_weights = self.decoder_vsn(
567575
x=embeddings_varying_decoder,
568-
context=static_context_expanded[:, encoder_length:],
576+
context=static_context_expanded,
569577
)
570578

571579
# LSTM
@@ -603,9 +611,7 @@ def forward(self, x_in: PLModuleInput) -> torch.Tensor:
603611
static_context_enriched = self.static_context_enrichment(static_embedding)
604612
attn_input = self.static_enrichment_grn(
605613
x=lstm_out,
606-
context=self.expand_static_context(
607-
context=static_context_enriched, time_steps=time_steps
608-
),
614+
context=self.expand_static_context(static_context_enriched),
609615
)
610616

611617
# multi-head attention
@@ -660,6 +666,7 @@ def __init__(
660666
dict[str, Union[int, tuple[int, int]]]
661667
] = None,
662668
add_relative_index: bool = False,
669+
skip_interpolation: bool = False,
663670
loss_fn: Optional[nn.Module] = None,
664671
likelihood: Optional[TorchLikelihood] = None,
665672
norm_type: Union[str, nn.Module] = "LayerNorm",
@@ -742,6 +749,10 @@ def __init__(
742749
This allows to use the TFTModel without having to pass future_covariates to :func:`fit()` and
743750
:func:`train()`. It gives a value to the position of each step from input and output chunk relative
744751
to the prediction point. The values are normalized with ``input_chunk_length``.
752+
skip_interpolation
753+
Whether to skip interpolation and replace with linear projection on feature embeddings in
754+
VariableSelectionNetwork. Setting this to ``True`` could increase training and inference speed.
755+
Defaults to ``False`` to preserve the permutation in the feature embedding space.
745756
loss_fn: nn.Module
746757
PyTorch loss function used for training. By default, the TFT model is probabilistic and uses a
747758
``likelihood`` instead (``QuantileRegression``). To make the model deterministic, you can set the `
@@ -949,11 +960,12 @@ def encode_year(idx):
949960
else {}
950961
)
951962
self.add_relative_index = add_relative_index
963+
self.skip_interpolation = skip_interpolation
952964
self.output_dim: Optional[tuple[int, int]] = None
953965
self.norm_type = norm_type
954966
self._considers_static_covariates = use_static_covariates
955967

956-
def _create_model(self, train_sample: TorchTrainingSample) -> nn.Module:
968+
def _create_model(self, train_sample: TorchTrainingSample) -> PLForecastingModule:
957969
"""
958970
`train_sample` contains the following tensors:
959971
(past_target, past_covariates, historic_future_covariates, future_covariates, static_covariates,
@@ -1140,6 +1152,7 @@ def _create_model(self, train_sample: TorchTrainingSample) -> nn.Module:
11401152
hidden_continuous_size=self.hidden_continuous_size,
11411153
categorical_embedding_sizes=self.categorical_embedding_sizes,
11421154
add_relative_index=self.add_relative_index,
1155+
skip_interpolation=self.skip_interpolation,
11431156
norm_type=self.norm_type,
11441157
**self.pl_module_params,
11451158
)
@@ -1149,7 +1162,7 @@ def _build_train_dataset(
11491162
series: Sequence[TimeSeries],
11501163
past_covariates: Optional[Sequence[TimeSeries]],
11511164
future_covariates: Optional[Sequence[TimeSeries]],
1152-
sample_weight: Optional[Sequence[TimeSeries]],
1165+
sample_weight: Optional[Union[Sequence[TimeSeries], str]],
11531166
max_samples_per_ts: Optional[int],
11541167
stride: int = 1,
11551168
) -> TorchTrainingDataset:

0 commit comments

Comments
 (0)