Skip to content

Commit 58de98d

Browse files
authored
add16bytes tag type (#506)
* add 16 bytes tag type * clean up code * format doc * fix compile issue * fix compile issue * revert change * format doc * separate static search and streaming search * clean up code * resolve comment * format doc * fix test * resolve comment
1 parent 5cf0360 commit 58de98d

File tree

9 files changed

+169
-79
lines changed

9 files changed

+169
-79
lines changed

apps/search_memory_index.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
163163
for (int64_t i = 0; i < (int64_t)query_num; i++)
164164
{
165165
auto qs = std::chrono::high_resolution_clock::now();
166-
if (filtered_search)
166+
if (filtered_search && !tags)
167167
{
168168
std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i];
169169

@@ -179,8 +179,19 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
179179
}
180180
else if (tags)
181181
{
182-
index->search_with_tags(query + i * query_aligned_dim, recall_at, L,
183-
query_result_tags.data() + i * recall_at, nullptr, res);
182+
if (!filtered_search)
183+
{
184+
index->search_with_tags(query + i * query_aligned_dim, recall_at, L,
185+
query_result_tags.data() + i * recall_at, nullptr, res);
186+
}
187+
else
188+
{
189+
std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i];
190+
191+
index->search_with_tags(query + i * query_aligned_dim, recall_at, L,
192+
query_result_tags.data() + i * recall_at, nullptr, res, true, raw_filter);
193+
}
194+
184195
for (int64_t r = 0; r < (int64_t)recall_at; r++)
185196
{
186197
query_result_ids[test_id][recall_at * i + r] = query_result_tags[recall_at * i + r];

include/abstract_index.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ class AbstractIndex
6262
// Initialize space for res_vectors before calling.
6363
template <typename data_type, typename tag_type>
6464
size_t search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags,
65-
float *distances, std::vector<data_type *> &res_vectors);
65+
float *distances, std::vector<data_type *> &res_vectors, bool use_filters = false,
66+
const std::string filter_label = "");
6667

6768
// Added search overload that takes L as parameter, so that we
6869
// can customize L on a per-query basis without tampering with "Parameters"
@@ -120,7 +121,8 @@ class AbstractIndex
120121
virtual void _set_start_points_at_random(DataType radius, uint32_t random_seed = 0) = 0;
121122
virtual int _get_vector_by_tag(TagType &tag, DataType &vec) = 0;
122123
virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags,
123-
float *distances, DataVector &res_vectors) = 0;
124+
float *distances, DataVector &res_vectors, bool use_filters = false,
125+
const std::string filter_label = "") = 0;
124126
virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) = 0;
125127
virtual void _set_universal_label(const LabelType universal_label) = 0;
126128
};

include/index.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
136136

137137
// Initialize space for res_vectors before calling.
138138
DISKANN_DLLEXPORT size_t search_with_tags(const T *query, const uint64_t K, const uint32_t L, TagT *tags,
139-
float *distances, std::vector<T *> &res_vectors);
139+
float *distances, std::vector<T *> &res_vectors, bool use_filters = false,
140+
const std::string filter_label = "");
140141

141142
// Filter support search
142143
template <typename IndexType>
@@ -226,7 +227,8 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
226227
virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) override;
227228

228229
virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags,
229-
float *distances, DataVector &res_vectors) override;
230+
float *distances, DataVector &res_vectors, bool use_filters = false,
231+
const std::string filter_label = "") override;
230232

231233
virtual void _set_universal_label(const LabelType universal_label) override;
232234

