Skip to content

Commit 13df0cf

Browse files
rakrirakri
andauthored
Rakri/cosine bug fix (#450)
* compiles, but need to verify * fixed windows compiler warning * minor typo * added cosine unit test with unnormalized data * minor typo in user prompt cosine/l2 * cosine was already supported in groundtruth, edited the message to say so * clang-format --------- Co-authored-by: rakri <[email protected]>
1 parent 58de98d commit 13df0cf

File tree

8 files changed

+106
-39
lines changed

8 files changed

+106
-39
lines changed

.github/actions/generate-random/action.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,21 @@ runs:
99
1010
echo "Generating random vectors for index"
1111
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_10K_norm1.0.bin -D 10 -N 10000 --norm 1.0
12+
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_10K_unnorm.bin -D 10 -N 10000 --rand_scaling 2.0
1213
dist/bin/rand_data_gen --data_type int8 --output_file data/rand_int8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0
1314
dist/bin/rand_data_gen --data_type uint8 --output_file data/rand_uint8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0
1415
1516
echo "Generating random vectors for query"
1617
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_1K_norm1.0.bin -D 10 -N 1000 --norm 1.0
18+
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_1K_unnorm.bin -D 10 -N 1000 --rand_scaling 2.0
1719
dist/bin/rand_data_gen --data_type int8 --output_file data/rand_int8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0
1820
dist/bin/rand_data_gen --data_type uint8 --output_file data/rand_uint8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0
1921
2022
echo "Computing ground truth for floats across l2, mips, and cosine distance functions"
2123
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/rand_float_10D_10K_norm1.0.bin --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --K 100
2224
dist/bin/compute_groundtruth --data_type float --dist_fn mips --base_file data/rand_float_10D_10K_norm1.0.bin --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/mips_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --K 100
2325
dist/bin/compute_groundtruth --data_type float --dist_fn cosine --base_file data/rand_float_10D_10K_norm1.0.bin --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/cosine_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --K 100
26+
dist/bin/compute_groundtruth --data_type float --dist_fn cosine --base_file data/rand_float_10D_10K_unnorm.bin --query_file data/rand_float_10D_1K_unnorm.bin --gt_file data/cosine_rand_float_10D_10K_unnorm_10D_1K_unnorm_gt100 --K 100
2427
2528
echo "Computing ground truth for int8s across l2, mips, and cosine distance functions"
2629
dist/bin/compute_groundtruth --data_type int8 --dist_fn l2 --base_file data/rand_int8_10D_10K_norm50.0.bin --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100

.github/workflows/disk-pq.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ jobs:
3434
run: |
3535
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1
3636
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
37+
- name: build and search disk index (one shot graph build, cosine, no diskPQ) (float)
38+
if: success() || failure()
39+
run: |
40+
dist/bin/build_disk_index --data_type float --dist_fn cosine --data_path data/rand_float_10D_10K_unnorm.bin --index_path_prefix data/disk_index_cosine_rand_float_10D_10K_unnorm_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1
41+
dist/bin/search_disk_index --data_type float --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/disk_index_cosine_rand_float_10D_10K_unnorm_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_10D_1K_unnorm.bin --gt_file data/cosine_rand_float_10D_10K_unnorm_10D_1K_unnorm_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
3742
- name: build and search disk index (one shot graph build, L2, no diskPQ) (int8)
3843
if: success() || failure()
3944
run: |
@@ -66,6 +71,11 @@ jobs:
6671
run: |
6772
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006
6873
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_sharded --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
74+
- name: build and search disk index (sharded graph build, cosine, no diskPQ) (float)
75+
if: success() || failure()
76+
run: |
77+
dist/bin/build_disk_index --data_type float --dist_fn cosine --data_path data/rand_float_10D_10K_unnorm.bin --index_path_prefix data/disk_index_cosine_rand_float_10D_10K_unnorm_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006
78+
dist/bin/search_disk_index --data_type float --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/disk_index_cosine_rand_float_10D_10K_unnorm_diskfull_sharded --result_path /tmp/res --query_file data/rand_float_10D_1K_unnorm.bin --gt_file data/cosine_rand_float_10D_10K_unnorm_10D_1K_unnorm_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
6979
- name: build and search disk index (sharded graph build, L2, no diskPQ) (int8)
7080
run: |
7181
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006

apps/build_disk_index.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ int main(int argc, char **argv)
107107
metric = diskann::Metric::L2;
108108
else if (dist_fn == std::string("mips"))
109109
metric = diskann::Metric::INNER_PRODUCT;
110+
else if (dist_fn == std::string("cosine"))
111+
metric = diskann::Metric::COSINE;
110112
else
111113
{
112114
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;

apps/utils/compute_groundtruth.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,8 @@ int main(int argc, char **argv)
499499
desc.add_options()("help,h", "Print information on arguments");
500500

501501
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
502-
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(), "distance function <l2/mips>");
502+
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
503+
"distance function <l2/mips/cosine>");
503504
desc.add_options()("base_file", po::value<std::string>(&base_file)->required(),
504505
"File containing the base vectors in binary format");
505506
desc.add_options()("query_file", po::value<std::string>(&query_file)->required(),

apps/utils/rand_data_gen.cpp

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,31 @@
1111

1212
namespace po = boost::program_options;
1313

14-
int block_write_float(std::ofstream &writer, size_t ndims, size_t npts, float norm)
14+
int block_write_float(std::ofstream &writer, size_t ndims, size_t npts, bool normalization, float norm,
15+
float rand_scale)
1516
{
1617
auto vec = new float[ndims];
1718

1819
std::random_device rd{};
1920
std::mt19937 gen{rd()};
2021
std::normal_distribution<> normal_rand{0, 1};
22+
std::uniform_real_distribution<> unif_dis(1.0, rand_scale);
2123

2224
for (size_t i = 0; i < npts; i++)
2325
{
2426
float sum = 0;
27+
float scale = 1.0f;
28+
if (rand_scale > 1.0f)
29+
scale = (float)unif_dis(gen);
2530
for (size_t d = 0; d < ndims; ++d)
26-
vec[d] = (float)normal_rand(gen);
27-
for (size_t d = 0; d < ndims; ++d)
28-
sum += vec[d] * vec[d];
29-
for (size_t d = 0; d < ndims; ++d)
30-
vec[d] = vec[d] * norm / std::sqrt(sum);
31+
vec[d] = scale * (float)normal_rand(gen);
32+
if (normalization)
33+
{
34+
for (size_t d = 0; d < ndims; ++d)
35+
sum += vec[d] * vec[d];
36+
for (size_t d = 0; d < ndims; ++d)
37+
vec[d] = vec[d] * norm / std::sqrt(sum);
38+
}
3139

3240
writer.write((char *)vec, ndims * sizeof(float));
3341
}
@@ -104,8 +112,8 @@ int main(int argc, char **argv)
104112
{
105113
std::string data_type, output_file;
106114
size_t ndims, npts;
107-
float norm;
108-
115+
float norm, rand_scaling;
116+
bool normalization = false;
109117
try
110118
{
111119
po::options_description desc{"Arguments"};
@@ -117,7 +125,11 @@ int main(int argc, char **argv)
117125
"File name for saving the random vectors");
118126
desc.add_options()("ndims,D", po::value<uint64_t>(&ndims)->required(), "Dimensoinality of the vector");
119127
desc.add_options()("npts,N", po::value<uint64_t>(&npts)->required(), "Number of vectors");
120-
desc.add_options()("norm", po::value<float>(&norm)->required(), "Norm of the vectors");
128+
desc.add_options()("norm", po::value<float>(&norm)->default_value(-1.0f),
129+
"Norm of the vectors (if not specified, vectors are not normalized)");
130+
desc.add_options()("rand_scaling", po::value<float>(&rand_scaling)->default_value(1.0f),
131+
"Each vector will be scaled (if not explicitly normalized) by a factor randomly chosen from "
132+
"[1, rand_scale]. Only applicable for floating point data");
121133
po::variables_map vm;
122134
po::store(po::parse_command_line(argc, argv, desc), vm);
123135
if (vm.count("help"))
@@ -139,9 +151,20 @@ int main(int argc, char **argv)
139151
return -1;
140152
}
141153

142-
if (norm <= 0.0)
154+
if (norm > 0.0)
155+
{
156+
normalization = true;
157+
}
158+
159+
if (rand_scaling < 1.0)
160+
{
161+
std::cout << "We will only scale the vector norms randomly in [1, value], so value must be >= 1." << std::endl;
162+
return -1;
163+
}
164+
165+
if ((rand_scaling > 1.0) && (normalization == true))
143166
{
144-
std::cerr << "Error: Norm must be a positive number" << std::endl;
167+
std::cout << "Data cannot be normalized and randomly scaled at same time. Use one or the other." << std::endl;
145168
return -1;
146169
}
147170

@@ -155,6 +178,11 @@ int main(int argc, char **argv)
155178
<< std::endl;
156179
return -1;
157180
}
181+
if (rand_scaling > 1.0)
182+
{
183+
std::cout << "Data scaling only supported for floating point data." << std::endl;
184+
return -1;
185+
}
158186
}
159187

