Skip to content

Commit bbaaaf9

Browse files
authored
Integration of Quantization Framework for Binary Quantization with Indexing Flow (#1996)
* Integration of Quantization Framework for Binary Quantization with Indexing Flow Signed-off-by: VIKASH TIWARI <viktari@amazon.com> * Integration With Qunatization Config Signed-off-by: VIKASH TIWARI <viktari@amazon.com> --------- Signed-off-by: VIKASH TIWARI <viktari@amazon.com>
1 parent 59c312b commit bbaaaf9

31 files changed

+1414
-173
lines changed

src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java

+108-24
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,13 @@
2424
import org.apache.lucene.index.Sorter;
2525
import org.apache.lucene.util.IOUtils;
2626
import org.apache.lucene.util.RamUsageEstimator;
27+
import org.opensearch.knn.index.quantizationService.QuantizationService;
2728
import org.opensearch.knn.index.VectorDataType;
2829
import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter;
2930
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
3031
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
32+
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
33+
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
3134

3235
import java.io.IOException;
3336
import java.util.ArrayList;
@@ -46,6 +49,7 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter {
4649
private final FlatVectorsWriter flatVectorsWriter;
4750
private final List<NativeEngineFieldVectorsWriter<?>> fields = new ArrayList<>();
4851
private boolean finished;
52+
private final QuantizationService quantizationService = QuantizationService.getInstance();
4953

5054
/**
5155
* Add new field for indexing.
@@ -68,42 +72,24 @@ public KnnFieldVectorsWriter<?> addField(final FieldInfo fieldInfo) throws IOExc
6872
*/
6973
@Override
7074
public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
71-
// simply write data in the flat file
7275
flatVectorsWriter.flush(maxDoc, sortMap);
7376
for (final NativeEngineFieldVectorsWriter<?> field : fields) {
74-
final VectorDataType vectorDataType = extractVectorDataType(field.getFieldInfo());
75-
final KNNVectorValues<?> knnVectorValues = KNNVectorValuesFactory.getVectorValues(
76-
vectorDataType,
77-
field.getDocsWithField(),
78-
field.getVectors()
77+
trainAndIndex(
78+
field.getFieldInfo(),
79+
(vectorDataType, fieldInfo, fieldVectorsWriter) -> getKNNVectorValues(vectorDataType, fieldVectorsWriter),
80+
NativeIndexWriter::flushIndex,
81+
field
7982
);
80-
81-
NativeIndexWriter.getWriter(field.getFieldInfo(), segmentWriteState).flushIndex(knnVectorValues);
8283
}
8384
}
8485

8586
@Override
8687
public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState) throws IOException {
8788
// This will ensure that we are merging the FlatIndex during force merge.
8889
flatVectorsWriter.mergeOneField(fieldInfo, mergeState);
89-
9090
// For merge, pick values from flat vector and reindex again. This will use the flush operation to create graphs
91-
final VectorDataType vectorDataType = extractVectorDataType(fieldInfo);
92-
final KNNVectorValues<?> knnVectorValues;
93-
switch (fieldInfo.getVectorEncoding()) {
94-
case FLOAT32:
95-
final FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
96-
knnVectorValues = KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedFloats);
97-
break;
98-
case BYTE:
99-
final ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
100-
knnVectorValues = KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedBytes);
101-
break;
102-
default:
103-
throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]");
104-
}
91+
trainAndIndex(fieldInfo, this::getKNNVectorValuesForMerge, NativeIndexWriter::mergeIndex, mergeState);
10592

106-
NativeIndexWriter.getWriter(fieldInfo, segmentWriteState).mergeIndex(knnVectorValues);
10793
}
10894

