Skip to content

Commit 2980010

Browse files
jiayisusefacebook-github-bot
authored andcommitted
remove .int() for cpu indices and values (#1590)
Summary: Pull Request resolved: #1590 CPU DI is serving large model with big embedding tables (2TB), the value and indices would overflow with .int() conversion. Remove .int() just for CPU Reviewed By: zyan0, tissue3 Differential Revision: D52225777 fbshipit-source-id: 0bf7973a91a7b7daed6eaed3a55bb8dca25fcdef
1 parent 80b19a2 commit 2980010

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

torchrec/distributed/quant_embedding_kernel.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,14 @@ def _quantize_weight(
118118
def _unwrap_kjt(
119119
features: KeyedJaggedTensor,
120120
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
121-
return features.values().int(), features.offsets().int(), features.weights_or_none()
121+
if features.device().type == "cuda":
122+
return (
123+
features.values().int(),
124+
features.offsets().int(),
125+
features.weights_or_none(),
126+
)
127+
else:
128+
return features.values(), features.offsets(), features.weights_or_none()
122129

123130

124131
class QuantBatchedEmbeddingBag(

0 commit comments

Comments
 (0)