diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c168ca157..a87ed2eec9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -68,6 +68,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co ### For developers of the library: + - Reworked the `_validate_model_params` function of `TorchForecastingModel` to support more complicated cases of class inheritance. [#2908](https://github.com/unit8co/darts/pull/2908) by [Tim Rosenflanz](https://github.com/trosenflanz). + ## [0.37.1](https://github.com/unit8co/darts/tree/0.37.1) (2025-08-18) ### For users of the library: diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index 74c28d4162..e2de5bb876 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -363,14 +363,24 @@ def encode_year(idx): @classmethod def _validate_model_params(cls, **kwargs): - """validate that parameters used at model creation are part of :class:`TorchForecastingModel`, - :class:`PLForecastingModule` or cls __init__ methods. - """ - valid_kwargs = ( - set(inspect.signature(TorchForecastingModel.__init__).parameters.keys()) - | set(inspect.signature(PLForecastingModule.__init__).parameters.keys()) - | set(inspect.signature(cls.__init__).parameters.keys()) + # initiate with PLForecastingModule params that isn't part of the base class + valid_kwargs = set( + inspect.signature(PLForecastingModule.__init__).parameters.keys() ) + # add params from the full list of base classes + for base in inspect.getmro(cls): + if base is object: + break + try: + sig = inspect.signature(base.__init__) + valid_kwargs.update(sig.parameters.keys()) + except (ValueError, TypeError): + # In case a built-in class throws or __init__ is not introspectable + continue + + # Remove 'self','args,'kwargs' from consideration + for generic_arg in ["self", "args", "kwargs"]: + valid_kwargs.discard(generic_arg) invalid_kwargs = [kwarg for kwarg in kwargs if kwarg not in valid_kwargs]