Skip to content

Commit 8d1669c

Browse files
jed326Jay Deng
authored and
Jay Deng
committed
Add integrity checking to VectorRepositoryAccessor
Signed-off-by: Jay Deng <jayd0104@gmail.com>
1 parent b77b6b6 commit 8d1669c

File tree

5 files changed

+125
-40
lines changed

5 files changed

+125
-40
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1111
* [Remote Vector Index Build] Implement data download and IndexOutput write functionality [#2554](https://github.com/opensearch-project/k-NN/pull/2554)
1212
* [Remote Vector Index Build] Introduce Client Skeleton + basic Build Request implementation [#2560](https://github.com/opensearch-project/k-NN/pull/2560)
1313
* Add concurrency optimizations with native memory graph loading and force eviction (#2265) [https://github.com/opensearch-project/k-NN/pull/2345]
14+
* [Remote Vector Index Build] Add integrity checking to VectorRepositoryAccessor [#2578](https://github.com/opensearch-project/k-NN/pull/2578)
1415
### Enhancements
1516
* Introduce node level circuit breakers for k-NN [#2509](https://github.com/opensearch-project/k-NN/pull/2509)
1617
### Bug Fixes

src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/DefaultVectorRepositoryAccessor.java

+85-37
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.knn.index.codec.nativeindex.remote;
77

8+
import com.google.common.annotations.VisibleForTesting;
89
import lombok.AllArgsConstructor;
910
import lombok.extern.log4j.Log4j2;
1011
import org.opensearch.action.LatchedActionListener;
@@ -32,6 +33,8 @@
3233
import java.util.concurrent.CountDownLatch;
3334
import java.util.concurrent.atomic.AtomicReference;
3435
import java.util.function.Supplier;
36+
import java.util.zip.CRC32;
37+
import java.util.zip.CheckedInputStream;
3538

3639
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues;
3740
import static org.opensearch.knn.index.remote.KNNRemoteConstants.DOC_ID_FILE_EXTENSION;
@@ -74,50 +77,42 @@ public void writeToRepository(
7477
initializeVectorValues(knnVectorValues);
7578
long vectorBlobLength = (long) knnVectorValues.bytesPerVector() * totalLiveDocs;
7679

77-
if (blobContainer instanceof AsyncMultiStreamBlobContainer) {
80+
if (blobContainer instanceof AsyncMultiStreamBlobContainer asyncBlobContainer) {
7881
// First initiate vectors upload
7982
log.debug("Repository {} Supports Parallel Blob Upload", repository);
8083
// WriteContext is the main entry point into asyncBlobUpload. It stores all of our upload configurations, analogous to
8184
// BuildIndexParams
82-
WriteContext writeContext = new WriteContext.Builder().fileName(blobName + VECTOR_BLOB_FILE_EXTENSION)
83-
.streamContextSupplier((partSize) -> getStreamContext(partSize, vectorBlobLength, knnVectorValuesSupplier, vectorDataType))
84-
.fileSize(vectorBlobLength)
85-
.failIfAlreadyExists(true)
86-
.writePriority(WritePriority.NORMAL)
87-
// TODO: Checksum implementations -- It is difficult to calculate a checksum on the knnVectorValues as
88-
// there is no underlying file upon which we can create the checksum. We should be able to create a
89-
// checksum still by iterating through once, however this will be an expensive operation.
90-
.uploadFinalizer((bool) -> {})
91-
.doRemoteDataIntegrityCheck(false)
92-
.expectedChecksum(null)
93-
.build();
85+
WriteContext writeContext = createWriteContext(
86+
blobName,
87+
vectorBlobLength,
88+
knnVectorValuesSupplier,
89+
vectorDataType,
90+
asyncBlobContainer.remoteIntegrityCheckSupported()
91+
);
9492

9593
AtomicReference<Exception> exception = new AtomicReference<>();
9694
final CountDownLatch latch = new CountDownLatch(1);
97-
((AsyncMultiStreamBlobContainer) blobContainer).asyncBlobUpload(
98-
writeContext,
99-
new LatchedActionListener<>(new ActionListener<>() {
100-
@Override
101-
public void onResponse(Void unused) {
102-
log.debug(
103-
"Parallel vector upload succeeded for blob {} with size {}",
104-
blobName + VECTOR_BLOB_FILE_EXTENSION,
105-
vectorBlobLength
106-
);
107-
}
108-
109-
@Override
110-
public void onFailure(Exception e) {
111-
log.error(
112-
"Parallel vector upload failed for blob {} with size {}",
113-
blobName + VECTOR_BLOB_FILE_EXTENSION,
114-
vectorBlobLength,
115-
e
116-
);
117-
exception.set(e);
118-
}
119-
}, latch)
120-
);
95+
asyncBlobContainer.asyncBlobUpload(writeContext, new LatchedActionListener<>(new ActionListener<>() {
96+
@Override
97+
public void onResponse(Void unused) {
98+
log.debug(
99+
"Parallel vector upload succeeded for blob {} with size {}",
100+
blobName + VECTOR_BLOB_FILE_EXTENSION,
101+
vectorBlobLength
102+
);
103+
}
104+
105+
@Override
106+
public void onFailure(Exception e) {
107+
log.error(
108+
"Parallel vector upload failed for blob {} with size {}",
109+
blobName + VECTOR_BLOB_FILE_EXTENSION,
110+
vectorBlobLength,
111+
e
112+
);
113+
exception.set(e);
114+
}
115+
}, latch));
121116

122117
// Then upload doc id blob before waiting on vector uploads
123118
// TODO: We wrap with a BufferedInputStream to support retries. We can tune this buffer size to optimize performance.
@@ -215,6 +210,59 @@ private CheckedTriFunction<Integer, Long, Long, InputStreamContainer, IOExceptio
215210
});
216211
}
217212

213+
/**
214+
* Creates a {@link WriteContext} meant to be used by {@link AsyncMultiStreamBlobContainer#asyncBlobUpload}. If integrity checking is supported, calculates a checksum as well.
215+
* @param blobName
216+
* @param vectorBlobLength
217+
* @param knnVectorValuesSupplier
218+
* @param vectorDataType
219+
* @param supportsIntegrityCheck
220+
* @return
221+
* @throws IOException
222+
*/
223+
private WriteContext createWriteContext(
224+
String blobName,
225+
long vectorBlobLength,
226+
Supplier<KNNVectorValues<?>> knnVectorValuesSupplier,
227+
VectorDataType vectorDataType,
228+
boolean supportsIntegrityCheck
229+
) throws IOException {
230+
return new WriteContext.Builder().fileName(blobName + VECTOR_BLOB_FILE_EXTENSION)
231+
.streamContextSupplier((partSize) -> getStreamContext(partSize, vectorBlobLength, knnVectorValuesSupplier, vectorDataType))
232+
.fileSize(vectorBlobLength)
233+
.failIfAlreadyExists(true)
234+
.writePriority(WritePriority.NORMAL)
235+
.doRemoteDataIntegrityCheck(supportsIntegrityCheck)
236+
.uploadFinalizer((bool) -> {})
237+
.expectedChecksum(supportsIntegrityCheck ? getExpectedChecksum(knnVectorValuesSupplier.get(), vectorDataType) : null)
238+
.build();
239+
}
240+
241+
/**
242+
* Calculates a checksum on the given {@link KNNVectorValues}, representing all the vector data for the index build operation.
243+
* This is done by creating a {@link VectorValuesInputStream} which is wrapped by a {@link CheckedInputStream} and then reading all the data through the stream to calculate the checksum.
244+
* Note: This does add some overhead to the vector blob upload, as we are reading through the KNNVectorValues an additional time. If instead of taking an expected checksum up front
245+
* the WriteContext accepted an expectedChecksumSupplier, we could calculate the checksum as the stream is being uploaded and use that same value to compare, however this is pending
246+
* a change in OpenSearch core.
247+
*
248+
* @param knnVectorValues
249+
* @param vectorDataType
250+
* @return
251+
* @throws IOException
252+
*/
253+
@VisibleForTesting
254+
long getExpectedChecksum(KNNVectorValues<?> knnVectorValues, VectorDataType vectorDataType) throws IOException {
255+
CheckedInputStream checkedStream = new CheckedInputStream(
256+
new VectorValuesInputStream(knnVectorValues, vectorDataType),
257+
new CRC32()
258+
);
259+
int CHUNK_SIZE = 16 * 1024;
260+
final byte[] buffer = new byte[CHUNK_SIZE];
261+
while (checkedStream.read(buffer, 0, CHUNK_SIZE) != -1) {
262+
}
263+
return checkedStream.getChecksum().getValue();
264+
}
265+
218266
@Override
219267
public void readFromRepository(String path, IndexOutputWithBuffer indexOutputWithBuffer) throws IOException {
220268
if (path == null || path.isEmpty()) {

src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/DefaultVectorRepositoryAccessorTests.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@
3535
import static org.mockito.Mockito.mock;
3636
import static org.mockito.Mockito.verify;
3737
import static org.mockito.Mockito.when;
38+
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues;
3839
import static org.opensearch.knn.index.remote.KNNRemoteConstants.DOC_ID_FILE_EXTENSION;
3940
import static org.opensearch.knn.index.remote.KNNRemoteConstants.VECTOR_BLOB_FILE_EXTENSION;
40-
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues;
4141

4242
public class DefaultVectorRepositoryAccessorTests extends RemoteIndexBuildTests {
4343

@@ -109,7 +109,7 @@ public void testRepositoryInteractionWithBlobContainer() throws IOException, Int
109109
/**
110110
* Test that when an exception is thrown during asyncBlobUpload, the exception is rethrown.
111111
*/
112-
public void testAsyncUploadThrowsException() throws InterruptedException, IOException {
112+
public void testAsyncUploadThrowsException() throws IOException {
113113
RepositoriesService repositoriesService = mock(RepositoriesService.class);
114114
BlobStoreRepository mockRepository = mock(BlobStoreRepository.class);
115115
BlobPath testBasePath = new BlobPath().add("testBasePath");

src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/KnnVectorValuesInputStreamTests.java

+36
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
package org.opensearch.knn.index.codec.nativeindex.remote;
77

88
import org.apache.lucene.search.DocIdSetIterator;
9+
import org.opensearch.index.IndexSettings;
910
import org.opensearch.knn.KNNTestCase;
1011
import org.opensearch.knn.index.VectorDataType;
1112
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
1213
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
1314
import org.opensearch.knn.index.vectorvalues.TestVectorValues;
15+
import org.opensearch.repositories.blobstore.BlobStoreRepository;
1416

1517
import java.io.IOException;
1618
import java.io.InputStream;
@@ -20,7 +22,10 @@
2022
import java.util.ArrayList;
2123
import java.util.List;
2224
import java.util.function.Supplier;
25+
import java.util.zip.CRC32;
26+
import java.util.zip.Checksum;
2327

28+
import static org.mockito.Mockito.mock;
2429
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues;
2530
import static org.opensearch.knn.index.vectorvalues.TestVectorValues.getRandomByteVector;
2631
import static org.opensearch.knn.index.vectorvalues.TestVectorValues.getRandomVector;
@@ -263,6 +268,37 @@ public void testDocIdInputStreamReadByte() throws IOException {
263268
assertArrayEquals(bufferRead.array(), bufferReadByByte.array());
264269
}
265270

271+
/**
272+
* Test that calculating the checksum in parts yields the same result as calculating the checksum on the whole stream
273+
*/
274+
public void testVectorValuesChecksum() throws IOException {
275+
final int NUM_DOCS = randomIntBetween(100, 1000);
276+
final int NUM_DIMENSION = randomIntBetween(1, 1000);
277+
278+
List<float[]> vectorValues = getRandomFloatVectors(NUM_DOCS, NUM_DIMENSION);
279+
final Supplier<TestVectorValues.PreDefinedFloatVectorValues> randomVectorValuesSupplier =
280+
() -> new TestVectorValues.PreDefinedFloatVectorValues(vectorValues);
281+
282+
final Supplier<KNNVectorValues<float[]>> knnVectorValuesSupplier = () -> KNNVectorValuesFactory.getVectorValues(
283+
VectorDataType.FLOAT,
284+
randomVectorValuesSupplier.get()
285+
);
286+
287+
// Get checksum from VectorRepositoryAccessor
288+
DefaultVectorRepositoryAccessor vectorRepositoryAccessor = new DefaultVectorRepositoryAccessor(
289+
mock(BlobStoreRepository.class),
290+
mock(IndexSettings.class)
291+
);
292+
long expectedChecksum = vectorRepositoryAccessor.getExpectedChecksum(knnVectorValuesSupplier.get(), VectorDataType.FLOAT);
293+
294+
// Get checksum by reading the entire stream
295+
InputStream inputStream = new VectorValuesInputStream(knnVectorValuesSupplier.get(), VectorDataType.FLOAT);
296+
Checksum actualChecksum = new CRC32();
297+
actualChecksum.update(inputStream.readAllBytes());
298+
299+
assertEquals(expectedChecksum, actualChecksum.getValue());
300+
}
301+
266302
private List<float[]> getRandomFloatVectors(int numDocs, int dimension) {
267303
ArrayList<float[]> vectorValues = new ArrayList<>();
268304
for (int i = 0; i < numDocs; i++) {

src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildTests.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ public void readBlobAsync(String s, ActionListener<ReadContext> actionListener)
127127

128128
@Override
129129
public boolean remoteIntegrityCheckSupported() {
130-
return false;
130+
return true;
131131
}
132132

133133
@Override

0 commit comments

Comments
 (0)