24
24
import org .apache .lucene .index .Sorter ;
25
25
import org .apache .lucene .util .IOUtils ;
26
26
import org .apache .lucene .util .RamUsageEstimator ;
27
+ import org .opensearch .knn .index .quantizationService .QuantizationService ;
27
28
import org .opensearch .knn .index .VectorDataType ;
28
29
import org .opensearch .knn .index .codec .nativeindex .NativeIndexWriter ;
29
30
import org .opensearch .knn .index .vectorvalues .KNNVectorValues ;
30
31
import org .opensearch .knn .index .vectorvalues .KNNVectorValuesFactory ;
32
+ import org .opensearch .knn .quantization .models .quantizationParams .QuantizationParams ;
33
+ import org .opensearch .knn .quantization .models .quantizationState .QuantizationState ;
31
34
32
35
import java .io .IOException ;
33
36
import java .util .ArrayList ;
@@ -46,6 +49,7 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter {
46
49
private final FlatVectorsWriter flatVectorsWriter ;
47
50
private final List <NativeEngineFieldVectorsWriter <?>> fields = new ArrayList <>();
48
51
private boolean finished ;
52
+ private final QuantizationService quantizationService = QuantizationService .getInstance ();
49
53
50
54
/**
51
55
* Add new field for indexing.
@@ -68,42 +72,24 @@ public KnnFieldVectorsWriter<?> addField(final FieldInfo fieldInfo) throws IOExc
68
72
*/
69
73
@ Override
70
74
public void flush (int maxDoc , final Sorter .DocMap sortMap ) throws IOException {
71
- // simply write data in the flat file
72
75
flatVectorsWriter .flush (maxDoc , sortMap );
73
76
for (final NativeEngineFieldVectorsWriter <?> field : fields ) {
74
- final VectorDataType vectorDataType = extractVectorDataType ( field . getFieldInfo ());
75
- final KNNVectorValues <?> knnVectorValues = KNNVectorValuesFactory . getVectorValues (
76
- vectorDataType ,
77
- field . getDocsWithField () ,
78
- field . getVectors ()
77
+ trainAndIndex (
78
+ field . getFieldInfo (),
79
+ ( vectorDataType , fieldInfo , fieldVectorsWriter ) -> getKNNVectorValues ( vectorDataType , fieldVectorsWriter ) ,
80
+ NativeIndexWriter :: flushIndex ,
81
+ field
79
82
);
80
-
81
- NativeIndexWriter .getWriter (field .getFieldInfo (), segmentWriteState ).flushIndex (knnVectorValues );
82
83
}
83
84
}
84
85
85
86
@ Override
86
87
public void mergeOneField (final FieldInfo fieldInfo , final MergeState mergeState ) throws IOException {
87
88
// This will ensure that we are merging the FlatIndex during force merge.
88
89
flatVectorsWriter .mergeOneField (fieldInfo , mergeState );
89
-
90
90
// For merge, pick values from flat vector and reindex again. This will use the flush operation to create graphs
91
- final VectorDataType vectorDataType = extractVectorDataType (fieldInfo );
92
- final KNNVectorValues <?> knnVectorValues ;
93
- switch (fieldInfo .getVectorEncoding ()) {
94
- case FLOAT32 :
95
- final FloatVectorValues mergedFloats = MergedVectorValues .mergeFloatVectorValues (fieldInfo , mergeState );
96
- knnVectorValues = KNNVectorValuesFactory .getVectorValues (vectorDataType , mergedFloats );
97
- break ;
98
- case BYTE :
99
- final ByteVectorValues mergedBytes = MergedVectorValues .mergeByteVectorValues (fieldInfo , mergeState );
100
- knnVectorValues = KNNVectorValuesFactory .getVectorValues (vectorDataType , mergedBytes );
101
- break ;
102
- default :
103
- throw new IllegalStateException ("Unsupported vector encoding [" + fieldInfo .getVectorEncoding () + "]" );
104
- }
91
+ trainAndIndex (fieldInfo , this ::getKNNVectorValuesForMerge , NativeIndexWriter ::mergeIndex , mergeState );
105
92
106
- NativeIndexWriter .getWriter (fieldInfo , segmentWriteState ).mergeIndex (knnVectorValues );
107
93
}
108
94
109
95
/**
@@ -146,4 +132,102 @@ public long ramBytesUsed() {
146
132
.sum ();
147
133
}
148
134
135
+ /**
136
+ * Retrieves the {@link KNNVectorValues} for a specific field based on the vector data type and field writer.
137
+ *
138
+ * @param vectorDataType The {@link VectorDataType} representing the type of vectors stored.
139
+ * @param field The {@link NativeEngineFieldVectorsWriter} representing the field from which to retrieve vectors.
140
+ * @param <T> The type of vectors being processed.
141
+ * @return The {@link KNNVectorValues} associated with the field.
142
+ */
143
+ private <T > KNNVectorValues <T > getKNNVectorValues (final VectorDataType vectorDataType , final NativeEngineFieldVectorsWriter <?> field ) {
144
+ return (KNNVectorValues <T >) KNNVectorValuesFactory .getVectorValues (vectorDataType , field .getDocsWithField (), field .getVectors ());
145
+ }
146
+
147
+ /**
148
+ * Retrieves the {@link KNNVectorValues} for a specific field during a merge operation, based on the vector data type.
149
+ *
150
+ * @param vectorDataType The {@link VectorDataType} representing the type of vectors stored.
151
+ * @param fieldInfo The {@link FieldInfo} object containing metadata about the field.
152
+ * @param mergeState The {@link MergeState} representing the state of the merge operation.
153
+ * @param <T> The type of vectors being processed.
154
+ * @return The {@link KNNVectorValues} associated with the field during the merge.
155
+ * @throws IOException If an I/O error occurs during the retrieval.
156
+ */
157
+ private <T > KNNVectorValues <T > getKNNVectorValuesForMerge (
158
+ final VectorDataType vectorDataType ,
159
+ final FieldInfo fieldInfo ,
160
+ final MergeState mergeState
161
+ ) throws IOException {
162
+ switch (fieldInfo .getVectorEncoding ()) {
163
+ case FLOAT32 :
164
+ FloatVectorValues mergedFloats = MergedVectorValues .mergeFloatVectorValues (fieldInfo , mergeState );
165
+ return (KNNVectorValues <T >) KNNVectorValuesFactory .getVectorValues (vectorDataType , mergedFloats );
166
+ case BYTE :
167
+ ByteVectorValues mergedBytes = MergedVectorValues .mergeByteVectorValues (fieldInfo , mergeState );
168
+ return (KNNVectorValues <T >) KNNVectorValuesFactory .getVectorValues (vectorDataType , mergedBytes );
169
+ default :
170
+ throw new IllegalStateException ("Unsupported vector encoding [" + fieldInfo .getVectorEncoding () + "]" );
171
+ }
172
+ }
173
+
174
+ /**
175
+ * Functional interface representing an operation that indexes the provided {@link KNNVectorValues}.
176
+ *
177
+ * @param <T> The type of vectors being processed.
178
+ */
179
+ @ FunctionalInterface
180
+ private interface IndexOperation <T > {
181
+ void buildAndWrite (NativeIndexWriter writer , KNNVectorValues <T > knnVectorValues ) throws IOException ;
182
+ }
183
+
184
+ /**
185
+ * Functional interface representing a method that retrieves {@link KNNVectorValues} based on
186
+ * the vector data type, field information, and the merge state.
187
+ *
188
+ * @param <DataType> The type of the data representing the vector (e.g., {@link VectorDataType}).
189
+ * @param <FieldInfo> The metadata about the field.
190
+ * @param <MergeState> The state of the merge operation.
191
+ * @param <Result> The result of the retrieval, typically {@link KNNVectorValues}.
192
+ */
193
+ @ FunctionalInterface
194
+ private interface VectorValuesRetriever <DataType , FieldInfo , MergeState , Result > {
195
+ Result apply (DataType vectorDataType , FieldInfo fieldInfo , MergeState mergeState ) throws IOException ;
196
+ }
197
+
198
+ /**
199
+ * Unified method for processing a field during either the indexing or merge operation. This method retrieves vector values
200
+ * based on the provided vector data type and applies the specified index operation, potentially including quantization if needed.
201
+ *
202
+ * @param fieldInfo The {@link FieldInfo} object containing metadata about the field.
203
+ * @param vectorValuesRetriever A functional interface that retrieves {@link KNNVectorValues} based on the vector data type,
204
+ * field information, and additional context (e.g., merge state or field writer).
205
+ * @param indexOperation A functional interface that performs the indexing operation using the retrieved
206
+ * {@link KNNVectorValues}.
207
+ * @param VectorProcessingContext The additional context required for retrieving the vector values (e.g., {@link MergeState} or {@link NativeEngineFieldVectorsWriter}).
208
+ * From Flush we need NativeFieldWriter which contains total number of vectors while from Merge we need merge state which contains vector information
209
+ * @param <T> The type of vectors being processed.
210
+ * @param <C> The type of the context needed for retrieving the vector values.
211
+ * @throws IOException If an I/O error occurs during the processing.
212
+ */
213
+ private <T , C > void trainAndIndex (
214
+ final FieldInfo fieldInfo ,
215
+ final VectorValuesRetriever <VectorDataType , FieldInfo , C , KNNVectorValues <T >> vectorValuesRetriever ,
216
+ final IndexOperation <T > indexOperation ,
217
+ final C VectorProcessingContext
218
+ ) throws IOException {
219
+ final VectorDataType vectorDataType = extractVectorDataType (fieldInfo );
220
+ KNNVectorValues <T > knnVectorValues = vectorValuesRetriever .apply (vectorDataType , fieldInfo , VectorProcessingContext );
221
+ QuantizationParams quantizationParams = quantizationService .getQuantizationParams (fieldInfo );
222
+ QuantizationState quantizationState = null ;
223
+ if (quantizationParams != null ) {
224
+ quantizationState = quantizationService .train (quantizationParams , knnVectorValues );
225
+ }
226
+ NativeIndexWriter writer = (quantizationParams != null )
227
+ ? NativeIndexWriter .getWriter (fieldInfo , segmentWriteState , quantizationState )
228
+ : NativeIndexWriter .getWriter (fieldInfo , segmentWriteState );
229
+
230
+ knnVectorValues = vectorValuesRetriever .apply (vectorDataType , fieldInfo , VectorProcessingContext );
231
+ indexOperation .buildAndWrite (writer , knnVectorValues );
232
+ }
149
233
}
0 commit comments