include/natural_number_map.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@ template <typename Key, typename Value> class natural_number_map
2626
{
2727
public:
2828
static_assert(std::is_trivial<Key>::value, "Key must be a trivial type");
29-
// Some of the class member prototypes are done with this assumption to
30-
// minimize verbosity since it's the only use case.
31-
static_assert(std::is_trivial<Value>::value, "Value must be a trivial type");
3229

3330
// Represents a reference to a element in the map. Used while iterating
3431
// over map entries.

include/tag_uint128.h

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#pragma once
2+
#include <cstdint>
3+
#include <type_traits>
4+
5+
namespace diskann
6+
{
7+
#pragma pack(push, 1)
8+
9+
struct tag_uint128
10+
{
11+
std::uint64_t _data1 = 0;
12+
std::uint64_t _data2 = 0;
13+
14+
bool operator==(const tag_uint128 &other) const
15+
{
16+
return _data1 == other._data1 && _data2 == other._data2;
17+
}
18+
19+
bool operator==(std::uint64_t other) const
20+
{
21+
return _data1 == other && _data2 == 0;
22+
}
23+
24+
tag_uint128 &operator=(const tag_uint128 &other)
25+
{
26+
_data1 = other._data1;
27+
_data2 = other._data2;
28+
29+
return *this;
30+
}
31+
32+
tag_uint128 &operator=(std::uint64_t other)
33+
{
34+
_data1 = other;
35+
_data2 = 0;
36+
37+
return *this;
38+
}
39+
};
40+
41+
#pragma pack(pop)
42+
} // namespace diskann
43+
44+
namespace std
45+
{
46+
// Hash 128 input bits down to 64 bits of output.
47+
// This is intended to be a reasonably good hash function.
48+
inline std::uint64_t Hash128to64(const std::uint64_t &low, const std::uint64_t &high)
49+
{
50+
// Murmur-inspired hashing.
51+
const std::uint64_t kMul = 0x9ddfea08eb382d69ULL;
52+
std::uint64_t a = (low ^ high) * kMul;
53+
a ^= (a >> 47);
54+
std::uint64_t b = (high ^ a) * kMul;
55+
b ^= (b >> 47);
56+
b *= kMul;
57+
return b;
58+
}
59+
60+
template <> struct hash<diskann::tag_uint128>
61+
{
62+
size_t operator()(const diskann::tag_uint128 &key) const noexcept
63+
{
64+
return Hash128to64(key._data1, key._data2); // map -0 to 0
65+
}
66+
};
67+
68+
} // namespace std

include/utils.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ typedef int FileHandle;
2727
#include "windows_customizations.h"
2828
#include "tsl/robin_set.h"
2929
#include "types.h"
30+
#include "tag_uint128.h"
3031
#include <any>
3132

3233
#ifdef EXEC_ENV_OLS
@@ -1007,6 +1008,17 @@ void block_convert(std::ofstream &writr, std::ifstream &readr, float *read_buf,
10071008

10081009
DISKANN_DLLEXPORT void normalize_data_file(const std::string &inFileName, const std::string &outFileName);
10091010

1011+
inline std::string get_tag_string(std::uint64_t tag)
1012+
{
1013+
return std::to_string(tag);
1014+
}
1015+
1016+
inline std::string get_tag_string(const tag_uint128 &tag)
1017+
{
1018+
std::string str = std::to_string(tag._data2) + "_" + std::to_string(tag._data1);
1019+
return str;
1020+
}
1021+
10101022
}; // namespace diskann
10111023

10121024
struct PivotContainer

src/abstract_index.cpp

Lines changed: 35 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@ std::pair<uint32_t, uint32_t> AbstractIndex::search(const data_type *query, cons
2424

2525
template <typename data_type, typename tag_type>
2626
size_t AbstractIndex::search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags,
27-
float *distances, std::vector<data_type *> &res_vectors)
27+
float *distances, std::vector<data_type *> &res_vectors, bool use_filters,
28+
const std::string filter_label)
2829
{
2930
auto any_query = std::any(query);
3031
auto any_tags = std::any(tags);
3132
auto any_res_vectors = DataVector(res_vectors);
32-
return this->_search_with_tags(any_query, K, L, any_tags, distances, any_res_vectors);
33+
return this->_search_with_tags(any_query, K, L, any_tags, distances, any_res_vectors, use_filters, filter_label);
3334
}
3435

3536
template <typename IndexType>
@@ -162,61 +163,53 @@ template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search_w
162163
const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint64_t *indices,
163164
float *distances);
164165

165-
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, int32_t>(const float *query, const uint64_t K,
166-
const uint32_t L, int32_t *tags,
167-
float *distances,
168-
std::vector<float *> &res_vectors);
166+
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, int32_t>(
167+
const float *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances,
168+
std::vector<float *> &res_vectors, bool use_filters, const std::string filter_label);
169169

