@@ -1129,11 +1129,12 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const
11291129 return -1 ;
11301130 }
11311131
1132- if (!std::is_same<T, float >::value && compareMetric == diskann::Metric::INNER_PRODUCT)
1132+ if (!std::is_same<T, float >::value &&
1133+ (compareMetric == diskann::Metric::INNER_PRODUCT || compareMetric == diskann::Metric::COSINE))
11331134 {
11341135 std::stringstream stream;
1135- stream << " DiskANN currently only supports floating point data for Max "
1136- " Inner Product Search. "
1136+ stream << " Disk-index build currently only supports floating point data for Max "
1137+ " Inner Product Search/ cosine similarity . "
11371138 << std::endl;
11381139 throw diskann::ANNException (stream.str (), -1 );
11391140 }
@@ -1195,6 +1196,10 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const
11951196 std::string disk_pq_pivots_path = index_prefix_path + " _disk.index_pq_pivots.bin" ;
11961197 // optional, used if disk index must store pq data
11971198 std::string disk_pq_compressed_vectors_path = index_prefix_path + " _disk.index_pq_compressed.bin" ;
1199+ std::string prepped_base =
1200+ index_prefix_path +
1201+ " _prepped_base.bin" ; // temp file for storing pre-processed base file for cosine/ mips metrics
1202+ bool created_temp_file_for_processed_data = false ;
11981203
11991204 // output a new base file which contains extra dimension with sqrt(1 -
12001205 // ||x||^2/M^2) for every x, M is max norm of all points. Extra space on
@@ -1205,14 +1210,26 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const
12051210 std::cout << " Using Inner Product search, so need to pre-process base "
12061211 " data into temp file. Please ensure there is additional "
12071212 " (n*(d+1)*4) bytes for storing pre-processed base vectors, "
1208- " apart from the intermin indices and final index."
1213+ " apart from the interim indices created by DiskANN and the final index."
12091214 << std::endl;
1210- std::string prepped_base = index_prefix_path + " _prepped_base.bin" ;
12111215 data_file_to_use = prepped_base;
12121216 float max_norm_of_base = diskann::prepare_base_for_inner_products<T>(base_file, prepped_base);
12131217 std::string norm_file = disk_index_path + " _max_base_norm.bin" ;
12141218 diskann::save_bin<float >(norm_file, &max_norm_of_base, 1 , 1 );
12151219 diskann::cout << timer.elapsed_seconds_for_step (" preprocessing data for inner product" ) << std::endl;
1220+ created_temp_file_for_processed_data = true ;
1221+ }
1222+ else if (compareMetric == diskann::Metric::COSINE)
1223+ {
1224+ Timer timer;
1225+ std::cout << " Normalizing data for cosine to temporary file, please ensure there is additional "
1226+ " (n*d*4) bytes for storing normalized base vectors, "
1227+ " apart from the interim indices created by DiskANN and the final index."
1228+ << std::endl;
1229+ data_file_to_use = prepped_base;
1230+ diskann::normalize_data_file (base_file, prepped_base);
1231+ diskann::cout << timer.elapsed_seconds_for_step (" preprocessing data for cosine" ) << std::endl;
1232+ created_temp_file_for_processed_data = true ;
12161233 }
12171234
12181235 uint32_t R = (uint32_t )atoi (param_list[0 ].c_str ());
@@ -1304,7 +1321,7 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const
13041321#if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD)
13051322 MallocExtension::instance ()->ReleaseFreeMemory ();
13061323#endif
1307-
1324+ // Whether it is cosine or inner product, we still L2 metric due to the pre-processing.
13081325 timer.reset ();
13091326 diskann::build_merged_vamana_index<T, LabelT>(data_file_to_use.c_str (), diskann::Metric::L2, L, R, p_val,
13101327 indexing_ram_budget, mem_index_path, medoids_path, centroids_path,
@@ -1345,7 +1362,8 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const
13451362 std::remove (augmented_labels_file.c_str ());
13461363 std::remove (labels_file_to_use.c_str ());
13471364 }
1348-
1365+ if (created_temp_file_for_processed_data)
1366+ std::remove (prepped_base.c_str ());
13491367 std::remove (mem_index_path.c_str ());
13501368 if (use_disk_pq)
13511369 std::remove (disk_pq_compressed_vectors_path.c_str ());
0 commit comments