10995
/**
@@ -146,4 +132,102 @@ public long ramBytesUsed() {
146132
.sum();
147133
}
148134

135+
/**
136+
* Retrieves the {@link KNNVectorValues} for a specific field based on the vector data type and field writer.
137+
*
138+
* @param vectorDataType The {@link VectorDataType} representing the type of vectors stored.
139+
* @param field The {@link NativeEngineFieldVectorsWriter} representing the field from which to retrieve vectors.
140+
* @param <T> The type of vectors being processed.
141+
* @return The {@link KNNVectorValues} associated with the field.
142+
*/
143+
private <T> KNNVectorValues<T> getKNNVectorValues(final VectorDataType vectorDataType, final NativeEngineFieldVectorsWriter<?> field) {
144+
return (KNNVectorValues<T>) KNNVectorValuesFactory.getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors());
145+
}
146+
147+
/**
148+
* Retrieves the {@link KNNVectorValues} for a specific field during a merge operation, based on the vector data type.
149+
*
150+
* @param vectorDataType The {@link VectorDataType} representing the type of vectors stored.
151+
* @param fieldInfo The {@link FieldInfo} object containing metadata about the field.
152+
* @param mergeState The {@link MergeState} representing the state of the merge operation.
153+
* @param <T> The type of vectors being processed.
154+
* @return The {@link KNNVectorValues} associated with the field during the merge.
155+
* @throws IOException If an I/O error occurs during the retrieval.
156+
*/
157+
private <T> KNNVectorValues<T> getKNNVectorValuesForMerge(
158+
final VectorDataType vectorDataType,
159+
final FieldInfo fieldInfo,
160+
final MergeState mergeState
161+
) throws IOException {
162+
switch (fieldInfo.getVectorEncoding()) {
163+
case FLOAT32:
164+
FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
165+
return (KNNVectorValues<T>) KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedFloats);
166+
case BYTE:
167+
ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
168+
return (KNNVectorValues<T>) KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedBytes);
169+
default:
170+
throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]");
171+
}
172+
}
173+
174+
/**
175+
* Functional interface representing an operation that indexes the provided {@link KNNVectorValues}.
176+
*
177+
* @param <T> The type of vectors being processed.
178+
*/
179+
@FunctionalInterface
180+
private interface IndexOperation<T> {
181+
void buildAndWrite(NativeIndexWriter writer, KNNVectorValues<T> knnVectorValues) throws IOException;
182+
}
183+
184+
/**
185+
* Functional interface representing a method that retrieves {@link KNNVectorValues} based on
186+
* the vector data type, field information, and the merge state.
187+
*
188+
* @param <DataType> The type of the data representing the vector (e.g., {@link VectorDataType}).
189+
* @param <FieldInfo> The metadata about the field.
190+
* @param <MergeState> The state of the merge operation.
191+
* @param <Result> The result of the retrieval, typically {@link KNNVectorValues}.
192+
*/
193+
@FunctionalInterface
194+
private interface VectorValuesRetriever<DataType, FieldInfo, MergeState, Result> {
195+
Result apply(DataType vectorDataType, FieldInfo fieldInfo, MergeState mergeState) throws IOException;
196+
}
197+
198+
/**
199+
* Unified method for processing a field during either the indexing or merge operation. This method retrieves vector values
200+
* based on the provided vector data type and applies the specified index operation, potentially including quantization if needed.
201+
*
202+
* @param fieldInfo The {@link FieldInfo} object containing metadata about the field.
203+
* @param vectorValuesRetriever A functional interface that retrieves {@link KNNVectorValues} based on the vector data type,
204+
* field information, and additional context (e.g., merge state or field writer).
205+
* @param indexOperation A functional interface that performs the indexing operation using the retrieved
206+
* {@link KNNVectorValues}.
207+
* @param VectorProcessingContext The additional context required for retrieving the vector values (e.g., {@link MergeState} or {@link NativeEngineFieldVectorsWriter}).
208+
* From Flush we need NativeFieldWriter which contains total number of vectors while from Merge we need merge state which contains vector information
209+
* @param <T> The type of vectors being processed.
210+
* @param <C> The type of the context needed for retrieving the vector values.
211+
* @throws IOException If an I/O error occurs during the processing.
212+
*/
213+
private <T, C> void trainAndIndex(
214+
final FieldInfo fieldInfo,
215+
final VectorValuesRetriever<VectorDataType, FieldInfo, C, KNNVectorValues<T>> vectorValuesRetriever,
216+
final IndexOperation<T> indexOperation,
217+
final C VectorProcessingContext
218+
) throws IOException {
219+
final VectorDataType vectorDataType = extractVectorDataType(fieldInfo);
220+
KNNVectorValues<T> knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext);
221+
QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo);
222+
QuantizationState quantizationState = null;
223+
if (quantizationParams != null) {
224+
quantizationState = quantizationService.train(quantizationParams, knnVectorValues);
225+
}
226+
NativeIndexWriter writer = (quantizationParams != null)
227+
? NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)
228+
: NativeIndexWriter.getWriter(fieldInfo, segmentWriteState);
229+
230+
knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext);
231+
indexOperation.buildAndWrite(writer, knnVectorValues);
232+
}
149233
}

