-
Notifications
You must be signed in to change notification settings - Fork 722
Description
Hello,
First of all thanks a lot for putting together this nice libraries.
I have a question regarding on how to properly call the plot_prediction() method.
Assume I have a TemporalFusionTransformer model which I refer below as model.
To make predictions I do the following:
preds = model.predict(
test_dataloader,
mode='raw',
return_x=True,
return_y=True,
mode_kwargs={"n_samples": n_samples},
trainer_kwargs=dict(accelerator="mps")
)Now using these predictions, I want to plot the results.
Note that if I do
type(self.predictions.output.prediction)this is a list of size 3where each element of the list is a different variable I am forecasting. If I analyze the shape of the first element of the list I get
self.predictions.output.prediction[0].shapeI get
torch.Size([1, 60, 3])where the first dimension is the batch, second dimension is the prediction length and last dimension is the quantiles I obtain from the QuantileLoss loss function.
What I am struggling with is the following. I'd like to call the plot_prediction() method so that only one element of the list self.predictions.output.prediction is being used.
I hope the way I phrased the question makes sense. Please bear with me as I am new to the library and its workflow. If you have suggestions on how to do things differently that would be very much appreciated.