Skip to content

Commit d58d133

Browse files
authored
Remove DocsWithFieldSet reference from NativeEngineFieldVectorsWriter (#2408)
* Remove DocsWithFieldSet reference from NativeEngineFieldVectorsWriter Signed-off-by: Wei Wang <weiwsde@gmail.com> * fix typo error in test file Signed-off-by: Wei Wang <weiwsde@gmail.com> --------- Signed-off-by: Wei Wang <weiwsde@gmail.com> Signed-off-by: Wei Wang <93847013+weiwang118@users.noreply.github.com>
1 parent 1c4a7ca commit d58d133

File tree

6 files changed

+23
-21
lines changed

6 files changed

+23
-21
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
2929
- Use one formula to calculate cosine similarity (#2357)[https://github.com/opensearch-project/k-NN/pull/2357]
3030
- Add WithFieldName implementation to KNNQueryBuilder (#2398)[https://github.com/opensearch-project/k-NN/pull/2398]
3131
- Make the build work for M series MacOS without manual code changes and local JAVA_HOME config (#2397)[https://github.com/opensearch-project/k-NN/pull/2397]
32+
- Remove DocsWithFieldSet reference from NativeEngineFieldVectorsWriter (#2408)[https://github.com/opensearch-project/k-NN/pull/2408]
3233
### Bug Fixes
3334
* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282]
3435
* Fixing the bug where search fails with "fields" parameter for an index with a knn_vector field (#2314)[https://github.com/opensearch-project/k-NN/pull/2314]

src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java

+5-10
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import lombok.Getter;
1515
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
1616
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
17-
import org.apache.lucene.index.DocsWithFieldSet;
1817
import org.apache.lucene.index.FieldInfo;
1918
import org.apache.lucene.util.InfoStream;
2019
import org.apache.lucene.util.RamUsageEstimator;
@@ -43,9 +42,8 @@ class NativeEngineFieldVectorsWriter<T> extends KnnFieldVectorsWriter<T> {
4342
@Getter
4443
private final Map<Integer, T> vectors;
4544
private int lastDocID = -1;
46-
@Getter
47-
private final DocsWithFieldSet docsWithField;
4845
private final InfoStream infoStream;
46+
@Getter
4947
private final FlatFieldVectorsWriter<T> flatFieldVectorsWriter;
5048

5149
@SuppressWarnings("unchecked")
@@ -75,7 +73,6 @@ private NativeEngineFieldVectorsWriter(
7573
this.fieldInfo = fieldInfo;
7674
this.infoStream = infoStream;
7775
vectors = new HashMap<>();
78-
this.docsWithField = new DocsWithFieldSet();
7976
this.flatFieldVectorsWriter = flatFieldVectorsWriter;
8077
}
8178

@@ -101,7 +98,6 @@ public void addValue(int docID, T vectorValue) throws IOException {
10198
// ensuring that vector is provided to flatFieldWriter.
10299
flatFieldVectorsWriter.addValue(docID, vectorValue);
103100
vectors.put(docID, vectorValue);
104-
docsWithField.add(docID);
105101
lastDocID = docID;
106102
}
107103

@@ -121,10 +117,9 @@ public T copyValue(T vectorValue) {
121117
*/
122118
@Override
123119
public long ramBytesUsed() {
124-
return SHALLOW_SIZE + docsWithField.ramBytesUsed() + (long) this.vectors.size() * (long) (RamUsageEstimator.NUM_BYTES_OBJECT_REF
125-
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) + (long) this.vectors.size() * RamUsageEstimator.shallowSizeOfInstance(
126-
Integer.class
127-
) + (long) vectors.size() * fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize + flatFieldVectorsWriter
128-
.ramBytesUsed();
120+
return SHALLOW_SIZE + flatFieldVectorsWriter.getDocsWithFieldSet().ramBytesUsed() + (long) this.vectors.size()
121+
* (long) (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) + (long) this.vectors.size()
122+
* RamUsageEstimator.shallowSizeOfInstance(Integer.class) + (long) vectors.size() * fieldInfo.getVectorDimension()
123+
* fieldInfo.getVectorEncoding().byteSize + flatFieldVectorsWriter.ramBytesUsed();
129124
}
130125
}

src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
100100
}
101101
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier = () -> getVectorValues(
102102
vectorDataType,
103-
field.getDocsWithField(),
103+
field.getFlatFieldVectorsWriter().getDocsWithFieldSet(),
104104
field.getVectors()
105105
);
106106
final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs);

src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriterTests.java

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import lombok.SneakyThrows;
1515
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
16+
import org.apache.lucene.index.DocsWithFieldSet;
1617
import org.apache.lucene.index.FieldInfo;
1718
import org.apache.lucene.index.VectorEncoding;
1819
import org.apache.lucene.util.InfoStream;
@@ -115,6 +116,7 @@ public void testRamByteUsed_whenValidInput_thenSuccess() {
115116
Mockito.when(fieldInfo.getVectorDimension()).thenReturn(2);
116117
FlatFieldVectorsWriter<?> mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class);
117118
Mockito.when(mockedFlatFieldVectorsWriter.ramBytesUsed()).thenReturn(1L);
119+
Mockito.when(mockedFlatFieldVectorsWriter.getDocsWithFieldSet()).thenReturn(new DocsWithFieldSet());
118120
final NativeEngineFieldVectorsWriter<float[]> floatWriter = (NativeEngineFieldVectorsWriter<float[]>) NativeEngineFieldVectorsWriter
119121
.create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault());
120122
// testing for value > 0 as we don't have a concrete way to find out expected bytes. This can OS dependent too.

src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java

+11-9
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ public void testFlush() {
161161
throw new RuntimeException(e);
162162
}
163163

164-
DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
164+
DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet();
165165
knnVectorValuesFactoryMockedStatic.when(
166166
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
167167
).thenReturn(expectedVectorValues.get(i));
@@ -250,7 +250,7 @@ public void testFlush_WithQuantization() {
250250
throw new RuntimeException(e);
251251
}
252252

253-
DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
253+
DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet();
254254
knnVectorValuesFactoryMockedStatic.when(
255255
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
256256
).thenReturn(expectedVectorValues.get(i));
@@ -352,7 +352,7 @@ public void testFlush_whenThresholdIsNegative_thenNativeIndexWriterIsNeverCalled
352352
throw new RuntimeException(e);
353353
}
354354

355-
DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
355+
DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet();
356356
knnVectorValuesFactoryMockedStatic.when(
357357
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
358358
).thenReturn(expectedVectorValues.get(i));
@@ -429,7 +429,7 @@ public void testFlush_whenThresholdIsGreaterThanVectorSize_thenNativeIndexWriter
429429
throw new RuntimeException(e);
430430
}
431431

432-
DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
432+
DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet();
433433
knnVectorValuesFactoryMockedStatic.when(
434434
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
435435
).thenReturn(expectedVectorValues.get(i));
@@ -507,7 +507,7 @@ public void testFlush_whenThresholdIsEqualToMinNumberOfVectors_thenNativeIndexWr
507507
throw new RuntimeException(e);
508508
}
509509