src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java

+27-11
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,32 @@ public static DefaultIndexBuildStrategy getInstance() {
3939
return INSTANCE;
4040
}
4141

42+
/**
43+
* Builds and writes a k-NN index using the provided vector values and index parameters. This method handles both
44+
* quantized and non-quantized vectors, transferring them off-heap before building the index using native JNI services.
45+
*
46+
* <p>The method first iterates over the vector values to calculate the necessary bytes per vector. If quantization is
47+
* enabled, the vectors are quantized before being transferred off-heap. Once all vectors are transferred, they are
48+
* flushed and used to build the index. The index is then written to the specified path using JNI calls.</p>
49+
*
50+
* @param indexInfo The {@link BuildIndexParams} containing the parameters and configuration for building the index.
51+
* @param knnVectorValues The {@link KNNVectorValues} representing the vectors to be indexed.
52+
* @throws IOException If an I/O error occurs during the process of building and writing the index.
53+
*/
4254
public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVectorValues<?> knnVectorValues) throws IOException {
43-
iterateVectorValuesOnce(knnVectorValues); // to get bytesPerVector
44-
int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / knnVectorValues.bytesPerVector());
55+
// Needed to make sure we don't get 0 dimensions while initializing index
56+
iterateVectorValuesOnce(knnVectorValues);
57+
IndexBuildSetup indexBuildSetup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, indexInfo);
58+
59+
int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / indexBuildSetup.getBytesPerVector());
4560
try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) {
61+
final List<Integer> transferredDocIds = new ArrayList<>((int) knnVectorValues.totalLiveDocs());
4662

47-
final List<Integer> tranferredDocIds = new ArrayList<>();
4863
while (knnVectorValues.docId() != NO_MORE_DOCS) {
64+
Object vector = QuantizationIndexUtils.processAndReturnVector(knnVectorValues, indexBuildSetup);
4965
// append is true here so off heap memory buffer isn't overwritten
50-
vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), true);
51-
tranferredDocIds.add(knnVectorValues.docId());
66+
vectorTransfer.transfer(vector, true);
67+
transferredDocIds.add(knnVectorValues.docId());
5268
knnVectorValues.nextDoc();
5369
}
5470
vectorTransfer.flush(true);
@@ -60,24 +76,24 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
6076
if (params.containsKey(MODEL_ID)) {
6177
AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
6278
JNIService.createIndexFromTemplate(
63-
intListToArray(tranferredDocIds),
79+
intListToArray(transferredDocIds),
6480
vectorAddress,
65-
knnVectorValues.dimension(),
81+
indexBuildSetup.getDimensions(),
6682
indexInfo.getIndexPath(),
6783
(byte[]) params.get(KNNConstants.MODEL_BLOB_PARAMETER),
68-
indexInfo.getParameters(),
84+
params,
6985
indexInfo.getKnnEngine()
7086
);
7187
return null;
7288
});
7389
} else {
7490
AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
7591
JNIService.createIndex(
76-
intListToArray(tranferredDocIds),
92+
intListToArray(transferredDocIds),
7793
vectorAddress,
78-
knnVectorValues.dimension(),
94+
indexBuildSetup.getDimensions(),
7995
indexInfo.getIndexPath(),
80-
indexInfo.getParameters(),
96+
params,
8197
indexInfo.getKnnEngine()
8298
);
8399
return null;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.knn.index.codec.nativeindex;
7+
8+
import lombok.AllArgsConstructor;
9+
import lombok.Getter;
10+
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
11+
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
12+
13+
/**
14+
* IndexBuildSetup encapsulates the configuration and parameters required for building an index.
15+
* This includes the size of each vector, the dimensions of the vectors, and any quantization-related
16+
* settings such as the output and state of quantization.
17+
*/
18+
@Getter
19+
@AllArgsConstructor
20+
final class IndexBuildSetup {
21+
/**
22+
* The number of bytes per vector.
23+
*/
24+
private final int bytesPerVector;
25+
26+
/**
27+
* Dimension of Vector for Indexing
28+
*/
29+
private final int dimensions;
30+
31+
/**
32+
* The quantization output that will hold the quantized vector.
33+
*/
34+
private final QuantizationOutput quantizationOutput;
35+
36+
/**
37+
* The state of quantization, which may include parameters and trained models.
38+
*/
39+
private final QuantizationState quantizationState;
40+
}

0 commit comments

Comments
 (0)