Skip to content

Commit fbff2dd

Browse files
jed326Jay Deng
authored and
Jay Deng
committed
Add integrity checking to VectorRepositoryAccessor
Signed-off-by: Jay Deng <jayd0104@gmail.com>
1 parent b77b6b6 commit fbff2dd

File tree

7 files changed

+191
-51
lines changed

7 files changed

+191
-51
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1111
* [Remote Vector Index Build] Implement data download and IndexOutput write functionality [#2554](https://github.com/opensearch-project/k-NN/pull/2554)
1212
* [Remote Vector Index Build] Introduce Client Skeleton + basic Build Request implementation [#2560](https://github.com/opensearch-project/k-NN/pull/2560)
1313
* Add concurrency optimizations with native memory graph loading and force eviction (#2265) [https://github.com/opensearch-project/k-NN/pull/2345]
14+
* [Remote Vector Index Build] Add integrity checking to VectorRepositoryAccessor [#2578](https://github.com/opensearch-project/k-NN/pull/2578)
1415
### Enhancements
1516
* Introduce node level circuit breakers for k-NN [#2509](https://github.com/opensearch-project/k-NN/pull/2509)
1617
### Bug Fixes

src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/DefaultVectorRepositoryAccessor.java

+108-48
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.knn.index.codec.nativeindex.remote;
77

8+
import com.google.common.annotations.VisibleForTesting;
89
import lombok.AllArgsConstructor;
910
import lombok.extern.log4j.Log4j2;
1011
import org.opensearch.action.LatchedActionListener;
@@ -32,6 +33,8 @@
3233
import java.util.concurrent.CountDownLatch;
3334
import java.util.concurrent.atomic.AtomicReference;
3435
import java.util.function.Supplier;
36+
import java.util.zip.CRC32;
37+
import java.util.zip.CheckedInputStream;
3538

3639
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues;
3740
import static org.opensearch.knn.index.remote.KNNRemoteConstants.DOC_ID_FILE_EXTENSION;
@@ -74,50 +77,42 @@ public void writeToRepository(
7477
initializeVectorValues(knnVectorValues);
7578
long vectorBlobLength = (long) knnVectorValues.bytesPerVector() * totalLiveDocs;
7679

77-
if (blobContainer instanceof AsyncMultiStreamBlobContainer) {
80+
if (blobContainer instanceof AsyncMultiStreamBlobContainer asyncBlobContainer) {
7881
// First initiate vectors upload
7982
log.debug("Repository {} Supports Parallel Blob Upload", repository);
8083
// WriteContext is the main entry point into asyncBlobUpload. It stores all of our upload configurations, analogous to
8184
// BuildIndexParams
82-
WriteContext writeContext = new WriteContext.Builder().fileName(blobName + VECTOR_BLOB_FILE_EXTENSION)
83-
.streamContextSupplier((partSize) -> getStreamContext(partSize, vectorBlobLength, knnVectorValuesSupplier, vectorDataType))
84-
.fileSize(vectorBlobLength)
85-
.failIfAlreadyExists(true)
86-
.writePriority(WritePriority.NORMAL)
87-
// TODO: Checksum implementations -- It is difficult to calculate a checksum on the knnVectorValues as
88-
// there is no underlying file upon which we can create the checksum. We should be able to create a
89-
// checksum still by iterating through once, however this will be an expensive operation.
90-
.uploadFinalizer((bool) -> {})
91-
.doRemoteDataIntegrityCheck(false)
92-
.expectedChecksum(null)
93-
.build();
85+
WriteContext writeContext = createWriteContext(
86+
blobName,
87+
vectorBlobLength,
88+
knnVectorValuesSupplier,
89+
vectorDataType,
90+
asyncBlobContainer.remoteIntegrityCheckSupported()
91+
);
9492

9593
AtomicReference<Exception> exception = new AtomicReference<>();
9694
final CountDownLatch latch = new CountDownLatch(1);
97-
((AsyncMultiStreamBlobContainer) blobContainer).asyncBlobUpload(
98-
writeContext,
99-
new LatchedActionListener<>(new ActionListener<>() {
100-
@Override
101-
public void onResponse(Void unused) {
102-
log.debug(
103-
"Parallel vector upload succeeded for blob {} with size {}",
104-
blobName + VECTOR_BLOB_FILE_EXTENSION,
105-
vectorBlobLength
106-
);
107-
}
108-
109-
@Override
110-
public void onFailure(Exception e) {
111-
log.error(
112-
"Parallel vector upload failed for blob {} with size {}",
113-
blobName + VECTOR_BLOB_FILE_EXTENSION,
114-
vectorBlobLength,
115-
e
116-
);
117-
exception.set(e);
118-
}
119-
}, latch)
120-
);
95+
asyncBlobContainer.asyncBlobUpload(writeContext, new LatchedActionListener<>(new ActionListener<>() {
96+
@Override
97+
public void onResponse(Void unused) {
98+
log.debug(
99+
"Parallel vector upload succeeded for blob {} with size {}",
100+
blobName + VECTOR_BLOB_FILE_EXTENSION,
101+
vectorBlobLength
102+
);
103+
}
104+
105+
@Override
106+
public void onFailure(Exception e) {
107+
log.error(
108+
"Parallel vector upload failed for blob {} with size {}",
109+
blobName + VECTOR_BLOB_FILE_EXTENSION,
110+
vectorBlobLength,
111+
e
112+
);
113+
exception.set(e);
114+
}
115+
}, latch));
121116

122117
// Then upload doc id blob before waiting on vector uploads
123118
// TODO: We wrap with a BufferedInputStream to support retries. We can tune this buffer size to optimize performance.
@@ -130,9 +125,14 @@ public void onFailure(Exception e) {
130125
} else {
131126
log.debug("Repository {} Does Not Support Parallel Blob Upload", repository);
132127
// Write Vectors
133-
InputStream vectorStream = new BufferedInputStream(new VectorValuesInputStream(knnVectorValuesSupplier.get(), vectorDataType));
134-
log.debug("Writing {} bytes for {} docs to {}", vectorBlobLength, totalLiveDocs, blobName + VECTOR_BLOB_FILE_EXTENSION);
135-
blobContainer.writeBlob(blobName + VECTOR_BLOB_FILE_EXTENSION, vectorStream, vectorBlobLength, true);
128+
try (
129+
InputStream vectorStream = new BufferedInputStream(
130+
new VectorValuesInputStream(knnVectorValuesSupplier.get(), vectorDataType)
131+
)
132+
) {
133+
log.debug("Writing {} bytes for {} docs to {}", vectorBlobLength, totalLiveDocs, blobName + VECTOR_BLOB_FILE_EXTENSION);
134+
blobContainer.writeBlob(blobName + VECTOR_BLOB_FILE_EXTENSION, vectorStream, vectorBlobLength, true);
135+
}
136136
// Then write doc ids
137137
writeDocIds(knnVectorValuesSupplier.get(), vectorBlobLength, totalLiveDocs, blobName, blobContainer);
138138
}
@@ -154,14 +154,15 @@ private void writeDocIds(
154154
String blobName,
155155
BlobContainer blobContainer
156156
) throws IOException {
157-
InputStream docStream = new BufferedInputStream(new DocIdInputStream(knnVectorValues));
158-
log.debug(
159-
"Writing {} bytes for {} docs ids to {}",
160-
vectorBlobLength,
161-
totalLiveDocs * Integer.BYTES,
162-
blobName + DOC_ID_FILE_EXTENSION
163-
);
164-
blobContainer.writeBlob(blobName + DOC_ID_FILE_EXTENSION, docStream, totalLiveDocs * Integer.BYTES, true);
157+
try (InputStream docStream = new BufferedInputStream(new DocIdInputStream(knnVectorValues))) {
158+
log.debug(
159+
"Writing {} bytes for {} docs ids to {}",
160+
vectorBlobLength,
161+
totalLiveDocs * Integer.BYTES,
162+
blobName + DOC_ID_FILE_EXTENSION
163+
);
164+
blobContainer.writeBlob(blobName + DOC_ID_FILE_EXTENSION, docStream, totalLiveDocs * Integer.BYTES, true);
165+
}
165166
}
166167

167168
/**
@@ -215,6 +216,65 @@ private CheckedTriFunction<Integer, Long, Long, InputStreamContainer, IOExceptio
215216
});
216217
}
217218

219+
/**
220+
* Creates a {@link WriteContext} meant to be used by {@link AsyncMultiStreamBlobContainer#asyncBlobUpload}. If integrity checking is supported, calculates a checksum as well.
221+
* @param blobName
222+
* @param vectorBlobLength
223+
* @param knnVectorValuesSupplier
224+
* @param vectorDataType
225+
* @param supportsIntegrityCheck
226+
* @return
227+
* @throws IOException
228+
*/
229+
private WriteContext createWriteContext(
230+
String blobName,
231+
long vectorBlobLength,
232+
Supplier<KNNVectorValues<?>> knnVectorValuesSupplier,
233+
VectorDataType vectorDataType,
234+
boolean supportsIntegrityCheck
235+
) throws IOException {
236+
return new WriteContext.Builder().fileName(blobName + VECTOR_BLOB_FILE_EXTENSION)
237+
.streamContextSupplier((partSize) -> getStreamContext(partSize, vectorBlobLength, knnVectorValuesSupplier, vectorDataType))
238+
.fileSize(vectorBlobLength)
239+
.failIfAlreadyExists(true)
240+
.writePriority(WritePriority.NORMAL)
241+
.doRemoteDataIntegrityCheck(supportsIntegrityCheck)
242+
.uploadFinalizer((bool) -> {})
243+
.expectedChecksum(supportsIntegrityCheck ? getExpectedChecksum(knnVectorValuesSupplier.get(), vectorDataType) : null)
244+
.build();
245+
}
246+
247+
/**
248+
* Calculates a checksum on the given {@link KNNVectorValues}, representing all the vector data for the index build operation.
249+
* This is done by creating a {@link VectorValuesInputStream} which is wrapped by a {@link CheckedInputStream} and then reading all the data through the stream to calculate the checksum.
250+
* Note: This does add some overhead to the vector blob upload, as we are reading through the KNNVectorValues an additional time. If instead of taking an expected checksum up front
251+
* the WriteContext accepted an expectedChecksumSupplier, we could calculate the checksum as the stream is being uploaded and use that same value to compare, however this is pending
252+
* a change in OpenSearch core.
253+
*
254+
* @param knnVectorValues
255+
* @param vectorDataType
256+
* @return
257+
* @throws IOException
258+
*/
259+
@VisibleForTesting
260+
long getExpectedChecksum(KNNVectorValues<?> knnVectorValues, VectorDataType vectorDataType) throws IOException {
261+
try (
262+
CheckedInputStream checkedStream = new CheckedInputStream(
263+
new VectorValuesInputStream(knnVectorValues, vectorDataType),
264+
new CRC32()
265+
)
266+
) {
267+
// VectorValuesInputStream#read only reads 1 vector max at a time, so no point making this buffer any larger than that
268+
initializeVectorValues(knnVectorValues);
269+
int bufferSize = knnVectorValues.bytesPerVector();
270+
final byte[] buffer = new byte[bufferSize];
271+
// Checksum is computed by reading through the CheckedInputStream
272+
while (checkedStream.read(buffer, 0, bufferSize) != -1) {
273+
}
274+
return checkedStream.getChecksum().getValue();
275+
}
276+
}
277+
218278
@Override
219279
public void readFromRepository(String path, IndexOutputWithBuffer indexOutputWithBuffer) throws IOException {
220280
if (path == null || path.isEmpty()) {

src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/DocIdInputStream.java

+21
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import java.io.InputStream;
1414
import java.nio.ByteBuffer;
1515
import java.nio.ByteOrder;
16+
import java.util.concurrent.atomic.AtomicBoolean;
1617

1718
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues;
1819

@@ -25,6 +26,7 @@ class DocIdInputStream extends InputStream {
2526
// Doc ids are 4 byte integers, byte read() only returns a single byte, so we will need to track the byte position within a doc id.
2627
// For simplicity, and to maintain the byte ordering, we use a buffer with size of 1 int.
2728
private ByteBuffer currentBuffer;
29+
private final AtomicBoolean closed = new AtomicBoolean(false);
2830

2931
/**
3032
* Use to represent the doc ids of a {@link KNNVectorValues} as an {@link InputStream}. Expected to be used only with {@link org.opensearch.common.blobstore.BlobContainer#writeBlob}.
@@ -41,6 +43,7 @@ public DocIdInputStream(KNNVectorValues<?> knnVectorValues) throws IOException {
4143

4244
@Override
4345
public int read() throws IOException {
46+
checkClosed();
4447
if (currentBuffer == null) {
4548
return -1;
4649
}
@@ -59,6 +62,7 @@ public int read() throws IOException {
5962

6063
@Override
6164
public int read(byte[] b, int off, int len) throws IOException {
65+
checkClosed();
6266
if (currentBuffer == null) {
6367
return -1;
6468
}
@@ -77,6 +81,23 @@ public int read(byte[] b, int off, int len) throws IOException {
7781
return bytesToRead;
7882
}
7983

84+
/**
85+
* Marks this stream as closed
86+
* @throws IOException
87+
*/
88+
@Override
89+
public void close() throws IOException {
90+
super.close();
91+
currentBuffer = null;
92+
closed.set(true);
93+
}
94+
95+
private void checkClosed() throws IOException {
96+
if (closed.get()) {
97+
throw new IOException("Stream closed");
98+
}
99+
}
100+
80101
/**
81102
* Advances to the next doc, and then refills the buffer with the new doc.
82103
* @throws IOException

src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/VectorValuesInputStream.java

+22
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import java.io.InputStream;
1818
import java.nio.ByteBuffer;
1919
import java.nio.ByteOrder;
20+
import java.util.concurrent.atomic.AtomicBoolean;
2021

2122
import static org.opensearch.knn.index.VectorDataType.BINARY;
2223
import static org.opensearch.knn.index.VectorDataType.BYTE;
@@ -36,6 +37,7 @@ class VectorValuesInputStream extends InputStream {
3637
private final int bytesPerVector;
3738
private long bytesRemaining;
3839
private final VectorDataType vectorDataType;
40+
private final AtomicBoolean closed = new AtomicBoolean(false);
3941

4042
/**
4143
* Used to represent a part of a {@link KNNVectorValues} as an {@link InputStream}. Expected to be used with
@@ -84,6 +86,7 @@ public VectorValuesInputStream(KNNVectorValues<?> knnVectorValues, VectorDataTyp
8486

8587
@Override
8688
public int read() throws IOException {
89+
checkClosed();
8790
if (bytesRemaining <= 0 || currentBuffer == null) {
8891
return -1;
8992
}
@@ -103,6 +106,7 @@ public int read() throws IOException {
103106

104107
@Override
105108
public int read(byte[] b, int off, int len) throws IOException {
109+
checkClosed();
106110
if (bytesRemaining <= 0 || currentBuffer == null) {
107111
return -1;
108112
}
@@ -132,9 +136,27 @@ public int read(byte[] b, int off, int len) throws IOException {
132136
*/
133137
@Override
134138
public long skip(long n) throws IOException {
139+
checkClosed();
135140
throw new UnsupportedOperationException("VectorValuesInputStream does not support skip");
136141
}
137142

143+
/**
144+
* Marks this stream as closed
145+
* @throws IOException
146+
*/
147+
@Override
148+
public void close() throws IOException {
149+
super.close();
150+
currentBuffer = null;
151+
closed.set(true);
152+
}
153+
154+
private void checkClosed() throws IOException {
155+
if (closed.get()) {
156+
throw new IOException("Stream closed");
157+
}
158+
}
159+
138160
/**
139161
* Advances n bytes forward in the knnVectorValues.
140162
* Note: {@link KNNVectorValues#advance} is not supported when we are merging segments, so we do not use it here.

src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/DefaultVectorRepositoryAccessorTests.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@
3535
import static org.mockito.Mockito.mock;
3636
import static org.mockito.Mockito.verify;
3737
import static org.mockito.Mockito.when;
38+
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues;
3839
import static org.opensearch.knn.index.remote.KNNRemoteConstants.DOC_ID_FILE_EXTENSION;
3940
import static org.opensearch.knn.index.remote.KNNRemoteConstants.VECTOR_BLOB_FILE_EXTENSION;
40-
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues;
4141

4242
public class DefaultVectorRepositoryAccessorTests extends RemoteIndexBuildTests {
4343

@@ -109,7 +109,7 @@ public void testRepositoryInteractionWithBlobContainer() throws IOException, Int
109109
/**
110110
* Test that when an exception is thrown during asyncBlobUpload, the exception is rethrown.
111111
*/
112-
public void testAsyncUploadThrowsException() throws InterruptedException, IOException {
112+
public void testAsyncUploadThrowsException() throws IOException {
113113
RepositoriesService repositoriesService = mock(RepositoriesService.class);
114114
BlobStoreRepository mockRepository = mock(BlobStoreRepository.class);
115115
BlobPath testBasePath = new BlobPath().add("testBasePath");

0 commit comments

Comments
 (0)