510-
DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
510+
DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet();
511511
knnVectorValuesFactoryMockedStatic.when(
512512
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
513513
).thenReturn(expectedVectorValues.get(i));
@@ -593,7 +593,7 @@ public void testFlush_whenThresholdIsEqualToFixedValue_thenRelevantNativeIndexWr
593593
throw new RuntimeException(e);
594594
}
595595

596-
DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
596+
DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet();
597597
knnVectorValuesFactoryMockedStatic.when(
598598
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
599599
).thenReturn(expectedVectorValues.get(i));
@@ -683,7 +683,7 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres
683683
throw new RuntimeException(e);
684684
}
685685

686-
DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
686+
DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet();
687687
knnVectorValuesFactoryMockedStatic.when(
688688
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
689689
).thenReturn(expectedVectorValues.get(i));
@@ -786,7 +786,7 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres
786786
throw new RuntimeException(e);
787787
}
788788

789-
DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
789+
DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet();
790790
knnVectorValuesFactoryMockedStatic.when(
791791
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
792792
).thenReturn(expectedVectorValues.get(i));
@@ -848,11 +848,13 @@ private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding, Map<
848848

849849
private <T> NativeEngineFieldVectorsWriter nativeEngineFieldVectorsWriter(FieldInfo fieldInfo, Map<Integer, T> vectors) {
850850
NativeEngineFieldVectorsWriter fieldVectorsWriter = mock(NativeEngineFieldVectorsWriter.class);
851+
FlatFieldVectorsWriter flatFieldVectorsWriter = mock(FlatFieldVectorsWriter.class);
851852
DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet();
852853
vectors.keySet().stream().sorted().forEach(docsWithFieldSet::add);
853854
when(fieldVectorsWriter.getFieldInfo()).thenReturn(fieldInfo);
854855
when(fieldVectorsWriter.getVectors()).thenReturn(vectors);
855-
when(fieldVectorsWriter.getDocsWithField()).thenReturn(docsWithFieldSet);
856+
when(fieldVectorsWriter.getFlatFieldVectorsWriter()).thenReturn(flatFieldVectorsWriter);
857+
when(flatFieldVectorsWriter.getDocsWithFieldSet()).thenReturn(docsWithFieldSet);
856858
return fieldVectorsWriter;
857859
}
858860
}

src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -370,11 +370,13 @@ private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding, Map<
370370

371371
private <T> NativeEngineFieldVectorsWriter nativeEngineFieldVectorsWriter(FieldInfo fieldInfo, Map<Integer, T> vectors) {
372372
NativeEngineFieldVectorsWriter fieldVectorsWriter = mock(NativeEngineFieldVectorsWriter.class);
373+
FlatFieldVectorsWriter flatFieldVectorsWriter = mock(FlatFieldVectorsWriter.class);
373374
DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet();
374375
vectors.keySet().stream().sorted().forEach(docsWithFieldSet::add);
375376
when(fieldVectorsWriter.getFieldInfo()).thenReturn(fieldInfo);
376377
when(fieldVectorsWriter.getVectors()).thenReturn(vectors);
377-
when(fieldVectorsWriter.getDocsWithField()).thenReturn(docsWithFieldSet);
378+
when(fieldVectorsWriter.getFlatFieldVectorsWriter()).thenReturn(flatFieldVectorsWriter);
379+
when(flatFieldVectorsWriter.getDocsWithFieldSet()).thenReturn(docsWithFieldSet);
378380
return fieldVectorsWriter;
379381
}
380382
}

0 commit comments

Comments
 (0)