160188
try
@@ -177,7 +205,7 @@ int main(int argc, char **argv)
177205
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
178206
if (data_type == std::string("float"))
179207
{
180-
ret = block_write_float(writer, ndims, cblk_size, norm);
208+
ret = block_write_float(writer, ndims, cblk_size, normalization, norm, rand_scaling);
181209
}
182210
else if (data_type == std::string("int8"))
183211
{

src/disk_utils.cpp

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,11 +1129,12 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const
11291129
return -1;
11301130
}
11311131

1132-
if (!std::is_same<T, float>::value && compareMetric == diskann::Metric::INNER_PRODUCT)
1132+
if (!std::is_same<T, float>::value &&
1133+
(compareMetric == diskann::Metric::INNER_PRODUCT || compareMetric == diskann::Metric::COSINE))
11331134
{
11341135
std::stringstream stream;
1135-
stream << "DiskANN currently only supports floating point data for Max "
1136-
"Inner Product Search. "
1136+
stream << "Disk-index build currently only supports floating point data for Max "
1137+
"Inner Product Search/ cosine similarity. "
11371138
<< std::endl;
11381139
throw diskann::ANNException(stream.str(), -1);
11391140
}
@@ -1195,6 +1196,10 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const
11951196
std::string disk_pq_pivots_path = index_prefix_path + "_disk.index_pq_pivots.bin";
11961197
// optional, used if disk index must store pq data
11971198
std::string disk_pq_compressed_vectors_path = index_prefix_path + "_disk.index_pq_compressed.bin";
1199+
std::string prepped_base =
1200+
index_prefix_path +
1201+
"_prepped_base.bin"; // temp file for storing pre-processed base file for cosine/ mips metrics
1202+
bool created_temp_file_for_processed_data = false;
11981203