170-
template DISKANN_DLLEXPORT size_t
171-
AbstractIndex::search_with_tags<uint8_t, int32_t>(const uint8_t *query, const uint64_t K, const uint32_t L,
172-
int32_t *tags, float *distances, std::vector<uint8_t *> &res_vectors);
170+
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<uint8_t, int32_t>(
171+
const uint8_t *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances,
172+
std::vector<uint8_t *> &res_vectors, bool use_filters, const std::string filter_label);
173173

174-
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, int32_t>(const int8_t *query,
175-
const uint64_t K, const uint32_t L,
176-
int32_t *tags, float *distances,
177-
std::vector<int8_t *> &res_vectors);
174+
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, int32_t>(
175+
const int8_t *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances,
176+
std::vector<int8_t *> &res_vectors, bool use_filters, const std::string filter_label);
178177

179-
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, uint32_t>(const float *query, const uint64_t K,
180-
const uint32_t L, uint32_t *tags,
181-
float *distances,
182-
std::vector<float *> &res_vectors);
178+
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, uint32_t>(
179+
const float *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances,
180+
std::vector<float *> &res_vectors, bool use_filters, const std::string filter_label);
183181

184182
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<uint8_t, uint32_t>(
185183
const uint8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances,
186-
std::vector<uint8_t *> &res_vectors);
184+
std::vector<uint8_t *> &res_vectors, bool use_filters, const std::string filter_label);
187185

188-
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, uint32_t>(const int8_t *query,
189-
const uint64_t K, const uint32_t L,
190-
uint32_t *tags, float *distances,
191-
std::vector<int8_t *> &res_vectors);
186+
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, uint32_t>(
187+
const int8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances,
188+
std::vector<int8_t *> &res_vectors, bool use_filters, const std::string filter_label);
192189

193-
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, int64_t>(const float *query, const uint64_t K,
194-
const uint32_t L, int64_t *tags,
195-
float *distances,
196-
std::vector<float *> &res_vectors);
190+
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, int64_t>(
191+
const float *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances,
192+
std::vector<float *> &res_vectors, bool use_filters, const std::string filter_label);
197193

198-
template DISKANN_DLLEXPORT size_t
199-
AbstractIndex::search_with_tags<uint8_t, int64_t>(const uint8_t *query, const uint64_t K, const uint32_t L,
200-
int64_t *tags, float *distances, std::vector<uint8_t *> &res_vectors);
194+
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<uint8_t, int64_t>(
195+
const uint8_t *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances,
196+
std::vector<uint8_t *> &res_vectors, bool use_filters, const std::string filter_label);
201197

202-
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, int64_t>(const int8_t *query,
203-
const uint64_t K, const uint32_t L,
204-
int64_t *tags, float *distances,
205-
std::vector<int8_t *> &res_vectors);
198+
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, int64_t>(
199+
const int8_t *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances,
200+
std::vector<int8_t *> &res_vectors, bool use_filters, const std::string filter_label);
206201

207-
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, uint64_t>(const float *query, const uint64_t K,
208-
const uint32_t L, uint64_t *tags,
209-
float *distances,
210-
std::vector<float *> &res_vectors);
202+
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, uint64_t>(
203+
const float *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances,
204+
std::vector<float *> &res_vectors, bool use_filters, const std::string filter_label);
211205

212206
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<uint8_t, uint64_t>(
213207
const uint8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances,
214-
std::vector<uint8_t *> &res_vectors);
208+
std::vector<uint8_t *> &res_vectors, bool use_filters, const std::string filter_label);
215209

216-
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, uint64_t>(const int8_t *query,
217-
const uint64_t K, const uint32_t L,
218-
uint64_t *tags, float *distances,
219-
std::vector<int8_t *> &res_vectors);
210+
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, uint64_t>(
211+
const int8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances,
212+
std::vector<int8_t *> &res_vectors, bool use_filters, const std::string filter_label);
220213

221214
template DISKANN_DLLEXPORT void AbstractIndex::search_with_optimized_layout<float>(const float *query, size_t K,
222215
size_t L, uint32_t *indices);

0 commit comments

Comments
 (0)