Skip to content

Slicing predictions for using plot_prediction() method with Temporal Fusion Transformer #1945

@nicoliKim

Description

@nicoliKim

Hello,

First of all thanks a lot for putting together this nice libraries.

I have a question regarding on how to properly call the plot_prediction() method.
Assume I have a TemporalFusionTransformer model which I refer below as model.

To make predictions I do the following:

preds = model.predict(
            test_dataloader,
            mode='raw', 
            return_x=True,
            return_y=True, 
            mode_kwargs={"n_samples": n_samples}, 
            trainer_kwargs=dict(accelerator="mps")
            )

Now using these predictions, I want to plot the results.
Note that if I do

type(self.predictions.output.prediction)

this is a list of size 3where each element of the list is a different variable I am forecasting. If I analyze the shape of the first element of the list I get

self.predictions.output.prediction[0].shape

I get

torch.Size([1, 60, 3])

where the first dimension is the batch, second dimension is the prediction length and last dimension is the quantiles I obtain from the QuantileLoss loss function.

What I am struggling with is the following. I'd like to call the plot_prediction() method so that only one element of the list self.predictions.output.prediction is being used.

I hope the way I phrased the question makes sense. Please bear with me as I am new to the library and its workflow. If you have suggestions on how to do things differently that would be very much appreciated.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions