Skip to content

[BUG] Inference-mode attention mask in TFT causing RuntimeError #2915

@daidahao

Description

@daidahao

Describe the bug
A RuntimeError could be thrown when training the TFT again after forecasting.

To Reproduce
Steps to reproduce the behavior, preferably code snippet.

import numpy as np

from darts import TimeSeries
from darts.models import TFTModel

# Create a series with 3 components
series = TimeSeries.from_values(np.random.rand(100, 3).astype(np.float32))

# Define a TFT model
model = TFTModel(10, 10, add_relative_index=True, n_epochs=1)

# Train the model
model.fit(series, verbose=True)

# Make a prediction with `num_samples` equal to the batch size
model.predict(10, num_samples=model.batch_size)

# Train the model again, causing a `RuntimeError` due to uncleaned `attention_mask`
model.fit(series, verbose=True)

A RuntimeError is thrown because attention_mask from inference is being reused for retraining.

RuntimeError: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.

Expected behavior
A clear and concise description of what you expected to happen.

System (please complete the following information):

  • Python version: 3.13
  • darts version: latest from source

Additional context
The bug should be easily fixable by clearing the attention_mask in TFT before each training.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingtriageIssue waiting for triaging

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions