Skip to content

Commit 4a57e89

Browse files
authored
Adding Filtered Index support to Python bindings (#482)
* Halfway approach to the new indexfactory, but it doesn't have the same featureset as the old way. Committing this for posterity but reverting my changes ultimately * Revert "Halfway approach to the new indexfactory, but it doesn't have the same featureset as the old way. Committing this for posterity but reverting my changes ultimately" This reverts commit 03dccb5. * Adding filtered search. API is going to change still. * Further enhancements to the new filter capability in the static memory index. * Ran automatic formatting * Fixing my logic and ensuring the unit tests pass. * Setting this up as a rc build first * list[list[Hashable]] -> list[list[str]] * Adding halfway to a solution where we query for more items than exist in the filter set. We need to replicate this behavior across all indices though - dynamic, static disk and memory w/o filters, etc * Removing the import of Hashable too
1 parent 179927e commit 4a57e89

13 files changed

+290
-39
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ build-backend = "setuptools.build_meta"
1111

1212
[project]
1313
name = "diskannpy"
14-
version = "0.6.1"
14+
version = "0.7.0rc1"
1515

1616
description = "DiskANN Python extension module"
1717
readme = "python/README.md"

python/include/builder.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ template <typename DT, typename TagT = DynamicIdType, typename LabelT = filterT>
2020
void build_memory_index(diskann::Metric metric, const std::string &vector_bin_path,
2121
const std::string &index_output_path, uint32_t graph_degree, uint32_t complexity,
2222
float alpha, uint32_t num_threads, bool use_pq_build,
23-
size_t num_pq_bytes, bool use_opq, uint32_t filter_complexity,
24-
bool use_tags = false);
23+
size_t num_pq_bytes, bool use_opq, bool use_tags = false,
24+
const std::string& filter_labels_file = "", const std::string& universal_label = "",
25+
uint32_t filter_complexity = 0);
2526

2627
}

python/include/static_memory_index.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ template <typename DT> class StaticMemoryIndex
2626
NeighborsAndDistances<StaticIdType> search(py::array_t<DT, py::array::c_style | py::array::forcecast> &query,
2727
uint64_t knn, uint64_t complexity);
2828

29+
NeighborsAndDistances<StaticIdType> search_with_filter(
30+
py::array_t<DT, py::array::c_style | py::array::forcecast> &query, uint64_t knn, uint64_t complexity,
31+
filterT filter);
32+
2933
NeighborsAndDistances<StaticIdType> batch_search(
3034
py::array_t<DT, py::array::c_style | py::array::forcecast> &queries, uint64_t num_queries, uint64_t knn,
3135
uint64_t complexity, uint32_t num_threads);

python/src/_builder.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT license.
33

