Skip to content

Commit 6da4958

Browse files
committed
feat(search): Probabilitic cutoffs
Signed-off-by: Vladislav Oleshko <[email protected]>
1 parent ceed85a commit 6da4958

File tree

6 files changed

+95
-39
lines changed

6 files changed

+95
-39
lines changed

src/core/search/index_result.h

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ class IndexResult {
3939

4040
BorrowedView Borrowed() const;
4141

42-
// Move out of owned or copy borrowed
43-
DocVec Take();
42+
// Move out of owned or copy borrowed. Take up to `limit` entries and return original size.
43+
std::pair<DocVec, size_t /* full size */> Take(size_t limit = std::numeric_limits<size_t>::max());
4444

4545
private:
4646
bool IsOwned() const;
@@ -82,20 +82,36 @@ inline IndexResult::BorrowedView IndexResult::Borrowed() const {
8282
return std::visit(cb, value_);
8383
}
8484

85-
inline IndexResult::DocVec IndexResult::Take() {
85+
inline std::pair<IndexResult::DocVec, size_t> IndexResult::Take(size_t limit) {
8686
if (IsOwned()) {
87-
return std::move(std::get<DocVec>(value_));
87+
auto& vec = std::get<DocVec>(value_);
88+
size_t size = vec.size();
89+
return {std::move(vec), size};
8890
}
8991

90-
auto cb = [](auto* set) -> DocVec {
92+
// Numeric ranges need to be filtered and don't know their size ahead
93+
// if (std::holds_alternative<RangeResult>(value_)) {
94+
auto cb = [](auto* range) -> std::pair<DocVec, size_t> {
9195
DocVec out;
92-
out.reserve(set->size());
93-
for (auto it = set->begin(); it != set->end(); ++it) {
96+
out.reserve(range->size());
97+
for (auto it = range->begin(); it != range->end(); ++it)
9498
out.push_back(*it);
95-
}
96-
return out;
99+
size_t total = out.size();
100+
return {std::move(out), total};
97101
};
98102
return std::visit(cb, Borrowed());
103+
//}
104+
105+
// Generic borrowed results sets don't need to be filtered, so we can tell the result size ahead
106+
/*auto cb = [limit](auto* set) -> std::pair<DocVec, size_t> {
107+
DocVec out;
108+
size_t taken = std::min(limit, set->size());
109+
out.reserve(taken);
110+
for (auto it = set->begin(); it != set->end() && out.size(); ++it)
111+
out.push_back(*it);
112+
return {std::move(out), set->size()};
113+
};
114+
return std::visit(cb, Borrowed());*/
99115
}
100116

101117
inline bool IndexResult::IsOwned() const {

src/core/search/search.cc

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ struct BasicSearch {
289289

290290
// negate -(*subquery*): explicitly compute result complement. Needs further optimizations
291291
IndexResult Search(const AstNegateNode& node, string_view active_field) {
292-
vector<DocId> matched = SearchGeneric(*node.node, active_field).Take();
292+
auto matched = SearchGeneric(*node.node, active_field).Take().first;
293293
vector<DocId> all = indices_->GetAllDocs();
294294

295295
// To negate a result, we have to find the complement of matched to all documents,
@@ -358,7 +358,7 @@ struct BasicSearch {
358358
knn_distances_ = vec_index->Knn(knn.vec.first.get(), knn.limit, knn.ef_runtime);
359359
else
360360
knn_distances_ =
361-
vec_index->Knn(knn.vec.first.get(), knn.limit, knn.ef_runtime, sub_results.Take());
361+
vec_index->Knn(knn.vec.first.get(), knn.limit, knn.ef_runtime, sub_results.Take().first);
362362
}
363363

364364
// [KNN limit @field vec]: Compute distance from `vec` to all vectors keep closest `limit`
@@ -405,7 +405,6 @@ struct BasicSearch {
405405

406406
// Top level results don't need to be sorted, because they will be scored, sorted by fields or
407407
// used by knn
408-
409408
DCHECK(top_level || holds_alternative<AstKnnNode>(node.Variant()) ||
410409
holds_alternative<AstGeoNode>(node.Variant()) ||
411410
visit([](auto* set) { return is_sorted(set->begin(), set->end()); }, result.Borrowed()));
@@ -416,16 +415,15 @@ struct BasicSearch {
416415
return result;
417416
}
418417

419-
SearchResult Search(const AstNode& query) {
418+
SearchResult Search(const AstNode& query, size_t cuttoff_limit) {
420419
IndexResult result = SearchGeneric(query, "", true);
421420

422421
// Extract profile if enabled
423422
optional<AlgorithmProfile> profile =
424423
profile_builder_ ? make_optional(profile_builder_->Take()) : nullopt;
425424

426-
auto out = result.Take();
427-
const size_t total = out.size();
428-
return SearchResult{total, std::move(out), std::move(knn_scores_), std::move(profile),
425+
auto [out, total_size] = result.Take(cuttoff_limit);
426+
return SearchResult{total_size, std::move(out), std::move(knn_scores_), std::move(profile),
429427
std::move(error_)};
430428
}
431429

@@ -654,11 +652,11 @@ bool SearchAlgorithm::Init(string_view query, const QueryParams* params,
654652
return true;
655653
}
656654

657-
SearchResult SearchAlgorithm::Search(const FieldIndices* index) const {
655+
SearchResult SearchAlgorithm::Search(const FieldIndices* index, size_t cuttoff_limit) const {
658656
auto bs = BasicSearch{index};
659657
if (profiling_enabled_)
660658
bs.EnableProfiling();
661-
return bs.Search(*query_);
659+
return bs.Search(*query_, cuttoff_limit);
662660
}
663661

664662
optional<KnnScoreSortOption> SearchAlgorithm::GetKnnScoreSortOption() const {

src/core/search/search.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,9 @@ class SearchAlgorithm {
197197
bool Init(std::string_view query, const QueryParams* params,
198198
const OptionalFilters* filters = nullptr);
199199

200-
SearchResult Search(const FieldIndices* index) const;
200+
// Search on given index with predefined limit for cutting off result ids
201+
SearchResult Search(const FieldIndices* index,
202+
size_t cuttoff_limit = std::numeric_limits<size_t>::max()) const;
201203

202204
// if enabled, return limit & alias for knn query
203205
std::optional<KnnScoreSortOption> GetKnnScoreSortOption() const;

src/server/search/doc_index.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,9 @@ vector<search::SortableValue> ShardDocIndex::KeepTopKSorted(vector<DocId>* ids,
400400
SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& params,
401401
search::SearchAlgorithm* search_algo) const {
402402
size_t limit = params.limit_offset + params.limit_total;
403-
auto result = search_algo->Search(&*indices_);
403+
bool can_cut_off = !params.sort_option && !search_algo->GetKnnScoreSortOption();
404+
size_t id_cutoff_limit = can_cut_off ? limit : numeric_limits<size_t>::max();
405+
auto result = search_algo->Search(&*indices_, id_cutoff_limit);
404406
if (!result.error.empty())
405407
return {facade::ErrorReply(std::move(result.error))};
406408

@@ -441,7 +443,10 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa
441443
}
442444

443445
// Cut off unnecessary items
444-
result.ids.resize(min(result.ids.size(), limit));
446+
size_t serialization_limit = limit;
447+
if (!search_algo->GetKnnScoreSortOption())
448+
serialization_limit = params.limit_serialization;
449+
result.ids.resize(min(result.ids.size(), serialization_limit));
445450

446451
// Serialize documents
447452
vector<SerializedSearchDoc> out;

src/server/search/doc_index.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ struct SearchParams {
9898
size_t limit_offset = 0;
9999
size_t limit_total = 10;
100100

101+
size_t limit_serialization = 0;
101102
/*
102103
1. If not set -> return all fields
103104
2. If set but empty -> no fields should be returned

src/server/search/search_family.cc

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -959,17 +959,27 @@ void PartialSort(absl::Span<SerializedSearchDoc*> docs, size_t limit, SortOrder
959959
partial_sort(docs.begin(), docs.begin() + min(limit, docs.size()), docs.end(), cb);
960960
}
961961

962-
void SearchReply(const SearchParams& params,
962+
bool SearchReply(const SearchParams& params,
963963
std::optional<search::KnnScoreSortOption> knn_sort_option,
964964
absl::Span<SearchResult> results, SinkReplyBuilder* builder) {
965+
// Count total number of hits
965966
size_t total_hits = 0;
967+
for (auto& shard_results : results)
968+
total_hits += shard_results.total_hits;
969+
970+
// Arrange documents in a stride
966971
absl::InlinedVector<SerializedSearchDoc*, 5> docs;
967972
docs.reserve(results.size());
968-
for (auto& shard_results : results) {
969-
total_hits += shard_results.total_hits;
970-
for (auto& doc : shard_results.docs) {
971-
docs.push_back(&doc);
973+
for (size_t i = 0;; ++i) {
974+
bool added = false;
975+
for (auto& shard_results : results) {
976+
if (i < shard_results.docs.size()) {
977+
added = true;
978+
docs.push_back(&shard_results.docs[i]);
979+
}
972980
}
981+
if (!added)
982+
break;
973983
}
974984

975985
// Reorder and cut KNN results before applying SORT and LIMIT
@@ -995,6 +1005,14 @@ void SearchReply(const SearchParams& params,
9951005
PartialSort(absl::MakeSpan(docs), end, params.sort_option->order,
9961006
&SerializedSearchDoc::sort_score);
9971007

1008+
// Check if we havent' chosen too few documents due to cutoffs
1009+
size_t left = docs.size() - params.limit_offset;
1010+
size_t expected = std::min(limit, total_hits - params.limit_offset);
1011+
if (left < expected)
1012+
return false;
1013+
1014+
// TODO: Check sort correctness
1015+
9981016
const bool reply_with_ids_only = params.IdsOnly();
9991017
auto* rb = static_cast<RedisReplyBuilder*>(builder);
10001018
RedisReplyBuilder::ArrayScope scope{rb, reply_with_ids_only ? (limit + 1) : (limit * 2 + 1)};
@@ -1011,6 +1029,7 @@ void SearchReply(const SearchParams& params,
10111029

10121030
SendSerializedDoc(*docs[i], builder);
10131031
}
1032+
return true;
10141033
}
10151034

10161035
// Warms up the query parser to avoid first-call slowness
@@ -1279,23 +1298,37 @@ void SearchFamily::FtSearch(CmdArgList args, const CommandContext& cmd_cntx) {
12791298
atomic<bool> index_not_found{false};
12801299
vector<SearchResult> docs(shard_set->size());
12811300

1282-
cmd_cntx.tx->ScheduleSingleHop([&](Transaction* t, EngineShard* es) {
1283-
if (auto* index = es->search_indices()->GetIndex(index_name); index)
1284-
docs[es->shard_id()] = index->Search(t->GetOpArgs(es), *params, &search_algo);
1301+
bool succeeded = true;
1302+
do {
1303+
if (succeeded)
1304+
params->limit_serialization =
1305+
params->limit_offset +
1306+
std::min(params->limit_total, 2 * params->limit_total / shard_set->size());
12851307
else
1286-
index_not_found.store(true, memory_order_relaxed);
1287-
return OpStatus::OK;
1288-
});
1308+
params->limit_serialization = params->limit_total + params->limit_offset;
12891309

1290-
if (index_not_found.load())
1291-
return builder->SendError(string{index_name} + ": no such index");
1310+
cmd_cntx.tx->ScheduleSingleHop([&](Transaction* t, EngineShard* es) {
1311+
if (auto* index = es->search_indices()->GetIndex(index_name); index)
1312+
docs[es->shard_id()] = index->Search(t->GetOpArgs(es), *params, &search_algo);
1313+
else
1314+
index_not_found.store(true, memory_order_relaxed);
1315+
return OpStatus::OK;
1316+
});
12921317

1293-
for (const auto& res : docs) {
1294-
if (res.error)
1295-
return builder->SendError(*res.error);
1296-
}
1318+
if (index_not_found.load())
1319+
return builder->SendError(string{index_name} + ": no such index");
1320+
1321+
for (const auto& res : docs) {
1322+
if (res.error)
1323+
return builder->SendError(*res.error);
1324+
}
12971325

1298-
SearchReply(*params, search_algo.GetKnnScoreSortOption(), absl::MakeSpan(docs), builder);
1326+
bool did_succeed =
1327+
SearchReply(*params, search_algo.GetKnnScoreSortOption(), absl::MakeSpan(docs), builder);
1328+
CHECK(did_succeed);
1329+
DCHECK(succeeded || did_succeed);
1330+
succeeded = did_succeed;
1331+
} while (!succeeded);
12991332
}
13001333

13011334
void SearchFamily::FtProfile(CmdArgList args, const CommandContext& cmd_cntx) {
@@ -1331,6 +1364,7 @@ void SearchFamily::FtProfile(CmdArgList args, const CommandContext& cmd_cntx) {
13311364
std::vector<SearchResult> search_results(shards_count);
13321365
std::vector<absl::Duration> profile_results(shards_count);
13331366

1367+
params->limit_serialization = params->limit_offset + params->limit_total;
13341368
cmd_cntx.tx->ScheduleSingleHop([&](Transaction* t, EngineShard* es) {
13351369
auto* index = es->search_indices()->GetIndex(index_name);
13361370
if (!index) {

0 commit comments

Comments
 (0)