11991204
// output a new base file which contains extra dimension with sqrt(1 -
12001205
// ||x||^2/M^2) for every x, M is max norm of all points. Extra space on
@@ -1205,14 +1210,26 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const
12051210
std::cout << "Using Inner Product search, so need to pre-process base "
12061211
"data into temp file. Please ensure there is additional "
12071212
"(n*(d+1)*4) bytes for storing pre-processed base vectors, "
1208-
"apart from the intermin indices and final index."
1213+
"apart from the interim indices created by DiskANN and the final index."
12091214
<< std::endl;
1210-
std::string prepped_base = index_prefix_path + "_prepped_base.bin";
12111215
data_file_to_use = prepped_base;
12121216
float max_norm_of_base = diskann::prepare_base_for_inner_products<T>(base_file, prepped_base);
12131217
std::string norm_file = disk_index_path + "_max_base_norm.bin";
12141218
diskann::save_bin<float>(norm_file, &max_norm_of_base, 1, 1);
12151219
diskann::cout << timer.elapsed_seconds_for_step("preprocessing data for inner product") << std::endl;
1220+
created_temp_file_for_processed_data = true;
1221+
}
1222+
else if (compareMetric == diskann::Metric::COSINE)
1223+
{
1224+
Timer timer;
1225+
std::cout << "Normalizing data for cosine to temporary file, please ensure there is additional "
1226+
"(n*d*4) bytes for storing normalized base vectors, "
1227+
"apart from the interim indices created by DiskANN and the final index."
1228+
<< std::endl;
1229+
data_file_to_use = prepped_base;
1230+
diskann::normalize_data_file(base_file, prepped_base);
1231+
diskann::cout << timer.elapsed_seconds_for_step("preprocessing data for cosine") << std::endl;
1232+
created_temp_file_for_processed_data = true;
12161233
}
12171234

