diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/DefaultVectorRepositoryAccessor.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/DefaultVectorRepositoryAccessor.java index ae41e07cf..33098042c 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/DefaultVectorRepositoryAccessor.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/DefaultVectorRepositoryAccessor.java @@ -74,50 +74,36 @@ public void writeToRepository( initializeVectorValues(knnVectorValues); long vectorBlobLength = (long) knnVectorValues.bytesPerVector() * totalLiveDocs; - if (blobContainer instanceof AsyncMultiStreamBlobContainer) { + if (blobContainer instanceof AsyncMultiStreamBlobContainer asyncBlobContainer) { // First initiate vectors upload log.debug("Repository {} Supports Parallel Blob Upload", repository); // WriteContext is the main entry point into asyncBlobUpload. It stores all of our upload configurations, analogous to // BuildIndexParams - WriteContext writeContext = new WriteContext.Builder().fileName(blobName + VECTOR_BLOB_FILE_EXTENSION) - .streamContextSupplier((partSize) -> getStreamContext(partSize, vectorBlobLength, knnVectorValuesSupplier, vectorDataType)) - .fileSize(vectorBlobLength) - .failIfAlreadyExists(true) - .writePriority(WritePriority.NORMAL) - // TODO: Checksum implementations -- It is difficult to calculate a checksum on the knnVectorValues as - // there is no underlying file upon which we can create the checksum. We should be able to create a - // checksum still by iterating through once, however this will be an expensive operation. - .uploadFinalizer((bool) -> {}) - .doRemoteDataIntegrityCheck(false) - .expectedChecksum(null) - .build(); + WriteContext writeContext = createWriteContext(blobName, vectorBlobLength, knnVectorValuesSupplier, vectorDataType); AtomicReference exception = new AtomicReference<>(); final CountDownLatch latch = new CountDownLatch(1); - ((AsyncMultiStreamBlobContainer) blobContainer).asyncBlobUpload( - writeContext, - new LatchedActionListener<>(new ActionListener<>() { - @Override - public void onResponse(Void unused) { - log.debug( - "Parallel vector upload succeeded for blob {} with size {}", - blobName + VECTOR_BLOB_FILE_EXTENSION, - vectorBlobLength - ); - } - - @Override - public void onFailure(Exception e) { - log.error( - "Parallel vector upload failed for blob {} with size {}", - blobName + VECTOR_BLOB_FILE_EXTENSION, - vectorBlobLength, - e - ); - exception.set(e); - } - }, latch) - ); + asyncBlobContainer.asyncBlobUpload(writeContext, new LatchedActionListener<>(new ActionListener<>() { + @Override + public void onResponse(Void unused) { + log.debug( + "Parallel vector upload succeeded for blob {} with size {}", + blobName + VECTOR_BLOB_FILE_EXTENSION, + vectorBlobLength + ); + } + + @Override + public void onFailure(Exception e) { + log.error( + "Parallel vector upload failed for blob {} with size {}", + blobName + VECTOR_BLOB_FILE_EXTENSION, + vectorBlobLength, + e + ); + exception.set(e); + } + }, latch)); // Then upload doc id blob before waiting on vector uploads // TODO: We wrap with a BufferedInputStream to support retries. We can tune this buffer size to optimize performance. @@ -130,9 +116,14 @@ public void onFailure(Exception e) { } else { log.debug("Repository {} Does Not Support Parallel Blob Upload", repository); // Write Vectors - InputStream vectorStream = new BufferedInputStream(new VectorValuesInputStream(knnVectorValuesSupplier.get(), vectorDataType)); - log.debug("Writing {} bytes for {} docs to {}", vectorBlobLength, totalLiveDocs, blobName + VECTOR_BLOB_FILE_EXTENSION); - blobContainer.writeBlob(blobName + VECTOR_BLOB_FILE_EXTENSION, vectorStream, vectorBlobLength, true); + try ( + InputStream vectorStream = new BufferedInputStream( + new VectorValuesInputStream(knnVectorValuesSupplier.get(), vectorDataType) + ) + ) { + log.debug("Writing {} bytes for {} docs to {}", vectorBlobLength, totalLiveDocs, blobName + VECTOR_BLOB_FILE_EXTENSION); + blobContainer.writeBlob(blobName + VECTOR_BLOB_FILE_EXTENSION, vectorStream, vectorBlobLength, true); + } // Then write doc ids writeDocIds(knnVectorValuesSupplier.get(), vectorBlobLength, totalLiveDocs, blobName, blobContainer); } @@ -154,14 +145,15 @@ private void writeDocIds( String blobName, BlobContainer blobContainer ) throws IOException { - InputStream docStream = new BufferedInputStream(new DocIdInputStream(knnVectorValues)); - log.debug( - "Writing {} bytes for {} docs ids to {}", - vectorBlobLength, - totalLiveDocs * Integer.BYTES, - blobName + DOC_ID_FILE_EXTENSION - ); - blobContainer.writeBlob(blobName + DOC_ID_FILE_EXTENSION, docStream, totalLiveDocs * Integer.BYTES, true); + try (InputStream docStream = new BufferedInputStream(new DocIdInputStream(knnVectorValues))) { + log.debug( + "Writing {} bytes for {} docs ids to {}", + vectorBlobLength, + totalLiveDocs * Integer.BYTES, + blobName + DOC_ID_FILE_EXTENSION + ); + blobContainer.writeBlob(blobName + DOC_ID_FILE_EXTENSION, docStream, totalLiveDocs * Integer.BYTES, true); + } } /** @@ -215,6 +207,30 @@ private CheckedTriFunction> knnVectorValuesSupplier, + VectorDataType vectorDataType + ) { + return new WriteContext.Builder().fileName(blobName + VECTOR_BLOB_FILE_EXTENSION) + .streamContextSupplier((partSize) -> getStreamContext(partSize, vectorBlobLength, knnVectorValuesSupplier, vectorDataType)) + .fileSize(vectorBlobLength) + .failIfAlreadyExists(true) + .writePriority(WritePriority.NORMAL) + .uploadFinalizer((bool) -> {}) + .build(); + } + @Override public void readFromRepository(String path, IndexOutputWithBuffer indexOutputWithBuffer) throws IOException { if (path == null || path.isEmpty()) { diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/DocIdInputStream.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/DocIdInputStream.java index c1a398701..a8125605b 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/DocIdInputStream.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/DocIdInputStream.java @@ -13,6 +13,7 @@ import java.io.InputStream; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.util.concurrent.atomic.AtomicBoolean; import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues; @@ -25,6 +26,7 @@ class DocIdInputStream extends InputStream { // Doc ids are 4 byte integers, byte read() only returns a single byte, so we will need to track the byte position within a doc id. // For simplicity, and to maintain the byte ordering, we use a buffer with size of 1 int. private ByteBuffer currentBuffer; + private final AtomicBoolean closed = new AtomicBoolean(false); /** * Use to represent the doc ids of a {@link KNNVectorValues} as an {@link InputStream}. Expected to be used only with {@link org.opensearch.common.blobstore.BlobContainer#writeBlob}. @@ -41,6 +43,7 @@ public DocIdInputStream(KNNVectorValues knnVectorValues) throws IOException { @Override public int read() throws IOException { + checkClosed(); if (currentBuffer == null) { return -1; } @@ -59,6 +62,7 @@ public int read() throws IOException { @Override public int read(byte[] b, int off, int len) throws IOException { + checkClosed(); if (currentBuffer == null) { return -1; } @@ -77,6 +81,23 @@ public int read(byte[] b, int off, int len) throws IOException { return bytesToRead; } + /** + * Marks this stream as closed + * @throws IOException + */ + @Override + public void close() throws IOException { + super.close(); + currentBuffer = null; + closed.set(true); + } + + private void checkClosed() throws IOException { + if (closed.get()) { + throw new IOException("Stream closed"); + } + } + /** * Advances to the next doc, and then refills the buffer with the new doc. * @throws IOException diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/VectorValuesInputStream.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/VectorValuesInputStream.java index c46677e80..ba738e594 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/VectorValuesInputStream.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/VectorValuesInputStream.java @@ -17,6 +17,7 @@ import java.io.InputStream; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.util.concurrent.atomic.AtomicBoolean; import static org.opensearch.knn.index.VectorDataType.BINARY; import static org.opensearch.knn.index.VectorDataType.BYTE; @@ -36,6 +37,7 @@ class VectorValuesInputStream extends InputStream { private final int bytesPerVector; private long bytesRemaining; private final VectorDataType vectorDataType; + private final AtomicBoolean closed = new AtomicBoolean(false); /** * Used to represent a part of a {@link KNNVectorValues} as an {@link InputStream}. Expected to be used with @@ -84,6 +86,7 @@ public VectorValuesInputStream(KNNVectorValues knnVectorValues, VectorDataTyp @Override public int read() throws IOException { + checkClosed(); if (bytesRemaining <= 0 || currentBuffer == null) { return -1; } @@ -103,6 +106,7 @@ public int read() throws IOException { @Override public int read(byte[] b, int off, int len) throws IOException { + checkClosed(); if (bytesRemaining <= 0 || currentBuffer == null) { return -1; } @@ -132,9 +136,27 @@ public int read(byte[] b, int off, int len) throws IOException { */ @Override public long skip(long n) throws IOException { + checkClosed(); throw new UnsupportedOperationException("VectorValuesInputStream does not support skip"); } + /** + * Marks this stream as closed + * @throws IOException + */ + @Override + public void close() throws IOException { + super.close(); + currentBuffer = null; + closed.set(true); + } + + private void checkClosed() throws IOException { + if (closed.get()) { + throw new IOException("Stream closed"); + } + } + /** * Advances n bytes forward in the knnVectorValues. * Note: {@link KNNVectorValues#advance} is not supported when we are merging segments, so we do not use it here. diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/DefaultVectorRepositoryAccessorTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/DefaultVectorRepositoryAccessorTests.java index e401603ea..ddb78f0ef 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/DefaultVectorRepositoryAccessorTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/DefaultVectorRepositoryAccessorTests.java @@ -35,9 +35,9 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues; import static org.opensearch.knn.index.remote.KNNRemoteConstants.DOC_ID_FILE_EXTENSION; import static org.opensearch.knn.index.remote.KNNRemoteConstants.VECTOR_BLOB_FILE_EXTENSION; -import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues; public class DefaultVectorRepositoryAccessorTests extends RemoteIndexBuildTests { @@ -109,7 +109,7 @@ public void testRepositoryInteractionWithBlobContainer() throws IOException, Int /** * Test that when an exception is thrown during asyncBlobUpload, the exception is rethrown. */ - public void testAsyncUploadThrowsException() throws InterruptedException, IOException { + public void testAsyncUploadThrowsException() throws IOException { RepositoriesService repositoriesService = mock(RepositoriesService.class); BlobStoreRepository mockRepository = mock(BlobStoreRepository.class); BlobPath testBasePath = new BlobPath().add("testBasePath"); diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildTests.java index 904eb2c92..46916b4da 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildTests.java @@ -127,7 +127,7 @@ public void readBlobAsync(String s, ActionListener actionListener) @Override public boolean remoteIntegrityCheckSupported() { - return false; + return true; } @Override