|
1 | 1 | # Copyright (c) Microsoft Corporation. All rights reserved. |
2 | 2 | # Licensed under the MIT license. |
3 | 3 |
|
| 4 | +import os |
4 | 5 | import warnings |
5 | 6 |
|
6 | 7 | import numpy as np |
|
21 | 22 | _assert, |
22 | 23 | _assert_2d, |
23 | 24 | _assert_dtype, |
| 25 | + _assert_existing_directory, |
24 | 26 | _assert_is_nonnegative_uint32, |
25 | 27 | _assert_is_positive_uint32, |
26 | 28 | _castable_dtype_or_raise, |
27 | 29 | _ensure_index_metadata, |
28 | 30 | _valid_metric, |
29 | 31 | _valid_index_prefix, |
| 32 | + _write_index_metadata |
30 | 33 | ) |
31 | 34 | from ._diskannpy import defaults |
32 | 35 |
|
@@ -158,6 +161,7 @@ def __init__( |
158 | 161 | """ |
159 | 162 |
|
160 | 163 | dap_metric = _valid_metric(distance_metric) |
| 164 | + self._dap_metric = dap_metric |
161 | 165 | _assert_dtype(vector_dtype) |
162 | 166 | _assert_is_positive_uint32(dimensions, "dimensions") |
163 | 167 |
|
@@ -199,6 +203,7 @@ def __init__( |
199 | 203 | search_threads=search_threads, |
200 | 204 | concurrent_consolidation=concurrent_consolidation |
201 | 205 | ) |
| 206 | + self._points_deleted = False |
202 | 207 |
|
203 | 208 | def search( |
204 | 209 | self, query: VectorLike, k_neighbors: int, complexity: int |
@@ -293,16 +298,31 @@ def batch_search( |
293 | 298 | num_threads=num_threads, |
294 | 299 | ) |
295 | 300 |
|
296 | | - def save(self, save_path: str, compact_before_save: bool = True): |
| 301 | + def save(self, save_path: str, index_prefix: str = "ann"): |
297 | 302 | """ |
298 | 303 | Saves this index to file. |
299 | 304 | :param save_path: The path to save these index files to. |
300 | 305 | :type save_path: str |
301 | | - :param compact_before_save: |
| 306 | + :param index_prefix: The prefix to use for the index files. Default is "ann". |
| 307 | + :type index_prefix: str |
302 | 308 | """ |
303 | 309 | if save_path == "": |
304 | 310 | raise ValueError("save_path cannot be empty") |
305 | | - self._index.save(save_path=save_path, compact_before_save=compact_before_save) |
| 311 | + if index_prefix == "": |
| 312 | + raise ValueError("index_prefix cannot be empty") |
| 313 | + _assert_existing_directory(save_path, "save_path") |
| 314 | + save_path = os.path.join(save_path, index_prefix) |
| 315 | + if self._points_deleted is True: |
| 316 | + warnings.warn( |
| 317 | + "DynamicMemoryIndex.save() currently requires DynamicMemoryIndex.consolidate_delete() to be called " |
| 318 | + "prior to save when items have been marked for deletion. This is being done automatically now, though" |
| 319 | + "it will increase the time it takes to save; on large sets of data it can take a substantial amount of " |
| 320 | + "time. In the future, we will implement a faster save with unconsolidated deletes, but for now this is " |
| 321 | + "required." |
| 322 | + ) |
| 323 | + self._index.consolidate_delete() |
| 324 | + self._index.save(save_path=save_path, compact_before_save=True) # we do not yet support uncompacted saves |
| 325 | + _write_index_metadata(save_path, self._vector_dtype, self._dap_metric, self._index.num_points(), self._dimensions) |
306 | 326 |
|
307 | 327 | def insert(self, vector: VectorLike, vector_id: VectorIdentifier): |
308 | 328 | """ |
@@ -349,10 +369,12 @@ def mark_deleted(self, vector_id: VectorIdentifier): |
349 | 369 | :type vector_id: int |
350 | 370 | """ |
351 | 371 | _assert_is_positive_uint32(vector_id, "vector_id") |
| 372 | + self._points_deleted = True |
352 | 373 | self._index.mark_deleted(np.uintc(vector_id)) |
353 | 374 |
|
354 | 375 | def consolidate_delete(self): |
355 | 376 | """ |
356 | 377 | This method actually restructures the DiskANN index to remove the items that have been marked for deletion. |
357 | 378 | """ |
358 | 379 | self._index.consolidate_delete() |
| 380 | + self._points_deleted = False |
0 commit comments