12181235
uint32_t R = (uint32_t)atoi(param_list[0].c_str());
@@ -1304,7 +1321,7 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const
13041321
#if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD)
13051322
MallocExtension::instance()->ReleaseFreeMemory();
13061323
#endif
1307-
1324+
// Whether it is cosine or inner product, we still L2 metric due to the pre-processing.
13081325
timer.reset();
13091326
diskann::build_merged_vamana_index<T, LabelT>(data_file_to_use.c_str(), diskann::Metric::L2, L, R, p_val,
13101327
indexing_ram_budget, mem_index_path, medoids_path, centroids_path,
@@ -1345,7 +1362,8 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const
13451362
std::remove(augmented_labels_file.c_str());
13461363
std::remove(labels_file_to_use.c_str());
13471364
}
1348-
1365+
if (created_temp_file_for_processed_data)
1366+
std::remove(prepped_base.c_str());
13491367
std::remove(mem_index_path.c_str());
13501368
if (use_disk_pq)
13511369
std::remove(disk_pq_compressed_vectors_path.c_str());

src/pq_flash_index.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,16 @@ template <typename T, typename LabelT>
3232
PQFlashIndex<T, LabelT>::PQFlashIndex(std::shared_ptr<AlignedFileReader> &fileReader, diskann::Metric m)
3333
: reader(fileReader), metric(m), _thread_data(nullptr)
3434
{
35+
diskann::Metric metric_to_invoke = m;
3536
if (m == diskann::Metric::COSINE || m == diskann::Metric::INNER_PRODUCT)
3637
{
3738
if (std::is_floating_point<T>::value)
3839
{
39-
diskann::cout << "Cosine metric chosen for (normalized) float data."
40-
"Changing distance to L2 to boost accuracy."
40+
diskann::cout << "Since data is floating point, we assume that it has been appropriately pre-processed "
41+
"(normalization for cosine, and convert-to-l2 by adding extra dimension for MIPS). So we "
42+
"shall invoke an l2 distance function."
4143
<< std::endl;
42-
metric = diskann::Metric::L2;
44+
metric_to_invoke = diskann::Metric::L2;
4345
}
4446
else
4547
{
@@ -49,8 +51,8 @@ PQFlashIndex<T, LabelT>::PQFlashIndex(std::shared_ptr<AlignedFileReader> &fileRe
4951
}
5052
}
5153

52-
this->_dist_cmp.reset(diskann::get_distance_function<T>(metric));
53-
this->_dist_cmp_float.reset(diskann::get_distance_function<float>(metric));
54+
this->_dist_cmp.reset(diskann::get_distance_function<T>(metric_to_invoke));
55+
this->_dist_cmp_float.reset(diskann::get_distance_function<float>(metric_to_invoke));
5456
}
5557

5658
template <typename T, typename LabelT> PQFlashIndex<T, LabelT>::~PQFlashIndex()
@@ -1292,20 +1294,23 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
12921294
float *query_float = pq_query_scratch->aligned_query_float;
12931295
float *query_rotated = pq_query_scratch->rotated_query;
12941296

1295-
// if inner product, we laso normalize the query and set the last coordinate
1296-
// to 0 (this is the extra coordindate used to convert MIPS to L2 search)
1297-
if (metric == diskann::Metric::INNER_PRODUCT)
1297+
// normalization step. for cosine, we simply normalize the query
1298+
// for mips, we normalize the first d-1 dims, and add a 0 for last dim, since an extra coordinate was used to
1299+
// convert MIPS to L2 search
1300+
if (metric == diskann::Metric::INNER_PRODUCT || metric == diskann::Metric::COSINE)
12981301
{
1299-
for (size_t i = 0; i < this->_data_dim - 1; i++)
1302+
uint64_t inherent_dim = (metric == diskann::Metric::COSINE) ? this->_data_dim : (uint64_t)(this->_data_dim - 1);
1303+
for (size_t i = 0; i < inherent_dim; i++)
13001304
{
13011305
aligned_query_T[i] = query1[i];
13021306
query_norm += query1[i] * query1[i];
13031307
}
1304-
aligned_query_T[this->_data_dim - 1] = 0;
1308+
if (metric == diskann::Metric::INNER_PRODUCT)
1309+
aligned_query_T[this->_data_dim - 1] = 0;
13051310

13061311
query_norm = std::sqrt(query_norm);
13071312

1308-
for (size_t i = 0; i < this->_data_dim - 1; i++)
1313+
for (size_t i = 0; i < inherent_dim; i++)
13091314
{
13101315
aligned_query_T[i] = (T)(aligned_query_T[i] / query_norm);
13111316
}

0 commit comments

Comments
 (0)