Skip to content

Commit 35d1c9b

Browse files
committed
fix: 🐛 fix a bug in TimeMixer
1 parent 29ab595 commit 35d1c9b

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/basicts/models/TimeMixer/arch/timemixer_arch.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def _decomposition(
6565
trend_list.append(trend)
6666
return seasonal_list, trend_list
6767

68-
def _multi_scale_process_inputs(
68+
def _prepare_multi_scale_inputs(
6969
self,
7070
inputs: torch.Tensor,
7171
inputs_timestamps: Optional[torch.Tensor] = None
@@ -82,7 +82,8 @@ def _multi_scale_process_inputs(
8282
sample = down_sampled
8383

8484
if inputs_timestamps is not None:
85-
multi_scale_timestamps.append(sample_ts[:, :, ::self.down_sampling_window])
85+
multi_scale_timestamps.append(
86+
sample_ts[:, :, ::self.down_sampling_window].permute(0, 2, 1))
8687
sample_ts = sample_ts[:, :, ::self.down_sampling_window]
8788

8889
return multi_scale_inputs, multi_scale_timestamps
@@ -94,7 +95,7 @@ def forward(self,
9495
decomp: bool = False,
9596
) -> tuple[list[torch.Tensor], Optional[list[torch.Tensor]]]:
9697

97-
x_list, x_ts_list = self._multi_scale_process_inputs(inputs, inputs_timestamps)
98+
x_list, x_ts_list = self._prepare_multi_scale_inputs(inputs, inputs_timestamps)
9899
num_scales = len(x_list)
99100

100101
for i in range(num_scales):

0 commit comments

Comments
 (0)