Skip to content

Commit 8eab15e

Browse files
mdouzefacebook-github-bot
authored andcommitted
LUT based search for additive quantizers (#1908)
Summary: Pull Request resolved: #1908 To search the best combination of codebooks, the method that was implemented so far is via a beam search. It is possible to make this faster for a query vector q by precomputing look-up tables in the form of LUT_m = <q, cent_m> where cent_m is the set of centroids for quantizer m=0..M-1. The LUT can then be used as inner_prod = sum_m LUT_m[c_m] and L2_distance = norm_q + norm_db - 2 * inner_prod This diff implements this computation by: - adding the LUT precomputation - storing an exhaustive table of all centroid norms (when using L2) This is only practical for small additive quantizers, eg. when a residual vector quantizer is used as coarse quantizer (ResidualCoarseQuantizer). This diff is based on AdditiveQuantizer diff because it applies equally to other quantizers (eg. the LSQ). Reviewed By: sc268 Differential Revision: D28467746 fbshipit-source-id: 82611fe1e4908c290204d4de866338c622ae4148
1 parent 0825eaf commit 8eab15e

File tree

6 files changed

+341
-19
lines changed

6 files changed

+341
-19
lines changed

faiss/IndexResidual.cpp

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
*/
77

88
#include <faiss/IndexResidual.h>
9-
#include "faiss/utils/utils.h"
109

1110
#include <algorithm>
1211
#include <cmath>
@@ -198,13 +197,35 @@ void ResidualCoarseQuantizer::add(idx_t, const float*) {
198197
FAISS_THROW_MSG("not applicable");
199198
}
200199

200+
void ResidualCoarseQuantizer::set_beam_factor(float new_beam_factor) {
201+
centroid_norms.resize(0);
202+
beam_factor = new_beam_factor;
203+
if (new_beam_factor > 0) {
204+
FAISS_THROW_IF_NOT(new_beam_factor >= 1.0);
205+
return;
206+
}
207+
208+
if (metric_type == METRIC_L2) {
209+
centroid_norms.resize((size_t)1 << rq.tot_bits);
210+
rq.compute_centroid_norms(centroid_norms.data());
211+
}
212+
}
213+
201214
void ResidualCoarseQuantizer::search(
202215
idx_t n,
203216
const float* x,
204217
idx_t k,
205218
float* distances,
206219
idx_t* labels) const {
207-
FAISS_THROW_IF_NOT(beam_factor >= 1.0);
220+
if (beam_factor < 0) {
221+
if (metric_type == METRIC_INNER_PRODUCT) {
222+
rq.knn_exact_inner_product(n, x, k, distances, labels);
223+
} else if (metric_type == METRIC_L2) {
224+
FAISS_THROW_IF_NOT(centroid_norms.size() == ntotal);
225+
rq.knn_exact_L2(n, x, k, distances, labels, centroid_norms.data());
226+
}
227+
return;
228+
}
208229

209230
int beam_size = int(k * beam_factor);
210231

@@ -249,28 +270,18 @@ void ResidualCoarseQuantizer::search(
249270
const int32_t* codes_i = codes.data() + beam_size * i * rq.M;
250271
for (idx_t j = 0; j < k; j++) {
251272
idx_t l = 0;
273+
int shift = 0;
252274
for (int m = 0; m < rq.M; m++) {
253-
l = (l << rq.nbits[m]) | *codes_i++;
275+
l |= (*codes_i++) << shift;
276+
shift += rq.nbits[m];
254277
}
255278
labels[i * k + j] = l;
256279
}
257280
}
258281
}
259282

260283
void ResidualCoarseQuantizer::reconstruct(idx_t key, float* recons) const {
261-
for (int m = 0; m < rq.M; m++) {
262-
int nbits = rq.nbits[m];
263-
idx_t l = key & ((idx_t(1) << nbits) - 1);
264-
key = key >> nbits;
265-
const float* c = rq.codebooks.data() + d * (rq.codebook_offsets[m] + l);
266-
if (m == 0) {
267-
memcpy(recons, c, sizeof(*c) * d);
268-
} else {
269-
for (int i = 0; i < d; i++) {
270-
recons[i] += c[i];
271-
}
272-
}
273-
}
284+
rq.decode_64bit(key, recons);
274285
}
275286

276287
void ResidualCoarseQuantizer::reset() {

faiss/IndexResidual.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,15 @@ struct ResidualCoarseQuantizer : Index {
101101
ResidualQuantizer rq;
102102

103103
/// factor between the beam size and the search k
104+
/// if negative, use exact search-to-centroid
104105
float beam_factor;
105106

107+
/// norms of centroids, useful for knn-search
108+
std::vector<float> centroid_norms;
109+
110+
/// computes centroid norms if required
111+
void set_beam_factor(float new_beam_factor);
112+
106113
/** Constructor.
107114
*
108115
* @param d dimensionality of the input vectors

faiss/impl/AdditiveQuantizer.cpp

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,30 @@
1818

1919
#include <algorithm>
2020

21+
#include <faiss/utils/Heap.h>
2122
#include <faiss/utils/distances.h>
2223
#include <faiss/utils/hamming.h> // BitstringWriter
2324
#include <faiss/utils/utils.h>
2425

26+
extern "C" {
27+
28+
// general matrix multiplication
29+
int sgemm_(
30+
const char* transa,
31+
const char* transb,
32+
FINTEGER* m,
33+
FINTEGER* n,
34+
FINTEGER* k,
35+
const float* alpha,
36+
const float* a,
37+
FINTEGER* lda,
38+
const float* b,
39+
FINTEGER* ldb,
40+
float* beta,
41+
float* c,
42+
FINTEGER* ldc);
43+
}
44+
2545
namespace {
2646

2747
// c and a and b can overlap
@@ -31,6 +51,12 @@ void fvec_add(size_t d, const float* a, const float* b, float* c) {
3151
}
3252
}
3353

54+
void fvec_add(size_t d, const float* a, float b, float* c) {
55+
for (size_t i = 0; i < d; i++) {
56+
c[i] = a[i] + b;
57+
}
58+
}
59+
3460
} // namespace
3561

3662
namespace faiss {
@@ -48,6 +74,7 @@ void AdditiveQuantizer::set_derived_values() {
4874
is_byte_aligned = false;
4975
}
5076
}
77+
total_codebook_size = codebook_offsets[M];
5178
// convert bits to bytes
5279
code_size = (tot_bits + 7) / 8;
5380
}
@@ -93,4 +120,151 @@ void AdditiveQuantizer::decode(const uint8_t* code, float* x, size_t n) const {
93120

94121
AdditiveQuantizer::~AdditiveQuantizer() {}
95122

123+
/****************************************************************************
124+
* Support for fast distance computations and search with additive quantizer
125+
****************************************************************************/
126+
127+
void AdditiveQuantizer::compute_centroid_norms(float* norms) const {
128+
size_t ntotal = (size_t)1 << tot_bits;
129+
// TODO: make tree of partial sums
130+
#pragma omp parallel
131+
{
132+
std::vector<float> tmp(d);
133+
#pragma omp for
134+
for (int64_t i = 0; i < ntotal; i++) {
135+
decode_64bit(i, tmp.data());
136+
norms[i] = fvec_norm_L2sqr(tmp.data(), d);
137+
}
138+
}
139+
}
140+
141+
void AdditiveQuantizer::decode_64bit(idx_t bits, float* xi) const {
142+
for (int m = 0; m < M; m++) {
143+
idx_t idx = bits & (((size_t)1 << nbits[m]) - 1);
144+
bits >>= nbits[m];
145+
const float* c = codebooks.data() + d * (codebook_offsets[m] + idx);
146+
if (m == 0) {
147+
memcpy(xi, c, sizeof(*xi) * d);
148+
} else {
149+
fvec_add(d, xi, c, xi);
150+
}
151+
}
152+
}
153+
154+
void AdditiveQuantizer::compute_LUT(size_t n, const float* xq, float* LUT)
155+
const {
156+
// in all cases, it is large matrix multiplication
157+
158+
FINTEGER ncenti = total_codebook_size;
159+
FINTEGER di = d;
160+
FINTEGER nqi = n;
161+
float one = 1, zero = 0;
162+
163+
sgemm_("Transposed",
164+
"Not transposed",
165+
&ncenti,
166+
&nqi,
167+
&di,
168+
&one,
169+
codebooks.data(),
170+
&di,
171+
xq,
172+
&di,
173+
&zero,
174+
LUT,
175+
&ncenti);
176+
}
177+
178+
namespace {
179+
180+
void compute_inner_prod_with_LUT(
181+
const AdditiveQuantizer& aq,
182+
const float* LUT,
183+
float* ips) {
184+
size_t prev_size = 1;
185+
for (int m = 0; m < aq.M; m++) {
186+
const float* LUTm = LUT + aq.codebook_offsets[m];
187+
int nb = aq.nbits[m];
188+
size_t nc = (size_t)1 << nb;
189+
190+
if (m == 0) {
191+
memcpy(ips, LUT, sizeof(*ips) * nc);
192+
} else {
193+
for (int64_t i = nc - 1; i >= 0; i--) {
194+
float v = LUTm[i];
195+
fvec_add(prev_size, ips, v, ips + i * prev_size);
196+
}
197+
}
198+
prev_size *= nc;
199+
}
200+
}
201+
202+
} // anonymous namespace
203+
204+
void AdditiveQuantizer::knn_exact_inner_product(
205+
idx_t n,
206+
const float* xq,
207+
idx_t k,
208+
float* distances,
209+
idx_t* labels) const {
210+
std::unique_ptr<float[]> LUT(new float[n * total_codebook_size]);
211+
compute_LUT(n, xq, LUT.get());
212+
size_t ntotal = (size_t)1 << tot_bits;
213+
214+
#pragma omp parallel if (n > 100)
215+
{
216+
std::vector<float> dis(ntotal);
217+
#pragma omp for
218+
for (idx_t i = 0; i < n; i++) {
219+
const float* LUTi = LUT.get() + i * total_codebook_size;
220+
compute_inner_prod_with_LUT(*this, LUTi, dis.data());
221+
float* distances_i = distances + i * k;
222+
idx_t* labels_i = labels + i * k;
223+
minheap_heapify(k, distances_i, labels_i);
224+
minheap_addn(k, distances_i, labels_i, dis.data(), nullptr, ntotal);
225+
minheap_reorder(k, distances_i, labels_i);
226+
}
227+
}
228+
}
229+
230+
void AdditiveQuantizer::knn_exact_L2(
231+
idx_t n,
232+
const float* xq,
233+
idx_t k,
234+
float* distances,
235+
idx_t* labels,
236+
const float* norms) const {
237+
std::unique_ptr<float[]> LUT(new float[n * total_codebook_size]);
238+
compute_LUT(n, xq, LUT.get());
239+
std::unique_ptr<float[]> q_norms(new float[n]);
240+
fvec_norms_L2sqr(q_norms.get(), xq, d, n);
241+
size_t ntotal = (size_t)1 << tot_bits;
242+
243+
#pragma omp parallel if (n > 100)
244+
{
245+
std::vector<float> dis(ntotal);
246+
#pragma omp for
247+
for (idx_t i = 0; i < n; i++) {
248+
const float* LUTi = LUT.get() + i * total_codebook_size;
249+
float* distances_i = distances + i * k;
250+
idx_t* labels_i = labels + i * k;
251+
252+
compute_inner_prod_with_LUT(*this, LUTi, dis.data());
253+
254+
// update distances using
255+
// ||x - y||^2 = ||x||^2 + ||y||^2 - 2 * <x,y>
256+
257+
maxheap_heapify(k, distances_i, labels_i);
258+
for (idx_t j = 0; j < ntotal; j++) {
259+
float disj = q_norms[i] + norms[j] - 2 * dis[j];
260+
if (disj < distances_i[0]) {
261+
heap_replace_top<CMax<float, int64_t>>(
262+
k, distances_i, labels_i, disj, j);
263+
}
264+
}
265+
maxheap_reorder(k, distances_i, labels_i);
266+
}
267+
}
268+
}
269+
96270
} // namespace faiss

faiss/impl/AdditiveQuantizer.h

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include <cstdint>
1111
#include <vector>
1212

13+
#include <faiss/Index.h>
14+
1315
namespace faiss {
1416

1517
/** Abstract structure for additive quantizers
@@ -26,8 +28,9 @@ struct AdditiveQuantizer {
2628

2729
// derived values
2830
std::vector<size_t> codebook_offsets;
29-
size_t code_size; ///< code size in bytes
30-
size_t tot_bits; ///< total number of bits
31+
size_t code_size; ///< code size in bytes
32+
size_t tot_bits; ///< total number of bits
33+
size_t total_codebook_size; ///< size of the codebook in vectors
3134
bool is_byte_aligned;
3235

3336
bool verbose; ///< verbose during training?
@@ -66,6 +69,46 @@ struct AdditiveQuantizer {
6669
*/
6770
void decode(const uint8_t* codes, float* x, size_t n) const;
6871

72+
/****************************************************************************
73+
* Support for exhaustive distance computations with the centroids.
74+
* Hence, the number of elements that can be enumerated is not too large.
75+
****************************************************************************/
76+
using idx_t = Index::idx_t;
77+
78+
/// decoding function for a code in a 64-bit word
79+
void decode_64bit(idx_t n, float* x) const;
80+
81+
/** Compute inner-product look-up tables. Used in the centroid search
82+
* functions.
83+
*
84+
* @param xq query vector, size (n, d)
85+
* @param LUT look-up table, size (n, total_codebook_size)
86+
*/
87+
void compute_LUT(size_t n, const float* xq, float* LUT) const;
88+
89+
/// exact IP search
90+
void knn_exact_inner_product(
91+
idx_t n,
92+
const float* xq,
93+
idx_t k,
94+
float* distances,
95+
idx_t* labels) const;
96+
97+
/** For L2 search we need the L2 norms of the centroids
98+
*
99+
* @param norms output norms table, size total_codebook_size
100+
*/
101+
void compute_centroid_norms(float* norms) const;
102+
103+
/** Exact L2 search, with precomputed norms */
104+
void knn_exact_L2(
105+
idx_t n,
106+
const float* xq,
107+
idx_t k,
108+
float* distances,
109+
idx_t* labels,
110+
const float* centroid_norms) const;
111+
69112
virtual ~AdditiveQuantizer();
70113
};
71114

faiss/impl/index_read.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,7 @@ Index* read_index(IOReader* f, int io_flags) {
490490
read_index_header(idxr, f);
491491
read_ResidualQuantizer(&idxr->rq, f);
492492
READ1(idxr->beam_factor);
493+
idxr->set_beam_factor(idxr->beam_factor);
493494
idx = idxr;
494495
} else if (h == fourcc("IvFl") || h == fourcc("IvFL")) { // legacy
495496
IndexIVFFlat* ivfl = new IndexIVFFlat();

0 commit comments

Comments
 (0)