Skip to content

Commit d68d2e9

Browse files
authored
fix mps backend does not implement float64 (#2216)
1 parent f8dd297 commit d68d2e9

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

gptqmodel/utils/linalg_warmup.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,11 @@ def run_torch_linalg_warmup(device: torch.device) -> None:
4444
still runs once per physical device so backend-specific handles are initialized where needed.
4545
"""
4646
with _GLOBAL_WARMUP_LOCK:
47-
dtypes = (torch.float32, torch.float64)
47+
if device.type == "mps":
48+
dtypes = (torch.float32,) # MPS backend does not implement float64.
49+
else:
50+
dtypes = (torch.float32, torch.float64)
51+
4852
for dtype in dtypes:
4953
_run_cholesky_and_eigh(device, dtype)
5054
_run_svd(device, dtype)

0 commit comments

Comments
 (0)