@@ -646,6 +646,12 @@ def _verify_past_future_covariates(self, past_covariates, future_covariates):
646646 logger = logger ,
647647 )
648648
649+ @abstractmethod
650+ def _update_covariates_use (self ):
651+ """Based on the Forecasting class and the training_sample attribute, update the
652+ uses_[past/future/static]_covariates attributes."""
653+ pass
654+
649655 def to_onnx (self , path : Optional [str ] = None , ** kwargs ):
650656 """Export model to ONNX format for optimized inference, wrapping around PyTorch Lightning's
651657 :func:`torch.onnx.export` method (`official documentation <https://lightning.ai/docs/pytorch/
@@ -677,6 +683,8 @@ def to_onnx(self, path: Optional[str] = None, **kwargs):
677683 ``input_sample``, ``input_name``). For more information, read the `official documentation
678684 <https://pytorch.org/docs/master/onnx.html#torch.onnx.export>`_.
679685 """
686+ # TODO: LSTM model should be exported with a batch size of 1
687+ # TODO: predictions with TFT and TCN models is incorrect, might be caused by helper function to process inputs
680688 if not self ._fit_called :
681689 raise_log (
682690 ValueError ("`fit()` needs to be called before `to_onnx()`." ), logger
@@ -2133,6 +2141,9 @@ def load_weights_from_checkpoint(
21332141 self .model .load_state_dict (ckpt ["state_dict" ], strict = strict )
21342142 # update the fit_called attribute to allow for direct inference
21352143 self ._fit_called = True
2144+ # based on the shape of train_sample, figure out which covariates are used by the model
2145+ # (usually set in the Darts model prior to fitting it)
2146+ self ._update_covariates_use ()
21362147
21372148 def load_weights (
21382149 self , path : str , load_encoders : bool = True , skip_checks : bool = False , ** kwargs
@@ -2683,6 +2694,13 @@ def extreme_lags(
26832694 None ,
26842695 )
26852696
2697+ def _update_covariates_use (self ):
2698+ """The model is expected to rely on the `PastCovariatesTrainingDataset`"""
2699+ _ , past_covs , static_covs , _ = self .train_sample
2700+ self ._uses_past_covariates = past_covs is not None
2701+ self ._uses_future_covariates = False
2702+ self ._uses_static_covariates = static_covs is not None
2703+
26862704
26872705class FutureCovariatesTorchModel (TorchForecastingModel , ABC ):
26882706 supports_past_covariates = False
@@ -2776,6 +2794,13 @@ def extreme_lags(
27762794 None ,
27772795 )
27782796
2797+ def _update_covariates_use (self ):
2798+ """The model is expected to rely on the `FutureCovariatesTrainingDataset`"""
2799+ _ , future_covs , static_covs , _ = self .train_sample
2800+ self ._uses_past_covariates = False
2801+ self ._uses_future_covariates = future_covs is not None
2802+ self ._uses_static_covariates = static_covs is not None
2803+
27792804
27802805class DualCovariatesTorchModel (TorchForecastingModel , ABC ):
27812806 supports_past_covariates = False
@@ -2870,6 +2895,15 @@ def extreme_lags(
28702895 None ,
28712896 )
28722897
2898+ def _update_covariates_use (self ):
2899+ """The model is expected to rely on the `DualCovariatesTrainingDataset`"""
2900+ _ , historic_future_covs , future_covs , static_covs , _ = self .train_sample
2901+ self ._uses_past_covariates = False
2902+ self ._uses_future_covariates = (
2903+ historic_future_covs is not None or future_covs is not None
2904+ )
2905+ self ._uses_static_covariates = static_covs is not None
2906+
28732907
28742908class MixedCovariatesTorchModel (TorchForecastingModel , ABC ):
28752909 def _build_train_dataset (
@@ -2964,6 +2998,17 @@ def extreme_lags(
29642998 None ,
29652999 )
29663000
3001+ def _update_covariates_use (self ):
3002+ """The model is expected to rely on the `MixedCovariatesTrainingDataset`"""
3003+ _ , past_covs , historic_future_covs , future_covs , static_covs , _ = (
3004+ self .train_sample
3005+ )
3006+ self ._uses_past_covariates = past_covs is not None
3007+ self ._uses_future_covariates = (
3008+ historic_future_covs is not None or future_covs is not None
3009+ )
3010+ self ._uses_static_covariates = static_covs is not None
3011+
29673012
29683013class SplitCovariatesTorchModel (TorchForecastingModel , ABC ):
29693014 def _build_train_dataset (
@@ -3058,3 +3103,12 @@ def extreme_lags(
30583103 self .output_chunk_shift ,
30593104 None ,
30603105 )
3106+
3107+ def _update_covariates_use (self ):
3108+ """The model is expected to rely on the `SplitCovariatesTrainingDataset`"""
3109+ _ , past_covs , historic_future_covs , future_covs , static_covs , _ = (
3110+ self .train_sample
3111+ )
3112+ self ._uses_past_covariates = past_covs is not None
3113+ self ._uses_future_covariates = future_covs is not None
3114+ self ._uses_static_covariates = static_covs is not None
0 commit comments