Skip to content

Commit df979a4

Browse files
Raahul Kalyaan Jakkafacebook-github-bot
authored andcommitted
Changing Backend Tensor initialization
Summary: X-link: meta-pytorch/torchrec#3484 X-link: facebookresearch/FBGEMM#2066 **Context:** Currently, we are enabling SSD optimizer offloading for the ssd tbe kernel **In this diff:** We retrieve the newly added parameters from the tbe config and pass it down to the tbe Differential Revision: D85353134
1 parent b7c013a commit df979a4

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
4545
std::optional<at::Tensor> table_dims = std::nullopt,
4646
std::optional<at::Tensor> hash_size_cumsum = std::nullopt,
4747
int64_t flushing_block_size = 2000000000 /*2GB*/,
48-
bool disable_random_init = false)
48+
bool disable_random_init = false,
49+
bool enable_optimizer_offloading = false,
50+
int64_t optimizer_D = 0)
4951
: impl_(std::make_shared<ssd::EmbeddingRocksDB>(
5052
path,
5153
num_shards,
@@ -76,7 +78,9 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
7678
table_dims,
7779
hash_size_cumsum,
7880
flushing_block_size,
79-
disable_random_init)) {}
81+
disable_random_init,
82+
enable_optimizer_offloading,
83+
optimizer_D)) {}
8084

8185
void set_cuda(
8286
at::Tensor indices,

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,9 @@ static auto embedding_rocks_db_wrapper =
809809
std::optional<at::Tensor>,
810810
std::optional<at::Tensor>,
811811
int64_t,
812-
bool>(),
812+
bool,
813+
bool,
814+
int64_t>(),
813815
"",
814816
{
815817
torch::arg("path"),
@@ -842,6 +844,8 @@ static auto embedding_rocks_db_wrapper =
842844
torch::arg("hash_size_cumsum") = std::nullopt,
843845
torch::arg("flushing_block_size") = 2000000000 /* 2GB */,
844846
torch::arg("disable_random_init") = false,
847+
torch::arg("enable_optimizer_offloading") = false,
848+
torch::arg("optimizer_D") = 0,
845849
})
846850
.def(
847851
"set_cuda",

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,6 +1378,7 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
13781378
std::vector<std::string> db_paths_;
13791379

13801380
bool disable_random_init_;
1381+
bool enable_optimizer_offloading;
13811382
}; // class EmbeddingRocksDB
13821383

13831384
/// @ingroup embedding-ssd

0 commit comments

Comments
 (0)