@@ -341,11 +341,7 @@ def _forward_transform(
341341 ) -> tuple [torch .Tensor , tuple [torch .Tensor , torch .Tensor ]]:
342342 """Input is of shape [B, N, P]."""
343343 mu , sigma = self ._timesfm_masked_mean_std (inputs , patched_pads )
344- sigma = torch .where (
345- sigma < self .config .tolerance ,
346- torch .tensor (1.0 , dtype = sigma .dtype , device = sigma .device ),
347- sigma ,
348- )
344+ sigma = torch .clamp (sigma , min = self .config .tolerance )
349345
350346 # Normalize each patch
351347 outputs = (inputs - mu [:, None , None ]) / sigma [:, None , None ]
@@ -524,24 +520,16 @@ def _get_patch_index(arr: torch.Tensor):
524520
525521 # Calculate the number of valid elements
526522 num_valid_elements = torch .sum (mask , dim = 1 )
527- num_valid_elements = torch .where (
528- num_valid_elements == 0 ,
529- torch .tensor (1 , dtype = num_valid_elements .dtype , device = num_valid_elements .device ),
530- num_valid_elements ,
531- )
523+ num_valid_elements = torch .clamp (num_valid_elements , min = 1.0 )
532524
533- # Calculate the masked sum and squared sum
525+ # Calculate the masked sum and mean
534526 masked_sum = torch .sum (arr * mask , dim = 1 )
535- masked_squared_sum = torch .sum ((arr * mask ) ** 2 , dim = 1 )
536-
537- # Calculate the masked mean and standard deviation
538- masked_mean = masked_sum / num_valid_elements
539- masked_var = masked_squared_sum / num_valid_elements - masked_mean ** 2
540- masked_var = torch .where (
541- masked_var < 0.0 ,
542- torch .tensor (0.0 , dtype = masked_var .dtype , device = masked_var .device ),
543- masked_var ,
544- )
527+ masked_mean = masked_sum / num_valid_elements # [b]
528+
529+ # Calculate the masked variance using centered values
530+ masked_centered_arr = (arr - masked_mean .unsqueeze (- 1 )) * mask
531+ masked_var = torch .sum (masked_centered_arr ** 2 , dim = 1 ) / num_valid_elements
532+ masked_var = torch .clamp (masked_var , min = 0.0 )
545533 masked_std = torch .sqrt (masked_var )
546534
547535 return masked_mean , masked_std
0 commit comments