Skip to content

Commit 7f4a9d2

Browse files
Hao Yanmeta-codesync[bot]
authored andcommitted
Inference test e2e [1/n] (#5091)
Summary: Pull Request resolved: #5091 X-link: https://github.com/facebookresearch/FBGEMM/pull/2099 In this test, we run following step 1. Create a DramKVInferenceEmbedding with TTL eviction for 1 min 2. Insert 1 embedding with current Unixtime - 2 mins (it is already expired) as timestamp 3. Read from it and check correctness 4. Read for multiple times 5. Evict it 6. Read it --- this time should be inconsistent Reviewed By: emlin Differential Revision: D86268606 fbshipit-source-id: edc2dc24e5327399421d20229a0b1af2ca29ea7a
1 parent 16aa87b commit 7f4a9d2

File tree

1 file changed

+217
-0
lines changed

1 file changed

+217
-0
lines changed
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "deeplearning/fbgemm/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h"
10+
11+
#include <fmt/format.h>
12+
#include <glog/logging.h>
13+
#include <gtest/gtest.h>
14+
#include <chrono>
15+
#include <random>
16+
#include <vector>
17+
18+
namespace kv_mem {
19+
20+
class KVEmbeddingInferenceTest : public ::testing::Test {
21+
protected:
22+
static constexpr int EMBEDDING_DIM = 128;
23+
static constexpr int NUM_SHARDS = 8;
24+
25+
void SetUp() override {
26+
FLAGS_logtostderr = true;
27+
FLAGS_minloglevel = 0;
28+
FLAGS_v = 1;
29+
30+
auto feature_evict_config = c10::make_intrusive<FeatureEvictConfig>(
31+
3,
32+
4,
33+
std::nullopt,
34+
std::nullopt,
35+
std::vector<int64_t>{1},
36+
std::nullopt,
37+
std::nullopt,
38+
std::nullopt,
39+
std::nullopt,
40+
std::nullopt,
41+
std::nullopt,
42+
std::vector<int64_t>{EMBEDDING_DIM},
43+
std::nullopt,
44+
std::nullopt,
45+
0,
46+
0,
47+
0);
48+
49+
auto hash_size_cumsum = at::tensor({0, 100000}, at::kLong);
50+
51+
backend_ = std::make_unique<DramKVInferenceEmbedding<float>>(
52+
EMBEDDING_DIM,
53+
-0.1,
54+
0.1,
55+
feature_evict_config,
56+
NUM_SHARDS,
57+
32,
58+
32,
59+
false,
60+
std::nullopt,
61+
hash_size_cumsum,
62+
false);
63+
}
64+
65+
void TearDown() override {
66+
backend_.reset();
67+
}
68+
69+
static std::vector<float> generateEmbedding(int64_t embedding_id) {
70+
std::vector<float> embedding(EMBEDDING_DIM);
71+
72+
// Use both embedding_id and current time as seed for randomness
73+
auto now = std::chrono::system_clock::now();
74+
auto time_seed = std::chrono::duration_cast<std::chrono::nanoseconds>(
75+
now.time_since_epoch())
76+
.count();
77+
uint32_t combined_seed = static_cast<uint32_t>(embedding_id ^ time_seed);
78+
79+
std::mt19937 rng(combined_seed);
80+
std::uniform_real_distribution<float> dist(-0.1f, 0.1f);
81+
for (int i = 0; i < EMBEDDING_DIM; ++i) {
82+
embedding[i] = dist(rng);
83+
}
84+
return embedding;
85+
}
86+
87+
std::unique_ptr<DramKVInferenceEmbedding<float>> backend_;
88+
};
89+
90+
TEST_F(KVEmbeddingInferenceTest, InferenceLifecycleWithMetadata) {
91+
const int64_t embedding_id = 12345;
92+
93+
auto now = std::chrono::system_clock::now();
94+
auto now_seconds =
95+
std::chrono::duration_cast<std::chrono::seconds>(now.time_since_epoch())
96+
.count();
97+
const uint32_t snapshot_timestamp = static_cast<uint32_t>(now_seconds - 120);
98+
99+
auto embedding_data = generateEmbedding(embedding_id);
100+
101+
LOG(INFO) << "STEP 1: Define test embedding";
102+
LOG(INFO) << "Embedding ID: " << embedding_id;
103+
LOG(INFO) << "Timestamp: " << snapshot_timestamp
104+
<< " (current time - 2 minutes)";
105+
LOG(INFO) << "Dimension: " << EMBEDDING_DIM;
106+
LOG(INFO) << "First 5 elements: [" << embedding_data[0] << ", "
107+
<< embedding_data[1] << ", " << embedding_data[2] << ", "
108+
<< embedding_data[3] << ", " << embedding_data[4] << "]";
109+
110+
auto indices_tensor = at::tensor({embedding_id}, at::kLong);
111+
auto weights_tensor = at::from_blob(
112+
embedding_data.data(),
113+
{1, EMBEDDING_DIM},
114+
at::TensorOptions().dtype(at::kFloat));
115+
auto count_tensor = at::tensor({1}, at::kInt);
116+
117+
LOG(INFO) << "STEP 2: Insert embedding into cache";
118+
folly::coro::blockingWait(backend_->inference_set_kv_db_async(
119+
indices_tensor, weights_tensor, count_tensor, snapshot_timestamp));
120+
LOG(INFO) << "Insertion completed";
121+
122+
auto retrieved_embedding = at::zeros({1, EMBEDDING_DIM}, at::kFloat);
123+
124+
LOG(INFO) << "STEP 3: Retrieve embedding from cache";
125+
folly::coro::blockingWait(backend_->get_kv_db_async(
126+
indices_tensor, retrieved_embedding, count_tensor));
127+
LOG(INFO) << "Retrieval completed";
128+
129+
auto retrieved_ptr = retrieved_embedding.data_ptr<float>();
130+
bool all_match = true;
131+
int mismatch_count = 0;
132+
133+
LOG(INFO) << "STEP 4: Verify embedding consistency";
134+
for (int i = 0; i < EMBEDDING_DIM; ++i) {
135+
if (std::abs(retrieved_ptr[i] - embedding_data[i]) > 1e-5f) {
136+
all_match = false;
137+
mismatch_count++;
138+
}
139+
}
140+
141+
if (all_match) {
142+
LOG(INFO) << "All " << EMBEDDING_DIM << " dimensions match";
143+
} else {
144+
LOG(ERROR) << "Found " << mismatch_count << " mismatches out of "
145+
<< EMBEDDING_DIM << " dimensions";
146+
}
147+
148+
ASSERT_TRUE(all_match) << "Retrieved embedding must match inserted embedding";
149+
150+
LOG(INFO) << "STEP 5: Test repeated reads";
151+
for (int iteration = 1; iteration <= 3; ++iteration) {
152+
auto read_again = at::zeros({1, EMBEDDING_DIM}, at::kFloat);
153+
folly::coro::blockingWait(
154+
backend_->get_kv_db_async(indices_tensor, read_again, count_tensor));
155+
156+
auto read_ptr = read_again.data_ptr<float>();
157+
bool matches = true;
158+
for (int i = 0; i < EMBEDDING_DIM; ++i) {
159+
if (std::abs(read_ptr[i] - embedding_data[i]) > 1e-5f) {
160+
matches = false;
161+
break;
162+
}
163+
}
164+
LOG(INFO) << "Read #" << iteration << ": "
165+
<< (matches ? "Match" : "Mismatch");
166+
}
167+
168+
LOG(INFO) << "STEP 6: Trigger eviction";
169+
auto eviction_time = std::chrono::system_clock::now();
170+
auto eviction_seconds = std::chrono::duration_cast<std::chrono::seconds>(
171+
eviction_time.time_since_epoch())
172+
.count();
173+
uint32_t eviction_threshold = static_cast<uint32_t>(eviction_seconds - 60);
174+
175+
LOG(INFO) << "Eviction threshold: " << eviction_threshold;
176+
backend_->trigger_feature_evict(eviction_threshold);
177+
backend_->wait_until_eviction_done();
178+
LOG(INFO) << "Eviction completed";
179+
180+
auto post_eviction_embedding = at::zeros({1, EMBEDDING_DIM}, at::kFloat);
181+
182+
LOG(INFO) << "STEP 7: Read embedding after eviction";
183+
folly::coro::blockingWait(backend_->get_kv_db_async(
184+
indices_tensor, post_eviction_embedding, count_tensor));
185+
186+
auto post_eviction_ptr = post_eviction_embedding.data_ptr<float>();
187+
bool values_changed = false;
188+
int differences = 0;
189+
190+
for (int i = 0; i < EMBEDDING_DIM; ++i) {
191+
if (std::abs(post_eviction_ptr[i] - embedding_data[i]) > 1e-5f) {
192+
values_changed = true;
193+
differences++;
194+
}
195+
}
196+
197+
LOG(INFO) << "Differences found: " << differences << "/" << EMBEDDING_DIM;
198+
199+
if (values_changed) {
200+
LOG(INFO) << "Eviction successful - values changed";
201+
} else {
202+
LOG(ERROR) << "Eviction may have failed - values unchanged";
203+
}
204+
205+
LOG(INFO) << "Original (cached): [" << embedding_data[0] << ", "
206+
<< embedding_data[1] << ", " << embedding_data[2] << ", "
207+
<< embedding_data[3] << ", " << embedding_data[4] << "]";
208+
LOG(INFO) << "After eviction: [" << post_eviction_ptr[0] << ", "
209+
<< post_eviction_ptr[1] << ", " << post_eviction_ptr[2] << ", "
210+
<< post_eviction_ptr[3] << ", " << post_eviction_ptr[4] << "]";
211+
212+
ASSERT_TRUE(values_changed) << "Embedding should be different after eviction";
213+
214+
LOG(INFO) << "Test completed successfully";
215+
}
216+
217+
} // namespace kv_mem

0 commit comments

Comments
 (0)