Skip to content

Commit 2f1b3d9

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
Change kvzch_eviction_tbe_config to kvzch_tbe_config (meta-pytorch#3514)
Summary: X-link: pytorch/FBGEMM#5084 X-link: facebookresearch/FBGEMM#2092 Change tbe config name from kvzch_eviction_tbe_config to kvzch_tbe_config, as it may use not only for eviction but also for some other processes like st publish. Differential Revision: D86212643
1 parent 570eb1c commit 2f1b3d9

File tree

2 files changed

+10
-12
lines changed

2 files changed

+10
-12
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -242,10 +242,8 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
242242
)
243243
ssd_tbe_params["cache_sets"] = int(max_cache_sets)
244244

245-
if "kvzch_eviction_tbe_config" in fused_params and config.is_using_virtual_table:
246-
ssd_tbe_params["kvzch_eviction_tbe_config"] = fused_params.get(
247-
"kvzch_eviction_tbe_config"
248-
)
245+
if "kvzch_tbe_config" in fused_params and config.is_using_virtual_table:
246+
ssd_tbe_params["kvzch_tbe_config"] = fused_params.get("kvzch_tbe_config")
249247

250248
ssd_tbe_params["table_names"] = [table.name for table in config.embedding_tables]
251249

@@ -339,10 +337,10 @@ def _populate_zero_collision_tbe_params(
339337
l2_cache_size = tbe_params["l2_cache_size"]
340338

341339
assert (
342-
"kvzch_eviction_tbe_config" in tbe_params
343-
), "kvzch_eviction_tbe_config should be in tbe_params"
344-
eviction_tbe_config = tbe_params["kvzch_eviction_tbe_config"]
345-
tbe_params.pop("kvzch_eviction_tbe_config")
340+
"kvzch_tbe_config" in tbe_params
341+
), "kvzch_tbe_config should be in tbe_params"
342+
eviction_tbe_config = tbe_params["kvzch_tbe_config"]
343+
tbe_params.pop("kvzch_tbe_config")
346344
eviction_trigger_mode = eviction_tbe_config.kvzch_eviction_trigger_mode
347345
eviction_free_mem_threshold_gb = (
348346
eviction_tbe_config.eviction_free_mem_threshold_gb

torchrec/distributed/types.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
3434
BoundsCheckMode,
3535
CacheAlgorithm,
36-
KVZCHEvictionTBEConfig,
36+
KVZCHTBEConfig,
3737
MultiPassPrefetchConfig,
3838
)
3939

@@ -668,7 +668,7 @@ class KeyValueParams:
668668
lazy_bulk_init_enabled: bool: whether to enable lazy(async) bulk init for SSD TBE
669669
enable_raw_embedding_streaming: Optional[bool]: enable raw embedding streaming for SSD TBE
670670
res_store_shards: Optional[int] = None: the number of shards to store the raw embeddings
671-
kvzch_eviction_tbe_config: Optional[KVZCHEvictionTBEConfig]: KVZCH eviction config for TBE
671+
kvzch_tbe_config: Optional[KVZCHTBEConfig]: KVZCH config for TBE
672672
673673
# Parameter Server (PS) Attributes
674674
ps_hosts (Optional[Tuple[Tuple[str, int]]]): List of PS host ip addresses
@@ -694,7 +694,7 @@ class KeyValueParams:
694694
None # enable raw embedding streaming for SSD TBE
695695
)
696696
res_store_shards: Optional[int] = None # shards to store the raw embeddings
697-
kvzch_eviction_tbe_config: Optional[KVZCHEvictionTBEConfig] = None
697+
kvzch_tbe_config: Optional[KVZCHTBEConfig] = None
698698

699699
# Parameter Server (PS) Attributes
700700
ps_hosts: Optional[Tuple[Tuple[str, int], ...]] = None
@@ -723,7 +723,7 @@ def __hash__(self) -> int:
723723
self.lazy_bulk_init_enabled,
724724
self.enable_raw_embedding_streaming,
725725
self.res_store_shards,
726-
self.kvzch_eviction_tbe_config,
726+
self.kvzch_tbe_config,
727727
)
728728
)
729729

0 commit comments

Comments
 (0)