4+
import json
45
import os
56
import shutil
67
from pathlib import Path
@@ -174,8 +175,10 @@ def build_memory_index(
174175
num_pq_bytes: int = defaults.NUM_PQ_BYTES,
175176
use_opq: bool = defaults.USE_OPQ,
176177
vector_dtype: Optional[VectorDType] = None,
177-
filter_complexity: int = defaults.FILTER_COMPLEXITY,
178178
tags: Union[str, VectorIdentifierBatch] = "",
179+
filter_labels: Optional[list[list[str]]] = None,
180+
universal_label: str = "",
181+
filter_complexity: int = defaults.FILTER_COMPLEXITY,
179182
index_prefix: str = "ann",
180183
) -> None:
181184
"""
@@ -223,10 +226,20 @@ def build_memory_index(
223226
Default is `0`.
224227
- **use_opq**: Use optimized product quantization during build.
225228
- **vector_dtype**: Required if the provided `data` is of type `str`, else we use the `data.dtype` if np array.
226-
- **filter_complexity**: Complexity to use when using filters. Default is 0.
227-
- **tags**: A `str` representing a path to a pre-built tags file on disk, or a `numpy.ndarray` of uint32 ids
228-
corresponding to the ordinal position of the vectors provided to build the index. Defaults to "". **This value
229-
must be provided if you want to build a memory index intended for use with `diskannpy.DynamicMemoryIndex`**.
229+
- **tags**: Tags can be defined either as a path on disk to an existing .tags file, or provided as a np.array of
230+
the same length as the number of vectors. Tags are used to identify vectors in the index via your *own*
231+
numbering conventions, and is absolutely required for loading DynamicMemoryIndex indices `from_file`.
232+
- **filter_labels**: An optional, but exhaustive list of categories for each vector. This is used to filter
233+
search results by category. If provided, this must be a list of lists, where each inner list is a list of
234+
categories for the corresponding vector. For example, if you have 3 vectors, and the first vector belongs to
235+
categories "a" and "b", the second vector belongs to category "b", and the third vector belongs to no categories,
236+
you would provide `filter_labels=[["a", "b"], ["b"], []]`. If you do not want to provide categories for a
237+
particular vector, you can provide an empty list. If you do not want to provide categories for any vectors,
238+
you can provide `None` for this parameter (which is the default)
239+
- **universal_label**: An optional label that indicates that this vector should be included in *every* search
240+
in which it also meets the knn search criteria.
241+
- **filter_complexity**: Complexity to use when using filters. Default is 0. 0 is strictly invalid if you are
242+
using filters.
230243
- **index_prefix**: The prefix of the index files. Defaults to "ann".
231244
"""
232245
_assert(
@@ -245,6 +258,10 @@ def build_memory_index(
245258
_assert_is_nonnegative_uint32(num_pq_bytes, "num_pq_bytes")
246259
_assert_is_nonnegative_uint32(filter_complexity, "filter_complexity")
247260
_assert(index_prefix != "", "index_prefix cannot be an empty string")
261+
_assert(
262+
filter_labels is None or filter_complexity > 0,
263+
"if filter_labels is provided, filter_complexity must not be 0"
264+
)
248265

249266
index_path = Path(index_directory)
250267
_assert(
@@ -262,6 +279,11 @@ def build_memory_index(
262279
)
263280

264281
num_points, dimensions = vectors_metadata_from_file(vector_bin_path)
282+
if filter_labels is not None:
283+
_assert(
284+
len(filter_labels) == num_points,
285+
"filter_labels must be the same length as the number of points"
286+
)
265287

266288
if vector_dtype_actual == np.uint8:
267289
_builder = _native_dap.build_memory_uint8_index
@@ -272,6 +294,21 @@ def build_memory_index(
272294

273295
index_prefix_path = os.path.join(index_directory, index_prefix)
274296

297+
filter_labels_file = ""
298+
if filter_labels is not None:
299+
label_counts = {}
300+
filter_labels_file = f"{index_prefix_path}_pylabels.txt"
301+
with open(filter_labels_file, "w") as labels_file:
302+
for labels in filter_labels:
303+
for label in labels:
304+
label_counts[label] = 1 if label not in label_counts else label_counts[label] + 1
305+
if len(labels) == 0:
306+
print("default", file=labels_file)
307+
else:
308+
print(",".join(labels), file=labels_file)
309+
with open(f"{index_prefix_path}_label_metadata.json", "w") as label_metadata_file:
310+
json.dump(label_counts, label_metadata_file, indent=True)
311+
275312
if isinstance(tags, str) and tags != "":
276313
use_tags = True
277314
shutil.copy(tags, index_prefix_path + ".tags")
@@ -299,8 +336,10 @@ def build_memory_index(
299336
use_pq_build=use_pq_build,
300337
num_pq_bytes=num_pq_bytes,
301338
use_opq=use_opq,
302-
filter_complexity=filter_complexity,
303339
use_tags=use_tags,
340+
filter_labels_file=filter_labels_file,
341+
universal_label=universal_label,
342+
filter_complexity=filter_complexity,
304343
)
305344

306345
_write_index_metadata(

python/src/_builder.pyi

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ def build_memory_index(
4747
use_pq_build: bool,
4848
num_pq_bytes: int,
4949
use_opq: bool,
50-
label_file: str,
50+
tags: Union[str, VectorIdentifierBatch],
51+
filter_labels: Optional[list[list[str]]],
5152
universal_label: str,
5253
filter_complexity: int,
53-
tags: Optional[VectorIdentifierBatch],
54-
index_prefix: str,
54+
index_prefix: str
5555
) -> None: ...
5656
@overload
5757
def build_memory_index(
@@ -66,9 +66,9 @@ def build_memory_index(
6666
num_pq_bytes: int,
6767
use_opq: bool,
6868
vector_dtype: VectorDType,
69-
label_file: str,
69+
tags: Union[str, VectorIdentifierBatch],
70+
filter_labels_file: Optional[list[list[str]]],
7071
universal_label: str,
7172
filter_complexity: int,
72-
tags: Optional[str],
73-
index_prefix: str,
73+
index_prefix: str
7474
) -> None: ...

python/src/_common.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def _ensure_index_metadata(
211211
distance_metric: Optional[DistanceMetric],
212212
max_vectors: int,
213213
dimensions: Optional[int],
214+
warn_size_exceeded: bool = False,
214215
) -> Tuple[VectorDType, str, np.uint64, np.uint64]:
215216
possible_metadata = _read_index_metadata(index_path_and_prefix)
216217
if possible_metadata is None:
@@ -226,16 +227,17 @@ def _ensure_index_metadata(
226227
return vector_dtype, distance_metric, max_vectors, dimensions # type: ignore
227228
else:
228229
vector_dtype, distance_metric, num_vectors, dimensions = possible_metadata
229-
if max_vectors is not None and num_vectors > max_vectors:
230-
warnings.warn(
231-
"The number of vectors in the saved index exceeds the max_vectors parameter. "
232-
"max_vectors is being adjusted to accommodate the dataset, but any insertions will fail."
233-
)
234-
max_vectors = num_vectors
235-
if num_vectors == max_vectors:
236-
warnings.warn(
237-
"The number of vectors in the saved index equals max_vectors parameter. Any insertions will fail."
238-
)
230+
if warn_size_exceeded:
231+
if max_vectors is not None and num_vectors > max_vectors:
232+
warnings.warn(
233+
"The number of vectors in the saved index exceeds the max_vectors parameter. "
234+
"max_vectors is being adjusted to accommodate the dataset, but any insertions will fail."
235+
)
236+
max_vectors = num_vectors
237+
if num_vectors == max_vectors:
238+
warnings.warn(
239+
"The number of vectors in the saved index equals max_vectors parameter. Any insertions will fail."
240+
)
239241
return possible_metadata
240242

241243

python/src/_dynamic_memory_index.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def from_file(
144144
f"The file {tags_file} does not exist in {index_directory}",
145145
)
146146
vector_dtype, dap_metric, num_vectors, dimensions = _ensure_index_metadata(
147-
index_prefix_path, vector_dtype, distance_metric, max_vectors, dimensions
147+
index_prefix_path, vector_dtype, distance_metric, max_vectors, dimensions, warn_size_exceeded=True
148148
)
149149

150150
index = cls(

python/src/_static_memory_index.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT license.
33

4+
import json
45
import os
56
import warnings
67
from typing import Optional
@@ -43,6 +44,7 @@ def __init__(
4344
distance_metric: Optional[DistanceMetric] = None,
4445
vector_dtype: Optional[VectorDType] = None,
4546
dimensions: Optional[int] = None,
47+
enable_filters: bool = False
4648
):
4749
"""
4850
### Parameters
@@ -73,8 +75,22 @@ def __init__(
7375
- **dimensions**: The vector dimensionality of this index. All new vectors inserted must be the same
7476
dimensionality. **This value is only used if a `{index_prefix}_metadata.bin` file does not exist.** If it
7577
does not exist, you are required to provide it.
78+
- **enable_filters**: Indexes built with filters can also be used for filtered search.
7679
"""
7780
index_prefix = _valid_index_prefix(index_directory, index_prefix)
81+
self._labels_map = {}
82+
self._labels_metadata = {}
83+
if enable_filters:
84+
try:
85+
with open(index_prefix + "_labels_map.txt", "r") as labels_map_if:
86+
for line in labels_map_if:
87+
(key, val) = line.split("\t")
88+
self._labels_map[key] = int(val)
89+
with open(f"{index_prefix}_label_metadata.json", "r") as labels_metadata_if:
90+
self._labels_metadata = json.load(labels_metadata_if)
91+
except: # noqa: E722
92+
# exceptions are basically presumed to be either file not found or file not formatted correctly
93+
raise RuntimeException("Filter labels file was unable to be processed.")
7894
vector_dtype, metric, num_points, dims = _ensure_index_metadata(
7995
index_prefix,
8096
vector_dtype,
@@ -109,7 +125,7 @@ def __init__(
109125
)
110126

111127
def search(
112-
self, query: VectorLike, k_neighbors: int, complexity: int
128+
self, query: VectorLike, k_neighbors: int, complexity: int, filter_label: str = ""
113129
) -> QueryResponse:
114130
"""
115131
Searches the index by a single query vector.
@@ -121,13 +137,25 @@ def search(
121137
- **complexity**: Size of distance ordered list of candidate neighbors to use while searching. List size
122138
increases accuracy at the cost of latency. Must be at least k_neighbors in size.
123139
"""
140+
if filter_label != "":
141+
if len(self._labels_map) == 0:
142+
raise ValueError(
143+
f"A filter label of {filter_label} was provided, but this class was not initialized with filters "
144+
"enabled, e.g. StaticDiskMemory(..., enable_filters=True)"
145+
)
146+
if filter_label not in self._labels_map:
147+
raise ValueError(
148+
f"A filter label of {filter_label} was provided, but the external(str)->internal(np.uint32) labels map "
149+
f"does not include that label."
150+
)
151+
k_neighbors = min(k_neighbors, self._labels_metadata[filter_label])
124152
_query = _castable_dtype_or_raise(query, expected=self._vector_dtype)
125153
_assert(len(_query.shape) == 1, "query vector must be 1-d")
126154
_assert(
127155
_query.shape[0] == self._dimensions,
128156
f"query vector must have the same dimensionality as the index; index dimensionality: {self._dimensions}, "
129157
f"query dimensionality: {_query.shape[0]}",
130-
)
158+
)
131159
_assert_is_positive_uint32(k_neighbors, "k_neighbors")
132160
_assert_is_nonnegative_uint32(complexity, "complexity")
133161

@@ -136,9 +164,20 @@ def search(
136164
f"k_neighbors={k_neighbors} asked for, but list_size={complexity} was smaller. Increasing {complexity} to {k_neighbors}"
137165
)
138166
complexity = k_neighbors
139-
neighbors, distances = self._index.search(query=_query, knn=k_neighbors, complexity=complexity)
167+
168+
if filter_label == "":
169+
neighbors, distances = self._index.search(query=_query, knn=k_neighbors, complexity=complexity)
170+
else:
171+
filter = self._labels_map[filter_label]
172+
neighbors, distances = self._index.search_with_filter(
173+
query=query,
174+
knn=k_neighbors,
175+
complexity=complexity,
176+
filter=filter
177+
)
140178
return QueryResponse(identifiers=neighbors, distances=distances)
141179

180+
142181
def batch_search(
143182
self,
144183
queries: VectorLikeBatch,

python/src/builder.cpp

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,37 @@ template void build_disk_index<uint8_t>(diskann::Metric, const std::string &, co
3131
template void build_disk_index<int8_t>(diskann::Metric, const std::string &, const std::string &, uint32_t, uint32_t,
3232
double, double, uint32_t, uint32_t);
3333

34+
template <typename T, typename TagT, typename LabelT>
35+
std::string prepare_filtered_label_map(diskann::Index<T, TagT, LabelT> &index, const std::string &index_output_path,
36+
const std::string &filter_labels_file, const std::string &universal_label)
37+
{
38+
std::string labels_file_to_use = index_output_path + "_label_formatted.txt";
39+
std::string mem_labels_int_map_file = index_output_path + "_labels_map.txt";
40+
convert_labels_string_to_int(filter_labels_file, labels_file_to_use, mem_labels_int_map_file, universal_label);
41+
if (!universal_label.empty())
42+
{
43+
uint32_t unv_label_as_num = 0;
44+
index.set_universal_label(unv_label_as_num);
45+
}
46+
return labels_file_to_use;
47+
}
48+
49+
template std::string prepare_filtered_label_map<float>(diskann::Index<float, uint32_t, uint32_t> &, const std::string &,
50+
const std::string &, const std::string &);
51+
52+
template std::string prepare_filtered_label_map<int8_t>(diskann::Index<int8_t, uint32_t, uint32_t> &,
53+
const std::string &, const std::string &, const std::string &);
54+
55+
template std::string prepare_filtered_label_map<uint8_t>(diskann::Index<uint8_t, uint32_t, uint32_t> &,
56+
const std::string &, const std::string &, const std::string &);
57+
3458
template <typename T, typename TagT, typename LabelT>
3559
void build_memory_index(const diskann::Metric metric, const std::string &vector_bin_path,
3660
const std::string &index_output_path, const uint32_t graph_degree, const uint32_t complexity,
3761
const float alpha, const uint32_t num_threads, const bool use_pq_build,
38-
const size_t num_pq_bytes, const bool use_opq, const uint32_t filter_complexity,
39-
const bool use_tags)
62+
const size_t num_pq_bytes, const bool use_opq, const bool use_tags,
63+
const std::string &filter_labels_file, const std::string &universal_label,
64+
const uint32_t filter_complexity)
4065
{
4166
diskann::IndexWriteParameters index_build_params = diskann::IndexWriteParametersBuilder(complexity, graph_degree)
4267
.with_filter_list_size(filter_complexity)
@@ -65,23 +90,44 @@ void build_memory_index(const diskann::Metric metric, const std::string &vector_
6590
size_t tag_dims = 1;
6691
diskann::load_bin(tags_file, tags_data, data_num, tag_dims);
6792
std::vector<TagT> tags(tags_data, tags_data + data_num);
68-
index.build(vector_bin_path.c_str(), data_num, tags);
93+
if (filter_labels_file.empty())
94+
{
95+
index.build(vector_bin_path.c_str(), data_num, tags);
96+
}
97+
else
98+
{
99+
auto labels_file = prepare_filtered_label_map<T, TagT, LabelT>(index, index_output_path, filter_labels_file,
100+
universal_label);
101+
index.build_filtered_index(vector_bin_path.c_str(), labels_file, data_num, tags);
102+
}
69103
}
70104
else
71105
{
72-
index.build(vector_bin_path.c_str(), data_num);
106+
if (filter_labels_file.empty())
107+
{
108+
index.build(vector_bin_path.c_str(), data_num);
109+
}
110+
else
111+
{
112+
auto labels_file = prepare_filtered_label_map<T, TagT, LabelT>(index, index_output_path, filter_labels_file,
113+
universal_label);
114+
index.build_filtered_index(vector_bin_path.c_str(), labels_file, data_num);
115+
}
73116
}
74117

75118
index.save(index_output_path.c_str());
76119
}
77120

78121
template void build_memory_index<float>(diskann::Metric, const std::string &, const std::string &, uint32_t, uint32_t,
79-
float, uint32_t, bool, size_t, bool, uint32_t, bool);
122+
float, uint32_t, bool, size_t, bool, bool, const std::string &,
123+
const std::string &, uint32_t);
80124

81125
template void build_memory_index<int8_t>(diskann::Metric, const std::string &, const std::string &, uint32_t, uint32_t,
82-
float, uint32_t, bool, size_t, bool, uint32_t, bool);
126+
float, uint32_t, bool, size_t, bool, bool, const std::string &,
127+
const std::string &, uint32_t);
83128

84129
template void build_memory_index<uint8_t>(diskann::Metric, const std::string &, const std::string &, uint32_t, uint32_t,
85-
float, uint32_t, bool, size_t, bool, uint32_t, bool);
130+
float, uint32_t, bool, size_t, bool, bool, const std::string &,
131+
const std::string &, uint32_t);
86132

87133
} // namespace diskannpy

python/src/diskann_bindings.cpp

Lines changed: 0 additions & 1 deletion
This file was deleted.

0 commit comments

Comments
 (0)