-
Notifications
You must be signed in to change notification settings - Fork 973
Open
Labels
bugSomething isn't workingSomething isn't workingtriageIssue waiting for triagingIssue waiting for triaging
Description
Describe the bug
A RuntimeError could be thrown when training the TFT again after forecasting.
To Reproduce
Steps to reproduce the behavior, preferably code snippet.
import numpy as np
from darts import TimeSeries
from darts.models import TFTModel
# Create a series with 3 components
series = TimeSeries.from_values(np.random.rand(100, 3).astype(np.float32))
# Define a TFT model
model = TFTModel(10, 10, add_relative_index=True, n_epochs=1)
# Train the model
model.fit(series, verbose=True)
# Make a prediction with `num_samples` equal to the batch size
model.predict(10, num_samples=model.batch_size)
# Train the model again, causing a `RuntimeError` due to uncleaned `attention_mask`
model.fit(series, verbose=True)A RuntimeError is thrown because attention_mask from inference is being reused for retraining.
RuntimeError: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.
Expected behavior
A clear and concise description of what you expected to happen.
System (please complete the following information):
- Python version: 3.13
- darts version: latest from source
Additional context
The bug should be easily fixable by clearing the attention_mask in TFT before each training.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingtriageIssue waiting for triagingIssue waiting for triaging