Skip to content

Commit 152f5b6

Browse files
authored
Fix TimesFM patch normalization instability (#42099)
* Fix TimesFM potential numerical instability in masked mean/std calculation. * Fix sigma clamping to 1 instead of config.tolerance in TimesFM.
1 parent 69f0036 commit 152f5b6

File tree

2 files changed

+18
-42
lines changed

2 files changed

+18
-42
lines changed

src/transformers/models/timesfm/modeling_timesfm.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -341,11 +341,7 @@ def _forward_transform(
341341
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
342342
"""Input is of shape [B, N, P]."""
343343
mu, sigma = self._timesfm_masked_mean_std(inputs, patched_pads)
344-
sigma = torch.where(
345-
sigma < self.config.tolerance,
346-
torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device),
347-
sigma,
348-
)
344+
sigma = torch.clamp(sigma, min=self.config.tolerance)
349345

350346
# Normalize each patch
351347
outputs = (inputs - mu[:, None, None]) / sigma[:, None, None]
@@ -524,24 +520,16 @@ def _get_patch_index(arr: torch.Tensor):
524520

525521
# Calculate the number of valid elements
526522
num_valid_elements = torch.sum(mask, dim=1)
527-
num_valid_elements = torch.where(
528-
num_valid_elements == 0,
529-
torch.tensor(1, dtype=num_valid_elements.dtype, device=num_valid_elements.device),
530-
num_valid_elements,
531-
)
523+
num_valid_elements = torch.clamp(num_valid_elements, min=1.0)
532524

533-
# Calculate the masked sum and squared sum
525+
# Calculate the masked sum and mean
534526
masked_sum = torch.sum(arr * mask, dim=1)
535-
masked_squared_sum = torch.sum((arr * mask) ** 2, dim=1)
536-
537-
# Calculate the masked mean and standard deviation
538-
masked_mean = masked_sum / num_valid_elements
539-
masked_var = masked_squared_sum / num_valid_elements - masked_mean**2
540-
masked_var = torch.where(
541-
masked_var < 0.0,
542-
torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device),
543-
masked_var,
544-
)
527+
masked_mean = masked_sum / num_valid_elements # [b]
528+
529+
# Calculate the masked variance using centered values
530+
masked_centered_arr = (arr - masked_mean.unsqueeze(-1)) * mask
531+
masked_var = torch.sum(masked_centered_arr**2, dim=1) / num_valid_elements
532+
masked_var = torch.clamp(masked_var, min=0.0)
545533
masked_std = torch.sqrt(masked_var)
546534

547535
return masked_mean, masked_std

src/transformers/models/timesfm/modular_timesfm.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -297,11 +297,7 @@ def _forward_transform(
297297
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
298298
"""Input is of shape [B, N, P]."""
299299
mu, sigma = self._timesfm_masked_mean_std(inputs, patched_pads)
300-
sigma = torch.where(
301-
sigma < self.config.tolerance,
302-
torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device),
303-
sigma,
304-
)
300+
sigma = torch.clamp(sigma, min=self.config.tolerance)
305301

306302
# Normalize each patch
307303
outputs = (inputs - mu[:, None, None]) / sigma[:, None, None]
@@ -480,24 +476,16 @@ def _get_patch_index(arr: torch.Tensor):
480476

481477
# Calculate the number of valid elements
482478
num_valid_elements = torch.sum(mask, dim=1)
483-
num_valid_elements = torch.where(
484-
num_valid_elements == 0,
485-
torch.tensor(1, dtype=num_valid_elements.dtype, device=num_valid_elements.device),
486-
num_valid_elements,
487-
)
479+
num_valid_elements = torch.clamp(num_valid_elements, min=1.0)
488480

489-
# Calculate the masked sum and squared sum
481+
# Calculate the masked sum and mean
490482
masked_sum = torch.sum(arr * mask, dim=1)
491-
masked_squared_sum = torch.sum((arr * mask) ** 2, dim=1)
492-
493-
# Calculate the masked mean and standard deviation
494-
masked_mean = masked_sum / num_valid_elements
495-
masked_var = masked_squared_sum / num_valid_elements - masked_mean**2
496-
masked_var = torch.where(
497-
masked_var < 0.0,
498-
torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device),
499-
masked_var,
500-
)
483+
masked_mean = masked_sum / num_valid_elements # [b]
484+
485+
# Calculate the masked variance using centered values
486+
masked_centered_arr = (arr - masked_mean.unsqueeze(-1)) * mask
487+
masked_var = torch.sum(masked_centered_arr**2, dim=1) / num_valid_elements
488+
masked_var = torch.clamp(masked_var, min=0.0)
501489
masked_std = torch.sqrt(masked_var)
502490

503491
return masked_mean, masked_std

0 commit comments

Comments
 (0)