Skip to content

Commit 2b303d9

Browse files
Add IVF changes to support Faiss byte vector (#2002)
* Add HNSW changes to support Faiss byte vector Signed-off-by: Naveen Tatikonda <navtat@amazon.com> * Address Review Comments Signed-off-by: Naveen Tatikonda <navtat@amazon.com> * Add IVF changes to support Faiss byte vector Signed-off-by: Naveen Tatikonda <navtat@amazon.com> * Address Review Comments Signed-off-by: Naveen Tatikonda <navtat@amazon.com> --------- Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
1 parent 9dbe7de commit 2b303d9

21 files changed

+626
-92
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1919
* Integrate Lucene Vector field with native engines to use KNNVectorFormat during segment creation [#1945](https://github.com/opensearch-project/k-NN/pull/1945)
2020
* k-NN query rescore support for native engines [#1984](https://github.com/opensearch-project/k-NN/pull/1984)
2121
* Add support for byte vector with Faiss Engine HNSW algorithm [#1823](https://github.com/opensearch-project/k-NN/pull/1823)
22+
* Add support for byte vector with Faiss Engine IVF algorithm [#2002](https://github.com/opensearch-project/k-NN/pull/2002)
2223
### Enhancements
2324
* Adds iterative graph build capability into a faiss index to improve the memory footprint during indexing and Integrates KNNVectorsFormat for native engines[#1950](https://github.com/opensearch-project/k-NN/pull/1950)
2425
### Bug Fixes

jni/include/faiss_wrapper.h

+13
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ namespace knn_jni {
3636
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ,
3737
jobject parametersJ);
3838

39+
// Create a index with ids and byte vectors. Instead of creating a new index, this function creates the index
40+
// based off of the template index passed in. The index is serialized to indexPathJ.
41+
void CreateByteIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
42+
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ,
43+
jobject parametersJ);
44+
3945
// Load an index from indexPathJ into memory.
4046
//
4147
// Return a pointer to the loaded index
@@ -110,6 +116,13 @@ namespace knn_jni {
110116
jbyteArray TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension,
111117
jlong trainVectorsPointerJ);
112118

119+
// Create an empty byte index defined by the values in the Java map, parametersJ. Train the index with
120+
// the byte vectors located at trainVectorsPointerJ.
121+
//
122+
// Return the serialized representation
123+
jbyteArray TrainByteIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension,
124+
jlong trainVectorsPointerJ);
125+
113126
/*
114127
* Perform a range search with filter against the index located in memory at indexPointerJ.
115128
*

jni/include/org_opensearch_knn_jni_FaissService.h

+16
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT
112112
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndexFromTemplate
113113
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject);
114114

115+
/*
116+
* Class: org_opensearch_knn_jni_FaissService
117+
* Method: createByteIndexFromTemplate
118+
* Signature: ([IJILjava/lang/String;[BLjava/util/Map;)V
119+
*/
120+
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createByteIndexFromTemplate
121+
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject);
122+
115123
/*
116124
* Class: org_opensearch_knn_jni_FaissService
117125
* Method: loadIndex
@@ -216,6 +224,14 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex
216224
JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinaryIndex
217225
(JNIEnv *, jclass, jobject, jint, jlong);
218226

227+
/*
228+
* Class: org_opensearch_knn_jni_FaissService
229+
* Method: trainByteIndex
230+
* Signature: (Ljava/util/Map;IJ)[B
231+
*/
232+
JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainByteIndex
233+
(JNIEnv *, jclass, jobject, jint, jlong);
234+
219235
/*
220236
* Class: org_opensearch_knn_jni_FaissService
221237
* Method: transferVectors

jni/src/faiss_wrapper.cpp

+157
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,96 @@ void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInter
320320
faiss::write_index_binary(&idMap, indexPathCpp.c_str());
321321
}
322322

323+
void knn_jni::faiss_wrapper::CreateByteIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
324+
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ,
325+
jbyteArray templateIndexJ, jobject parametersJ) {
326+
if (idsJ == nullptr) {
327+
throw std::runtime_error("IDs cannot be null");
328+
}
329+
330+
if (vectorsAddressJ <= 0) {
331+
throw std::runtime_error("VectorsAddress cannot be less than 0");
332+
}
333+
334+
if(dimJ <= 0) {
335+
throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0");
336+
}
337+
338+
if (indexPathJ == nullptr) {
339+
throw std::runtime_error("Index path cannot be null");
340+
}
341+
342+
if (templateIndexJ == nullptr) {
343+
throw std::runtime_error("Template index cannot be null");
344+
}
345+
346+
// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
347+
auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ);
348+
if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
349+
auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
350+
omp_set_num_threads(threadCount);
351+
}
352+
jniUtil->DeleteLocalRef(env, parametersJ);
353+
354+
// Read data set
355+
// Read vectors from memory address
356+
auto *inputVectors = reinterpret_cast<std::vector<int8_t>*>(vectorsAddressJ);
357+
int dim = (int)dimJ;
358+
int numVectors = (int) (inputVectors->size() / (uint64_t) dim);
359+
int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);
360+
361+
if (numIds != numVectors) {
362+
throw std::runtime_error("Number of IDs does not match number of vectors");
363+
}
364+
365+
// Get vector of bytes from jbytearray
366+
int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ);
367+
jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr);
368+
369+
faiss::VectorIOReader vectorIoReader;
370+
for (int i = 0; i < indexBytesCount; i++) {
371+
vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]);
372+
}
373+
jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT);
374+
375+
// Create faiss index
376+
std::unique_ptr<faiss::Index> indexWriter;
377+
indexWriter.reset(faiss::read_index(&vectorIoReader, 0));
378+
379+
auto ids = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
380+
faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get());
381+
382+
// Add vectors in batches by casting int8 vectors into float with a batch size of 1000 to avoid additional memory spike.
383+
// Refer to this github issue for more details https://github.com/opensearch-project/k-NN/issues/1659#issuecomment-2307390255
384+
int batchSize = 1000;
385+
std::vector <float> inputFloatVectors(batchSize * dim);
386+
std::vector <int64_t> floatVectorsIds(batchSize);
387+
int id = 0;
388+
auto iter = inputVectors->begin();
389+
390+
for (int id = 0; id < numVectors; id += batchSize) {
391+
if (numVectors - id < batchSize) {
392+
batchSize = numVectors - id;
393+
}
394+
395+
for (int i = 0; i < batchSize; ++i) {
396+
floatVectorsIds[i] = ids[id + i];
397+
for (int j = 0; j < dim; ++j, ++iter) {
398+
inputFloatVectors[i * dim + j] = static_cast<float>(*iter);
399+
}
400+
}
401+
idMap.add_with_ids(batchSize, inputFloatVectors.data(), floatVectorsIds.data());
402+
}
403+
404+
// Releasing the vectorsAddressJ memory as that is not required once we have created the index.
405+
// This is not the ideal approach, please refer this gh issue for long term solution:
406+
// https://github.com/opensearch-project/k-NN/issues/1600
407+
delete inputVectors;
408+
// Write the index to disk
409+
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
410+
faiss::write_index(&idMap, indexPathCpp.c_str());
411+
}
412+
323413
jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) {
324414
if (indexPathJ == nullptr) {
325415
throw std::runtime_error("Index path cannot be null");
@@ -782,6 +872,73 @@ jbyteArray knn_jni::faiss_wrapper::TrainBinaryIndex(knn_jni::JNIUtilInterface *
782872
return ret;
783873
}
784874

875+
jbyteArray knn_jni::faiss_wrapper::TrainByteIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ,
876+
jint dimensionJ, jlong trainVectorsPointerJ) {
877+
// First, we need to build the index
878+
if (parametersJ == nullptr) {
879+
throw std::runtime_error("Parameters cannot be null");
880+
}
881+
882+
auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ);
883+
884+
jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE);
885+
std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ));
886+
faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp);
887+
888+
// Create faiss index
889+
jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION);
890+
std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ));
891+
892+
std::unique_ptr<faiss::Index> indexWriter;
893+
indexWriter.reset(faiss::index_factory((int) dimensionJ, indexDescriptionCpp.c_str(), metric));
894+
895+
// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
896+
if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
897+
auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
898+
omp_set_num_threads(threadCount);
899+
}
900+
901+
// Add extra parameters that cant be configured with the index factory
902+
if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) {
903+
jobject subParametersJ = parametersCpp[knn_jni::PARAMETERS];
904+
auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, subParametersJ);
905+
SetExtraParameters(jniUtil, env, subParametersCpp, indexWriter.get());
906+
jniUtil->DeleteLocalRef(env, subParametersJ);
907+
}
908+
909+
// Train index if needed
910+
auto *trainingVectorsPointerCpp = reinterpret_cast<std::vector<int8_t>*>(trainVectorsPointerJ);
911+
int numVectors = trainingVectorsPointerCpp->size()/(int) dimensionJ;
912+
913+
auto iter = trainingVectorsPointerCpp->begin();
914+
std::vector <float> trainingFloatVectors(numVectors * dimensionJ);
915+
for(int i=0; i < numVectors * dimensionJ; ++i, ++iter) {
916+
trainingFloatVectors[i] = static_cast<float>(*iter);
917+
}
918+
919+
if(!indexWriter->is_trained) {
920+
InternalTrainIndex(indexWriter.get(), numVectors, trainingFloatVectors.data());
921+
}
922+
jniUtil->DeleteLocalRef(env, parametersJ);
923+
924+
// Now that indexWriter is trained, we just load the bytes into an array and return
925+
faiss::VectorIOWriter vectorIoWriter;
926+
faiss::write_index(indexWriter.get(), &vectorIoWriter);
927+
928+
// Wrap in smart pointer
929+
std::unique_ptr<jbyte[]> jbytesBuffer;
930+
jbytesBuffer.reset(new jbyte[vectorIoWriter.data.size()]);
931+
int c = 0;
932+
for (auto b : vectorIoWriter.data) {
933+
jbytesBuffer[c++] = (jbyte) b;
934+
}
935+
936+
jbyteArray ret = jniUtil->NewByteArray(env, vectorIoWriter.data.size());
937+
jniUtil->SetByteArrayRegion(env, ret, 0, vectorIoWriter.data.size(), jbytesBuffer.get());
938+
return ret;
939+
}
940+
941+
785942
faiss::MetricType TranslateSpaceToMetric(const std::string& spaceType) {
786943
if (spaceType == knn_jni::L2) {
787944
return faiss::METRIC_L2;

jni/src/org_opensearch_knn_jni_FaissService.cpp

+28
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,21 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryInde
192192
}
193193
}
194194

195+
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createByteIndexFromTemplate(JNIEnv * env, jclass cls,
196+
jintArray idsJ,
197+
jlong vectorsAddressJ,
198+
jint dimJ,
199+
jstring indexPathJ,
200+
jbyteArray templateIndexJ,
201+
jobject parametersJ)
202+
{
203+
try {
204+
knn_jni::faiss_wrapper::CreateByteIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ);
205+
} catch (...) {
206+
jniUtil.CatchCppExceptionAndThrowJava(env);
207+
}
208+
}
209+
195210
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex(JNIEnv * env, jclass cls, jstring indexPathJ)
196211
{
197212
try {
@@ -335,6 +350,19 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinar
335350
return nullptr;
336351
}
337352

353+
JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainByteIndex(JNIEnv * env, jclass cls,
354+
jobject parametersJ,
355+
jint dimensionJ,
356+
jlong trainVectorsPointerJ)
357+
{
358+
try {
359+
return knn_jni::faiss_wrapper::TrainByteIndex(&jniUtil, env, parametersJ, dimensionJ, trainVectorsPointerJ);
360+
} catch (...) {
361+
jniUtil.CatchCppExceptionAndThrowJava(env);
362+
}
363+
return nullptr;
364+
}
365+
338366
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors(JNIEnv * env, jclass cls,
339367
jlong vectorsPointerJ,
340368
jobjectArray vectorsJ)

jni/tests/faiss_wrapper_test.cpp

+81
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,55 @@ TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) {
230230
std::remove(indexPath.c_str());
231231
}
232232

233+
TEST(FaissCreateByteIndexFromTemplateTest, BasicAssertions) {
234+
// Define the data
235+
faiss::idx_t numIds = 100;
236+
std::vector<faiss::idx_t> ids;
237+
auto *vectors = new std::vector<int8_t>();
238+
int dim = 8;
239+
vectors->reserve(dim * numIds);
240+
for (int64_t i = 0; i < numIds; ++i) {
241+
ids.push_back(i);
242+
for (int j = 0; j < dim; ++j) {
243+
vectors->push_back(test_util::RandomInt(-128, 127));
244+
}
245+
}
246+
247+
std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss");
248+
faiss::MetricType metricType = faiss::METRIC_L2;
249+
std::string method = "HNSW32,SQ8_direct_signed";
250+
251+
std::unique_ptr<faiss::Index> createdIndex(
252+
test_util::FaissCreateIndex(dim, method, metricType));
253+
auto vectorIoWriter = test_util::FaissGetSerializedIndex(createdIndex.get());
254+
255+
// Setup jni
256+
JNIEnv *jniEnv = nullptr;
257+
NiceMock<test_util::MockJNIUtil> mockJNIUtil;
258+
259+
EXPECT_CALL(mockJNIUtil,
260+
GetJavaObjectArrayLength(
261+
jniEnv, reinterpret_cast<jobjectArray>(&vectors)))
262+
.WillRepeatedly(Return(vectors->size()));
263+
264+
std::string spaceType = knn_jni::L2;
265+
std::unordered_map<std::string, jobject> parametersMap;
266+
parametersMap[knn_jni::SPACE_TYPE] = (jobject) &spaceType;
267+
268+
knn_jni::faiss_wrapper::CreateByteIndexFromTemplate(
269+
&mockJNIUtil, jniEnv, reinterpret_cast<jintArray>(&ids),
270+
(jlong)vectors, dim, (jstring)&indexPath,
271+
reinterpret_cast<jbyteArray>(&(vectorIoWriter.data)),
272+
(jobject) &parametersMap
273+
);
274+
275+
// Make sure index can be loaded
276+
std::unique_ptr<faiss::Index> index(test_util::FaissLoadIndex(indexPath));
277+
278+
// Clean up
279+
std::remove(indexPath.c_str());
280+
}
281+
233282
TEST(FaissLoadIndexTest, BasicAssertions) {
234283
// Define the data
235284
faiss::idx_t numIds = 100;
@@ -717,6 +766,38 @@ TEST(FaissTrainIndexTest, BasicAssertions) {
717766
ASSERT_TRUE(trainedIndex->is_trained);
718767
}
719768

769+
TEST(FaissTrainByteIndexTest, BasicAssertions) {
770+
// Define the index configuration
771+
int dim = 2;
772+
std::string spaceType = knn_jni::L2;
773+
std::string index_description = "IVF4,SQ8_direct_signed";
774+
775+
std::unordered_map<std::string, jobject> parametersMap;
776+
parametersMap[knn_jni::SPACE_TYPE] = (jobject) &spaceType;
777+
parametersMap[knn_jni::INDEX_DESCRIPTION] = (jobject) &index_description;
778+
779+
// Define training data
780+
int numTrainingVectors = 256;
781+
std::vector<int8_t> trainingVectors = test_util::RandomByteVectors(dim, numTrainingVectors, -128, 127);
782+
783+
// Setup jni
784+
JNIEnv *jniEnv = nullptr;
785+
NiceMock<test_util::MockJNIUtil> mockJNIUtil;
786+
787+
// Perform training
788+
std::unique_ptr<std::vector<uint8_t>> trainedIndexSerialization(
789+
reinterpret_cast<std::vector<uint8_t> *>(
790+
knn_jni::faiss_wrapper::TrainByteIndex(
791+
&mockJNIUtil, jniEnv, (jobject) &parametersMap, dim,
792+
reinterpret_cast<jlong>(&trainingVectors))));
793+
794+
std::unique_ptr<faiss::Index> trainedIndex(
795+
test_util::FaissLoadFromSerializedIndex(trainedIndexSerialization.get()));
796+
797+
// Confirm that training succeeded
798+
ASSERT_TRUE(trainedIndex->is_trained);
799+
}
800+
720801
TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) {
721802
// Define the data
722803
faiss::idx_t numIds = 200;

jni/tests/test_util.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,14 @@ std::vector<float> test_util::RandomVectors(int dim, int64_t numVectors, float m
447447
return vectors;
448448
}
449449

450+
std::vector<int8_t> test_util::RandomByteVectors(int dim, int64_t numVectors, int min, int max) {
451+
std::vector<int8_t> vectors(dim*numVectors);
452+
for (int64_t i = 0; i < dim*numVectors; i++) {
453+
vectors[i] = test_util::RandomInt(min, max);
454+
}
455+
return vectors;
456+
}
457+
450458
std::vector<int64_t> test_util::Range(int64_t numElements) {
451459
std::vector<int64_t> rangeVector(numElements);
452460
for (int64_t i = 0; i < numElements; i++) {

jni/tests/test_util.h

+2
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ namespace test_util {
173173

174174
std::vector<float> RandomVectors(int dim, int64_t numVectors, float min, float max);
175175

176+
std::vector<int8_t> RandomByteVectors(int dim, int64_t numVectors, int min, int max);
177+
176178
std::vector<int64_t> Range(int64_t numElements);
177179

178180
// returns the number of 64 bit words it would take to hold numBits

0 commit comments

Comments
 (0)