1818
1919#include < algorithm>
2020
21+ #include < faiss/utils/Heap.h>
2122#include < faiss/utils/distances.h>
2223#include < faiss/utils/hamming.h> // BitstringWriter
2324#include < faiss/utils/utils.h>
2425
26+ extern " C" {
27+
28+ // general matrix multiplication
29+ int sgemm_ (
30+ const char * transa,
31+ const char * transb,
32+ FINTEGER* m,
33+ FINTEGER* n,
34+ FINTEGER* k,
35+ const float * alpha,
36+ const float * a,
37+ FINTEGER* lda,
38+ const float * b,
39+ FINTEGER* ldb,
40+ float * beta,
41+ float * c,
42+ FINTEGER* ldc);
43+ }
44+
2545namespace {
2646
2747// c and a and b can overlap
@@ -31,6 +51,12 @@ void fvec_add(size_t d, const float* a, const float* b, float* c) {
3151 }
3252}
3353
54+ void fvec_add (size_t d, const float * a, float b, float * c) {
55+ for (size_t i = 0 ; i < d; i++) {
56+ c[i] = a[i] + b;
57+ }
58+ }
59+
3460} // namespace
3561
3662namespace faiss {
@@ -48,6 +74,7 @@ void AdditiveQuantizer::set_derived_values() {
4874 is_byte_aligned = false ;
4975 }
5076 }
77+ total_codebook_size = codebook_offsets[M];
5178 // convert bits to bytes
5279 code_size = (tot_bits + 7 ) / 8 ;
5380}
@@ -93,4 +120,151 @@ void AdditiveQuantizer::decode(const uint8_t* code, float* x, size_t n) const {
93120
94121AdditiveQuantizer::~AdditiveQuantizer () {}
95122
123+ /* ***************************************************************************
124+ * Support for fast distance computations and search with additive quantizer
125+ ****************************************************************************/
126+
127+ void AdditiveQuantizer::compute_centroid_norms (float * norms) const {
128+ size_t ntotal = (size_t )1 << tot_bits;
129+ // TODO: make tree of partial sums
130+ #pragma omp parallel
131+ {
132+ std::vector<float > tmp (d);
133+ #pragma omp for
134+ for (int64_t i = 0 ; i < ntotal; i++) {
135+ decode_64bit (i, tmp.data ());
136+ norms[i] = fvec_norm_L2sqr (tmp.data (), d);
137+ }
138+ }
139+ }
140+
141+ void AdditiveQuantizer::decode_64bit (idx_t bits, float * xi) const {
142+ for (int m = 0 ; m < M; m++) {
143+ idx_t idx = bits & (((size_t )1 << nbits[m]) - 1 );
144+ bits >>= nbits[m];
145+ const float * c = codebooks.data () + d * (codebook_offsets[m] + idx);
146+ if (m == 0 ) {
147+ memcpy (xi, c, sizeof (*xi) * d);
148+ } else {
149+ fvec_add (d, xi, c, xi);
150+ }
151+ }
152+ }
153+
154+ void AdditiveQuantizer::compute_LUT (size_t n, const float * xq, float * LUT)
155+ const {
156+ // in all cases, it is large matrix multiplication
157+
158+ FINTEGER ncenti = total_codebook_size;
159+ FINTEGER di = d;
160+ FINTEGER nqi = n;
161+ float one = 1 , zero = 0 ;
162+
163+ sgemm_ (" Transposed" ,
164+ " Not transposed" ,
165+ &ncenti,
166+ &nqi,
167+ &di,
168+ &one,
169+ codebooks.data (),
170+ &di,
171+ xq,
172+ &di,
173+ &zero,
174+ LUT,
175+ &ncenti);
176+ }
177+
178+ namespace {
179+
180+ void compute_inner_prod_with_LUT (
181+ const AdditiveQuantizer& aq,
182+ const float * LUT,
183+ float * ips) {
184+ size_t prev_size = 1 ;
185+ for (int m = 0 ; m < aq.M ; m++) {
186+ const float * LUTm = LUT + aq.codebook_offsets [m];
187+ int nb = aq.nbits [m];
188+ size_t nc = (size_t )1 << nb;
189+
190+ if (m == 0 ) {
191+ memcpy (ips, LUT, sizeof (*ips) * nc);
192+ } else {
193+ for (int64_t i = nc - 1 ; i >= 0 ; i--) {
194+ float v = LUTm[i];
195+ fvec_add (prev_size, ips, v, ips + i * prev_size);
196+ }
197+ }
198+ prev_size *= nc;
199+ }
200+ }
201+
202+ } // anonymous namespace
203+
204+ void AdditiveQuantizer::knn_exact_inner_product (
205+ idx_t n,
206+ const float * xq,
207+ idx_t k,
208+ float * distances,
209+ idx_t * labels) const {
210+ std::unique_ptr<float []> LUT (new float [n * total_codebook_size]);
211+ compute_LUT (n, xq, LUT.get ());
212+ size_t ntotal = (size_t )1 << tot_bits;
213+
214+ #pragma omp parallel if (n > 100)
215+ {
216+ std::vector<float > dis (ntotal);
217+ #pragma omp for
218+ for (idx_t i = 0 ; i < n; i++) {
219+ const float * LUTi = LUT.get () + i * total_codebook_size;
220+ compute_inner_prod_with_LUT (*this , LUTi, dis.data ());
221+ float * distances_i = distances + i * k;
222+ idx_t * labels_i = labels + i * k;
223+ minheap_heapify (k, distances_i, labels_i);
224+ minheap_addn (k, distances_i, labels_i, dis.data (), nullptr , ntotal);
225+ minheap_reorder (k, distances_i, labels_i);
226+ }
227+ }
228+ }
229+
230+ void AdditiveQuantizer::knn_exact_L2 (
231+ idx_t n,
232+ const float * xq,
233+ idx_t k,
234+ float * distances,
235+ idx_t * labels,
236+ const float * norms) const {
237+ std::unique_ptr<float []> LUT (new float [n * total_codebook_size]);
238+ compute_LUT (n, xq, LUT.get ());
239+ std::unique_ptr<float []> q_norms (new float [n]);
240+ fvec_norms_L2sqr (q_norms.get (), xq, d, n);
241+ size_t ntotal = (size_t )1 << tot_bits;
242+
243+ #pragma omp parallel if (n > 100)
244+ {
245+ std::vector<float > dis (ntotal);
246+ #pragma omp for
247+ for (idx_t i = 0 ; i < n; i++) {
248+ const float * LUTi = LUT.get () + i * total_codebook_size;
249+ float * distances_i = distances + i * k;
250+ idx_t * labels_i = labels + i * k;
251+
252+ compute_inner_prod_with_LUT (*this , LUTi, dis.data ());
253+
254+ // update distances using
255+ // ||x - y||^2 = ||x||^2 + ||y||^2 - 2 * <x,y>
256+
257+ maxheap_heapify (k, distances_i, labels_i);
258+ for (idx_t j = 0 ; j < ntotal; j++) {
259+ float disj = q_norms[i] + norms[j] - 2 * dis[j];
260+ if (disj < distances_i[0 ]) {
261+ heap_replace_top<CMax<float , int64_t >>(
262+ k, distances_i, labels_i, disj, j);
263+ }
264+ }
265+ maxheap_reorder (k, distances_i, labels_i);
266+ }
267+ }
268+ }
269+
96270} // namespace faiss
0 commit comments