Skip to content

Commit 6b46325

Browse files
jed326Jay Deng
authored and
Jay Deng
committed
Remove integrity checking TODO and leave to the vendor implementation
Signed-off-by: Jay Deng <jayd0104@gmail.com>
1 parent b77b6b6 commit 6b46325

File tree

5 files changed

+110
-51
lines changed

5 files changed

+110
-51
lines changed

src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/DefaultVectorRepositoryAccessor.java

+64-48
Original file line numberDiff line numberDiff line change
@@ -74,50 +74,36 @@ public void writeToRepository(
7474
initializeVectorValues(knnVectorValues);
7575
long vectorBlobLength = (long) knnVectorValues.bytesPerVector() * totalLiveDocs;
7676

77-
if (blobContainer instanceof AsyncMultiStreamBlobContainer) {
77+
if (blobContainer instanceof AsyncMultiStreamBlobContainer asyncBlobContainer) {
7878
// First initiate vectors upload
7979
log.debug("Repository {} Supports Parallel Blob Upload", repository);
8080
// WriteContext is the main entry point into asyncBlobUpload. It stores all of our upload configurations, analogous to
8181
// BuildIndexParams
82-
WriteContext writeContext = new WriteContext.Builder().fileName(blobName + VECTOR_BLOB_FILE_EXTENSION)
83-
.streamContextSupplier((partSize) -> getStreamContext(partSize, vectorBlobLength, knnVectorValuesSupplier, vectorDataType))
84-
.fileSize(vectorBlobLength)
85-
.failIfAlreadyExists(true)
86-
.writePriority(WritePriority.NORMAL)
87-
// TODO: Checksum implementations -- It is difficult to calculate a checksum on the knnVectorValues as
88-
// there is no underlying file upon which we can create the checksum. We should be able to create a
89-
// checksum still by iterating through once, however this will be an expensive operation.
90-
.uploadFinalizer((bool) -> {})
91-
.doRemoteDataIntegrityCheck(false)
92-
.expectedChecksum(null)
93-
.build();
82+
WriteContext writeContext = createWriteContext(blobName, vectorBlobLength, knnVectorValuesSupplier, vectorDataType);
9483

9584
AtomicReference<Exception> exception = new AtomicReference<>();
9685
final CountDownLatch latch = new CountDownLatch(1);
97-
((AsyncMultiStreamBlobContainer) blobContainer).asyncBlobUpload(
98-
writeContext,
99-
new LatchedActionListener<>(new ActionListener<>() {
100-
@Override
101-
public void onResponse(Void unused) {
102-
log.debug(
103-
"Parallel vector upload succeeded for blob {} with size {}",
104-
blobName + VECTOR_BLOB_FILE_EXTENSION,
105-
vectorBlobLength
106-
);
107-
}
108-
109-
@Override
110-
public void onFailure(Exception e) {
111-
log.error(
112-
"Parallel vector upload failed for blob {} with size {}",
113-
blobName + VECTOR_BLOB_FILE_EXTENSION,
114-
vectorBlobLength,
115-
e
116-
);
117-
exception.set(e);
118-
}
119-
}, latch)
120-
);
86+
asyncBlobContainer.asyncBlobUpload(writeContext, new LatchedActionListener<>(new ActionListener<>() {
87+
@Override
88+
public void onResponse(Void unused) {
89+
log.debug(
90+
"Parallel vector upload succeeded for blob {} with size {}",
91+
blobName + VECTOR_BLOB_FILE_EXTENSION,
92+
vectorBlobLength
93+
);
94+
}
95+
96+
@Override
97+
public void onFailure(Exception e) {
98+
log.error(
99+
"Parallel vector upload failed for blob {} with size {}",
100+
blobName + VECTOR_BLOB_FILE_EXTENSION,
101+
vectorBlobLength,
102+
e
103+
);
104+
exception.set(e);
105+
}
106+
}, latch));
121107

122108
// Then upload doc id blob before waiting on vector uploads
123109
// TODO: We wrap with a BufferedInputStream to support retries. We can tune this buffer size to optimize performance.
@@ -130,9 +116,14 @@ public void onFailure(Exception e) {
130116
} else {
131117
log.debug("Repository {} Does Not Support Parallel Blob Upload", repository);
132118
// Write Vectors
133-
InputStream vectorStream = new BufferedInputStream(new VectorValuesInputStream(knnVectorValuesSupplier.get(), vectorDataType));
134-
log.debug("Writing {} bytes for {} docs to {}", vectorBlobLength, totalLiveDocs, blobName + VECTOR_BLOB_FILE_EXTENSION);
135-
blobContainer.writeBlob(blobName + VECTOR_BLOB_FILE_EXTENSION, vectorStream, vectorBlobLength, true);
119+
try (
120+
InputStream vectorStream = new BufferedInputStream(
121+
new VectorValuesInputStream(knnVectorValuesSupplier.get(), vectorDataType)
122+
)
123+
) {
124+
log.debug("Writing {} bytes for {} docs to {}", vectorBlobLength, totalLiveDocs, blobName + VECTOR_BLOB_FILE_EXTENSION);
125+
blobContainer.writeBlob(blobName + VECTOR_BLOB_FILE_EXTENSION, vectorStream, vectorBlobLength, true);
126+
}
136127
// Then write doc ids
137128
writeDocIds(knnVectorValuesSupplier.get(), vectorBlobLength, totalLiveDocs, blobName, blobContainer);
138129
}
@@ -154,14 +145,15 @@ private void writeDocIds(
154145
String blobName,
155146
BlobContainer blobContainer
156147
) throws IOException {
157-
InputStream docStream = new BufferedInputStream(new DocIdInputStream(knnVectorValues));
158-
log.debug(
159-
"Writing {} bytes for {} docs ids to {}",
160-
vectorBlobLength,
161-
totalLiveDocs * Integer.BYTES,
162-
blobName + DOC_ID_FILE_EXTENSION
163-
);
164-
blobContainer.writeBlob(blobName + DOC_ID_FILE_EXTENSION, docStream, totalLiveDocs * Integer.BYTES, true);
148+
try (InputStream docStream = new BufferedInputStream(new DocIdInputStream(knnVectorValues))) {
149+
log.debug(
150+
"Writing {} bytes for {} docs ids to {}",
151+
vectorBlobLength,
152+
totalLiveDocs * Integer.BYTES,
153+
blobName + DOC_ID_FILE_EXTENSION
154+
);
155+
blobContainer.writeBlob(blobName + DOC_ID_FILE_EXTENSION, docStream, totalLiveDocs * Integer.BYTES, true);
156+
}
165157
}
166158

167159
/**
@@ -215,6 +207,30 @@ private CheckedTriFunction<Integer, Long, Long, InputStreamContainer, IOExceptio
215207
});
216208
}
217209

210+
/**
211+
* Creates a {@link WriteContext} meant to be used by {@link AsyncMultiStreamBlobContainer#asyncBlobUpload}.
212+
* Note: Integrity checking is left up to the vendor repository and SDK implementations.
213+
* @param blobName
214+
* @param vectorBlobLength
215+
* @param knnVectorValuesSupplier
216+
* @param vectorDataType
217+
* @return
218+
*/
219+
private WriteContext createWriteContext(
220+
String blobName,
221+
long vectorBlobLength,
222+
Supplier<KNNVectorValues<?>> knnVectorValuesSupplier,
223+
VectorDataType vectorDataType
224+
) {
225+
return new WriteContext.Builder().fileName(blobName + VECTOR_BLOB_FILE_EXTENSION)
226+
.streamContextSupplier((partSize) -> getStreamContext(partSize, vectorBlobLength, knnVectorValuesSupplier, vectorDataType))
227+
.fileSize(vectorBlobLength)
228+
.failIfAlreadyExists(true)
229+
.writePriority(WritePriority.NORMAL)
230+
.uploadFinalizer((bool) -> {})
231+
.build();
232+
}
233+
218234
@Override
219235
public void readFromRepository(String path, IndexOutputWithBuffer indexOutputWithBuffer) throws IOException {
220236
if (path == null || path.isEmpty()) {

src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/DocIdInputStream.java

+21
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import java.io.InputStream;
1414
import java.nio.ByteBuffer;
1515
import java.nio.ByteOrder;
16+
import java.util.concurrent.atomic.AtomicBoolean;
1617

1718
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues;
1819

@@ -25,6 +26,7 @@ class DocIdInputStream extends InputStream {
2526
// Doc ids are 4 byte integers, byte read() only returns a single byte, so we will need to track the byte position within a doc id.
2627
// For simplicity, and to maintain the byte ordering, we use a buffer with size of 1 int.
2728
private ByteBuffer currentBuffer;
29+
private final AtomicBoolean closed = new AtomicBoolean(false);
2830

2931
/**
3032
* Use to represent the doc ids of a {@link KNNVectorValues} as an {@link InputStream}. Expected to be used only with {@link org.opensearch.common.blobstore.BlobContainer#writeBlob}.
@@ -41,6 +43,7 @@ public DocIdInputStream(KNNVectorValues<?> knnVectorValues) throws IOException {
4143

4244
@Override
4345
public int read() throws IOException {
46+
checkClosed();
4447
if (currentBuffer == null) {
4548
return -1;
4649
}
@@ -59,6 +62,7 @@ public int read() throws IOException {
5962

6063
@Override
6164
public int read(byte[] b, int off, int len) throws IOException {
65+
checkClosed();
6266
if (currentBuffer == null) {
6367
return -1;
6468
}
@@ -77,6 +81,23 @@ public int read(byte[] b, int off, int len) throws IOException {
7781
return bytesToRead;
7882
}
7983

84+
/**
85+
* Marks this stream as closed
86+
* @throws IOException
87+
*/
88+
@Override
89+
public void close() throws IOException {
90+
super.close();
91+
currentBuffer = null;
92+
closed.set(true);
93+
}
94+
95+
private void checkClosed() throws IOException {
96+
if (closed.get()) {
97+
throw new IOException("Stream closed");
98+
}
99+
}
100+
80101
/**
81102
* Advances to the next doc, and then refills the buffer with the new doc.
82103
* @throws IOException

src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/VectorValuesInputStream.java

+22
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import java.io.InputStream;
1818
import java.nio.ByteBuffer;
1919
import java.nio.ByteOrder;
20+
import java.util.concurrent.atomic.AtomicBoolean;
2021

2122
import static org.opensearch.knn.index.VectorDataType.BINARY;
2223
import static org.opensearch.knn.index.VectorDataType.BYTE;
@@ -36,6 +37,7 @@ class VectorValuesInputStream extends InputStream {
3637
private final int bytesPerVector;
3738
private long bytesRemaining;
3839
private final VectorDataType vectorDataType;
40+
private final AtomicBoolean closed = new AtomicBoolean(false);
3941

4042
/**
4143
* Used to represent a part of a {@link KNNVectorValues} as an {@link InputStream}. Expected to be used with
@@ -84,6 +86,7 @@ public VectorValuesInputStream(KNNVectorValues<?> knnVectorValues, VectorDataTyp
8486

8587
@Override
8688
public int read() throws IOException {
89+
checkClosed();
8790
if (bytesRemaining <= 0 || currentBuffer == null) {
8891
return -1;
8992
}
@@ -103,6 +106,7 @@ public int read() throws IOException {
103106

104107
@Override
105108
public int read(byte[] b, int off, int len) throws IOException {
109+
checkClosed();
106110
if (bytesRemaining <= 0 || currentBuffer == null) {
107111
return -1;
108112
}
@@ -132,9 +136,27 @@ public int read(byte[] b, int off, int len) throws IOException {
132136
*/
133137
@Override
134138
public long skip(long n) throws IOException {
139+
checkClosed();
135140
throw new UnsupportedOperationException("VectorValuesInputStream does not support skip");
136141
}
137142

143+
/**
144+
* Marks this stream as closed
145+
* @throws IOException
146+
*/
147+
@Override
148+
public void close() throws IOException {
149+
super.close();
150+
currentBuffer = null;
151+
closed.set(true);
152+
}
153+
154+
private void checkClosed() throws IOException {
155+
if (closed.get()) {
156+
throw new IOException("Stream closed");
157+
}
158+
}
159+
138160
/**
139161
* Advances n bytes forward in the knnVectorValues.
140162
* Note: {@link KNNVectorValues#advance} is not supported when we are merging segments, so we do not use it here.

src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/DefaultVectorRepositoryAccessorTests.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@
3535
import static org.mockito.Mockito.mock;
3636
import static org.mockito.Mockito.verify;
3737
import static org.mockito.Mockito.when;
38+
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues;
3839
import static org.opensearch.knn.index.remote.KNNRemoteConstants.DOC_ID_FILE_EXTENSION;
3940
import static org.opensearch.knn.index.remote.KNNRemoteConstants.VECTOR_BLOB_FILE_EXTENSION;
40-
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues;
4141

4242
public class DefaultVectorRepositoryAccessorTests extends RemoteIndexBuildTests {
4343

@@ -109,7 +109,7 @@ public void testRepositoryInteractionWithBlobContainer() throws IOException, Int
109109
/**
110110
* Test that when an exception is thrown during asyncBlobUpload, the exception is rethrown.
111111
*/
112-
public void testAsyncUploadThrowsException() throws InterruptedException, IOException {
112+
public void testAsyncUploadThrowsException() throws IOException {
113113
RepositoriesService repositoriesService = mock(RepositoriesService.class);
114114
BlobStoreRepository mockRepository = mock(BlobStoreRepository.class);
115115
BlobPath testBasePath = new BlobPath().add("testBasePath");

src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildTests.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ public void readBlobAsync(String s, ActionListener<ReadContext> actionListener)
127127

128128
@Override
129129
public boolean remoteIntegrityCheckSupported() {
130-
return false;
130+
return true;
131131
}
132132

133133
@Override

0 commit comments

Comments
 (0)