Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

### For developers of the library:

- Reworked the `_validate_model_params` function of `TorchForecastingModel` to support more complicated cases of class inheritance. [#2908](https://github.com/unit8co/darts/pull/2908) by [Tim Rosenflanz](https://github.com/trosenflanz).

## [0.37.1](https://github.com/unit8co/darts/tree/0.37.1) (2025-08-18)

### For users of the library:
Expand Down
24 changes: 17 additions & 7 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,14 +362,24 @@ def encode_year(idx):

@classmethod
def _validate_model_params(cls, **kwargs):
"""validate that parameters used at model creation are part of :class:`TorchForecastingModel`,
:class:`PLForecastingModule` or cls __init__ methods.
"""
valid_kwargs = (
set(inspect.signature(TorchForecastingModel.__init__).parameters.keys())
| set(inspect.signature(PLForecastingModule.__init__).parameters.keys())
| set(inspect.signature(cls.__init__).parameters.keys())
# initiate with PLForecastingModule params that isn't part of the base class
valid_kwargs = set(
inspect.signature(PLForecastingModule.__init__).parameters.keys()
)
# add params from the full list of base classes
for base in inspect.getmro(cls):
if base is object:
break
try:
sig = inspect.signature(base.__init__)
valid_kwargs.update(sig.parameters.keys())
except (ValueError, TypeError):
# In case a built-in class throws or __init__ is not introspectable
continue

# Remove 'self','args,'kwargs' from consideration
for generic_arg in ["self", "args", "kwargs"]:
valid_kwargs.discard(generic_arg)

invalid_kwargs = [kwarg for kwarg in kwargs if kwarg not in valid_kwargs]

Expand Down