@@ -65,7 +65,7 @@ def _decomposition(
6565 trend_list .append (trend )
6666 return seasonal_list , trend_list
6767
68- def _multi_scale_process_inputs (
68+ def _prepare_multi_scale_inputs (
6969 self ,
7070 inputs : torch .Tensor ,
7171 inputs_timestamps : Optional [torch .Tensor ] = None
@@ -82,7 +82,8 @@ def _multi_scale_process_inputs(
8282 sample = down_sampled
8383
8484 if inputs_timestamps is not None :
85- multi_scale_timestamps .append (sample_ts [:, :, ::self .down_sampling_window ])
85+ multi_scale_timestamps .append (
86+ sample_ts [:, :, ::self .down_sampling_window ].permute (0 , 2 , 1 ))
8687 sample_ts = sample_ts [:, :, ::self .down_sampling_window ]
8788
8889 return multi_scale_inputs , multi_scale_timestamps
@@ -94,7 +95,7 @@ def forward(self,
9495 decomp : bool = False ,
9596 ) -> tuple [list [torch .Tensor ], Optional [list [torch .Tensor ]]]:
9697
97- x_list , x_ts_list = self ._multi_scale_process_inputs (inputs , inputs_timestamps )
98+ x_list , x_ts_list = self ._prepare_multi_scale_inputs (inputs , inputs_timestamps )
9899 num_scales = len (x_list )
99100
100101 for i in range (num_scales ):
0 commit comments