@@ -320,6 +320,96 @@ void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInter
320
320
faiss::write_index_binary (&idMap, indexPathCpp.c_str ());
321
321
}
322
322
323
+ void knn_jni::faiss_wrapper::CreateByteIndexFromTemplate (knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
324
+ jlong vectorsAddressJ, jint dimJ, jstring indexPathJ,
325
+ jbyteArray templateIndexJ, jobject parametersJ) {
326
+ if (idsJ == nullptr ) {
327
+ throw std::runtime_error (" IDs cannot be null" );
328
+ }
329
+
330
+ if (vectorsAddressJ <= 0 ) {
331
+ throw std::runtime_error (" VectorsAddress cannot be less than 0" );
332
+ }
333
+
334
+ if (dimJ <= 0 ) {
335
+ throw std::runtime_error (" Vectors dimensions cannot be less than or equal to 0" );
336
+ }
337
+
338
+ if (indexPathJ == nullptr ) {
339
+ throw std::runtime_error (" Index path cannot be null" );
340
+ }
341
+
342
+ if (templateIndexJ == nullptr ) {
343
+ throw std::runtime_error (" Template index cannot be null" );
344
+ }
345
+
346
+ // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
347
+ auto parametersCpp = jniUtil->ConvertJavaMapToCppMap (env, parametersJ);
348
+ if (parametersCpp.find (knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end ()) {
349
+ auto threadCount = jniUtil->ConvertJavaObjectToCppInteger (env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
350
+ omp_set_num_threads (threadCount);
351
+ }
352
+ jniUtil->DeleteLocalRef (env, parametersJ);
353
+
354
+ // Read data set
355
+ // Read vectors from memory address
356
+ auto *inputVectors = reinterpret_cast <std::vector<int8_t >*>(vectorsAddressJ);
357
+ int dim = (int )dimJ;
358
+ int numVectors = (int ) (inputVectors->size () / (uint64_t ) dim);
359
+ int numIds = jniUtil->GetJavaIntArrayLength (env, idsJ);
360
+
361
+ if (numIds != numVectors) {
362
+ throw std::runtime_error (" Number of IDs does not match number of vectors" );
363
+ }
364
+
365
+ // Get vector of bytes from jbytearray
366
+ int indexBytesCount = jniUtil->GetJavaBytesArrayLength (env, templateIndexJ);
367
+ jbyte * indexBytesJ = jniUtil->GetByteArrayElements (env, templateIndexJ, nullptr );
368
+
369
+ faiss::VectorIOReader vectorIoReader;
370
+ for (int i = 0 ; i < indexBytesCount; i++) {
371
+ vectorIoReader.data .push_back ((uint8_t ) indexBytesJ[i]);
372
+ }
373
+ jniUtil->ReleaseByteArrayElements (env, templateIndexJ, indexBytesJ, JNI_ABORT);
374
+
375
+ // Create faiss index
376
+ std::unique_ptr<faiss::Index> indexWriter;
377
+ indexWriter.reset (faiss::read_index (&vectorIoReader, 0 ));
378
+
379
+ auto ids = jniUtil->ConvertJavaIntArrayToCppIntVector (env, idsJ);
380
+ faiss::IndexIDMap idMap = faiss::IndexIDMap (indexWriter.get ());
381
+
382
+ // Add vectors in batches by casting int8 vectors into float with a batch size of 1000 to avoid additional memory spike.
383
+ // Refer to this github issue for more details https://github.com/opensearch-project/k-NN/issues/1659#issuecomment-2307390255
384
+ int batchSize = 1000 ;
385
+ std::vector <float > inputFloatVectors (batchSize * dim);
386
+ std::vector <int64_t > floatVectorsIds (batchSize);
387
+ int id = 0 ;
388
+ auto iter = inputVectors->begin ();
389
+
390
+ for (int id = 0 ; id < numVectors; id += batchSize) {
391
+ if (numVectors - id < batchSize) {
392
+ batchSize = numVectors - id;
393
+ }
394
+
395
+ for (int i = 0 ; i < batchSize; ++i) {
396
+ floatVectorsIds[i] = ids[id + i];
397
+ for (int j = 0 ; j < dim; ++j, ++iter) {
398
+ inputFloatVectors[i * dim + j] = static_cast <float >(*iter);
399
+ }
400
+ }
401
+ idMap.add_with_ids (batchSize, inputFloatVectors.data (), floatVectorsIds.data ());
402
+ }
403
+
404
+ // Releasing the vectorsAddressJ memory as that is not required once we have created the index.
405
+ // This is not the ideal approach, please refer this gh issue for long term solution:
406
+ // https://github.com/opensearch-project/k-NN/issues/1600
407
+ delete inputVectors;
408
+ // Write the index to disk
409
+ std::string indexPathCpp (jniUtil->ConvertJavaStringToCppString (env, indexPathJ));
410
+ faiss::write_index (&idMap, indexPathCpp.c_str ());
411
+ }
412
+
323
413
jlong knn_jni::faiss_wrapper::LoadIndex (knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) {
324
414
if (indexPathJ == nullptr ) {
325
415
throw std::runtime_error (" Index path cannot be null" );
@@ -782,6 +872,73 @@ jbyteArray knn_jni::faiss_wrapper::TrainBinaryIndex(knn_jni::JNIUtilInterface *
782
872
return ret;
783
873
}
784
874
875
+ jbyteArray knn_jni::faiss_wrapper::TrainByteIndex (knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ,
876
+ jint dimensionJ, jlong trainVectorsPointerJ) {
877
+ // First, we need to build the index
878
+ if (parametersJ == nullptr ) {
879
+ throw std::runtime_error (" Parameters cannot be null" );
880
+ }
881
+
882
+ auto parametersCpp = jniUtil->ConvertJavaMapToCppMap (env, parametersJ);
883
+
884
+ jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow (parametersCpp, knn_jni::SPACE_TYPE);
885
+ std::string spaceTypeCpp (jniUtil->ConvertJavaObjectToCppString (env, spaceTypeJ));
886
+ faiss::MetricType metric = TranslateSpaceToMetric (spaceTypeCpp);
887
+
888
+ // Create faiss index
889
+ jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow (parametersCpp, knn_jni::INDEX_DESCRIPTION);
890
+ std::string indexDescriptionCpp (jniUtil->ConvertJavaObjectToCppString (env, indexDescriptionJ));
891
+
892
+ std::unique_ptr<faiss::Index> indexWriter;
893
+ indexWriter.reset (faiss::index_factory ((int ) dimensionJ, indexDescriptionCpp.c_str (), metric));
894
+
895
+ // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
896
+ if (parametersCpp.find (knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end ()) {
897
+ auto threadCount = jniUtil->ConvertJavaObjectToCppInteger (env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
898
+ omp_set_num_threads (threadCount);
899
+ }
900
+
901
+ // Add extra parameters that cant be configured with the index factory
902
+ if (parametersCpp.find (knn_jni::PARAMETERS) != parametersCpp.end ()) {
903
+ jobject subParametersJ = parametersCpp[knn_jni::PARAMETERS];
904
+ auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap (env, subParametersJ);
905
+ SetExtraParameters (jniUtil, env, subParametersCpp, indexWriter.get ());
906
+ jniUtil->DeleteLocalRef (env, subParametersJ);
907
+ }
908
+
909
+ // Train index if needed
910
+ auto *trainingVectorsPointerCpp = reinterpret_cast <std::vector<int8_t >*>(trainVectorsPointerJ);
911
+ int numVectors = trainingVectorsPointerCpp->size ()/(int ) dimensionJ;
912
+
913
+ auto iter = trainingVectorsPointerCpp->begin ();
914
+ std::vector <float > trainingFloatVectors (numVectors * dimensionJ);
915
+ for (int i=0 ; i < numVectors * dimensionJ; ++i, ++iter) {
916
+ trainingFloatVectors[i] = static_cast <float >(*iter);
917
+ }
918
+
919
+ if (!indexWriter->is_trained ) {
920
+ InternalTrainIndex (indexWriter.get (), numVectors, trainingFloatVectors.data ());
921
+ }
922
+ jniUtil->DeleteLocalRef (env, parametersJ);
923
+
924
+ // Now that indexWriter is trained, we just load the bytes into an array and return
925
+ faiss::VectorIOWriter vectorIoWriter;
926
+ faiss::write_index (indexWriter.get (), &vectorIoWriter);
927
+
928
+ // Wrap in smart pointer
929
+ std::unique_ptr<jbyte[]> jbytesBuffer;
930
+ jbytesBuffer.reset (new jbyte[vectorIoWriter.data .size ()]);
931
+ int c = 0 ;
932
+ for (auto b : vectorIoWriter.data ) {
933
+ jbytesBuffer[c++] = (jbyte) b;
934
+ }
935
+
936
+ jbyteArray ret = jniUtil->NewByteArray (env, vectorIoWriter.data .size ());
937
+ jniUtil->SetByteArrayRegion (env, ret, 0 , vectorIoWriter.data .size (), jbytesBuffer.get ());
938
+ return ret;
939
+ }
940
+
941
+
785
942
faiss::MetricType TranslateSpaceToMetric (const std::string& spaceType) {
786
943
if (spaceType == knn_jni::L2) {
787
944
return faiss::METRIC_L2;
0 commit comments