5
5
6
6
package org .opensearch .knn .index .codec .nativeindex .remote ;
7
7
8
+ import com .google .common .annotations .VisibleForTesting ;
8
9
import lombok .AllArgsConstructor ;
9
10
import lombok .extern .log4j .Log4j2 ;
10
11
import org .opensearch .action .LatchedActionListener ;
32
33
import java .util .concurrent .CountDownLatch ;
33
34
import java .util .concurrent .atomic .AtomicReference ;
34
35
import java .util .function .Supplier ;
36
+ import java .util .zip .CRC32 ;
37
+ import java .util .zip .CheckedInputStream ;
35
38
36
39
import static org .opensearch .knn .index .codec .util .KNNCodecUtil .initializeVectorValues ;
37
40
import static org .opensearch .knn .index .remote .KNNRemoteConstants .DOC_ID_FILE_EXTENSION ;
@@ -74,50 +77,42 @@ public void writeToRepository(
74
77
initializeVectorValues (knnVectorValues );
75
78
long vectorBlobLength = (long ) knnVectorValues .bytesPerVector () * totalLiveDocs ;
76
79
77
- if (blobContainer instanceof AsyncMultiStreamBlobContainer ) {
80
+ if (blobContainer instanceof AsyncMultiStreamBlobContainer asyncBlobContainer ) {
78
81
// First initiate vectors upload
79
82
log .debug ("Repository {} Supports Parallel Blob Upload" , repository );
80
83
// WriteContext is the main entry point into asyncBlobUpload. It stores all of our upload configurations, analogous to
81
84
// BuildIndexParams
82
- WriteContext writeContext = new WriteContext .Builder ().fileName (blobName + VECTOR_BLOB_FILE_EXTENSION )
83
- .streamContextSupplier ((partSize ) -> getStreamContext (partSize , vectorBlobLength , knnVectorValuesSupplier , vectorDataType ))
84
- .fileSize (vectorBlobLength )
85
- .failIfAlreadyExists (true )
86
- .writePriority (WritePriority .NORMAL )
87
- // TODO: Checksum implementations -- It is difficult to calculate a checksum on the knnVectorValues as
88
- // there is no underlying file upon which we can create the checksum. We should be able to create a
89
- // checksum still by iterating through once, however this will be an expensive operation.
90
- .uploadFinalizer ((bool ) -> {})
91
- .doRemoteDataIntegrityCheck (false )
92
- .expectedChecksum (null )
93
- .build ();
85
+ WriteContext writeContext = createWriteContext (
86
+ blobName ,
87
+ vectorBlobLength ,
88
+ knnVectorValuesSupplier ,
89
+ vectorDataType ,
90
+ asyncBlobContainer .remoteIntegrityCheckSupported ()
91
+ );
94
92
95
93
AtomicReference <Exception > exception = new AtomicReference <>();
96
94
final CountDownLatch latch = new CountDownLatch (1 );
97
- ((AsyncMultiStreamBlobContainer ) blobContainer ).asyncBlobUpload (
98
- writeContext ,
99
- new LatchedActionListener <>(new ActionListener <>() {
100
- @ Override
101
- public void onResponse (Void unused ) {
102
- log .debug (
103
- "Parallel vector upload succeeded for blob {} with size {}" ,
104
- blobName + VECTOR_BLOB_FILE_EXTENSION ,
105
- vectorBlobLength
106
- );
107
- }
108
-
109
- @ Override
110
- public void onFailure (Exception e ) {
111
- log .error (
112
- "Parallel vector upload failed for blob {} with size {}" ,
113
- blobName + VECTOR_BLOB_FILE_EXTENSION ,
114
- vectorBlobLength ,
115
- e
116
- );
117
- exception .set (e );
118
- }
119
- }, latch )
120
- );
95
+ asyncBlobContainer .asyncBlobUpload (writeContext , new LatchedActionListener <>(new ActionListener <>() {
96
+ @ Override
97
+ public void onResponse (Void unused ) {
98
+ log .debug (
99
+ "Parallel vector upload succeeded for blob {} with size {}" ,
100
+ blobName + VECTOR_BLOB_FILE_EXTENSION ,
101
+ vectorBlobLength
102
+ );
103
+ }
104
+
105
+ @ Override
106
+ public void onFailure (Exception e ) {
107
+ log .error (
108
+ "Parallel vector upload failed for blob {} with size {}" ,
109
+ blobName + VECTOR_BLOB_FILE_EXTENSION ,
110
+ vectorBlobLength ,
111
+ e
112
+ );
113
+ exception .set (e );
114
+ }
115
+ }, latch ));
121
116
122
117
// Then upload doc id blob before waiting on vector uploads
123
118
// TODO: We wrap with a BufferedInputStream to support retries. We can tune this buffer size to optimize performance.
@@ -130,9 +125,14 @@ public void onFailure(Exception e) {
130
125
} else {
131
126
log .debug ("Repository {} Does Not Support Parallel Blob Upload" , repository );
132
127
// Write Vectors
133
- InputStream vectorStream = new BufferedInputStream (new VectorValuesInputStream (knnVectorValuesSupplier .get (), vectorDataType ));
134
- log .debug ("Writing {} bytes for {} docs to {}" , vectorBlobLength , totalLiveDocs , blobName + VECTOR_BLOB_FILE_EXTENSION );
135
- blobContainer .writeBlob (blobName + VECTOR_BLOB_FILE_EXTENSION , vectorStream , vectorBlobLength , true );
128
+ try (
129
+ InputStream vectorStream = new BufferedInputStream (
130
+ new VectorValuesInputStream (knnVectorValuesSupplier .get (), vectorDataType )
131
+ )
132
+ ) {
133
+ log .debug ("Writing {} bytes for {} docs to {}" , vectorBlobLength , totalLiveDocs , blobName + VECTOR_BLOB_FILE_EXTENSION );
134
+ blobContainer .writeBlob (blobName + VECTOR_BLOB_FILE_EXTENSION , vectorStream , vectorBlobLength , true );
135
+ }
136
136
// Then write doc ids
137
137
writeDocIds (knnVectorValuesSupplier .get (), vectorBlobLength , totalLiveDocs , blobName , blobContainer );
138
138
}
@@ -154,14 +154,15 @@ private void writeDocIds(
154
154
String blobName ,
155
155
BlobContainer blobContainer
156
156
) throws IOException {
157
- InputStream docStream = new BufferedInputStream (new DocIdInputStream (knnVectorValues ));
158
- log .debug (
159
- "Writing {} bytes for {} docs ids to {}" ,
160
- vectorBlobLength ,
161
- totalLiveDocs * Integer .BYTES ,
162
- blobName + DOC_ID_FILE_EXTENSION
163
- );
164
- blobContainer .writeBlob (blobName + DOC_ID_FILE_EXTENSION , docStream , totalLiveDocs * Integer .BYTES , true );
157
+ try (InputStream docStream = new BufferedInputStream (new DocIdInputStream (knnVectorValues ))) {
158
+ log .debug (
159
+ "Writing {} bytes for {} docs ids to {}" ,
160
+ vectorBlobLength ,
161
+ totalLiveDocs * Integer .BYTES ,
162
+ blobName + DOC_ID_FILE_EXTENSION
163
+ );
164
+ blobContainer .writeBlob (blobName + DOC_ID_FILE_EXTENSION , docStream , totalLiveDocs * Integer .BYTES , true );
165
+ }
165
166
}
166
167
167
168
/**
@@ -215,6 +216,65 @@ private CheckedTriFunction<Integer, Long, Long, InputStreamContainer, IOExceptio
215
216
});
216
217
}
217
218
219
+ /**
220
+ * Creates a {@link WriteContext} meant to be used by {@link AsyncMultiStreamBlobContainer#asyncBlobUpload}. If integrity checking is supported, calculates a checksum as well.
221
+ * @param blobName
222
+ * @param vectorBlobLength
223
+ * @param knnVectorValuesSupplier
224
+ * @param vectorDataType
225
+ * @param supportsIntegrityCheck
226
+ * @return
227
+ * @throws IOException
228
+ */
229
+ private WriteContext createWriteContext (
230
+ String blobName ,
231
+ long vectorBlobLength ,
232
+ Supplier <KNNVectorValues <?>> knnVectorValuesSupplier ,
233
+ VectorDataType vectorDataType ,
234
+ boolean supportsIntegrityCheck
235
+ ) throws IOException {
236
+ return new WriteContext .Builder ().fileName (blobName + VECTOR_BLOB_FILE_EXTENSION )
237
+ .streamContextSupplier ((partSize ) -> getStreamContext (partSize , vectorBlobLength , knnVectorValuesSupplier , vectorDataType ))
238
+ .fileSize (vectorBlobLength )
239
+ .failIfAlreadyExists (true )
240
+ .writePriority (WritePriority .NORMAL )
241
+ .doRemoteDataIntegrityCheck (supportsIntegrityCheck )
242
+ .uploadFinalizer ((bool ) -> {})
243
+ .expectedChecksum (supportsIntegrityCheck ? getExpectedChecksum (knnVectorValuesSupplier .get (), vectorDataType ) : null )
244
+ .build ();
245
+ }
246
+
247
+ /**
248
+ * Calculates a checksum on the given {@link KNNVectorValues}, representing all the vector data for the index build operation.
249
+ * This is done by creating a {@link VectorValuesInputStream} which is wrapped by a {@link CheckedInputStream} and then reading all the data through the stream to calculate the checksum.
250
+ * Note: This does add some overhead to the vector blob upload, as we are reading through the KNNVectorValues an additional time. If instead of taking an expected checksum up front
251
+ * the WriteContext accepted an expectedChecksumSupplier, we could calculate the checksum as the stream is being uploaded and use that same value to compare, however this is pending
252
+ * a change in OpenSearch core.
253
+ *
254
+ * @param knnVectorValues
255
+ * @param vectorDataType
256
+ * @return
257
+ * @throws IOException
258
+ */
259
+ @ VisibleForTesting
260
+ long getExpectedChecksum (KNNVectorValues <?> knnVectorValues , VectorDataType vectorDataType ) throws IOException {
261
+ try (
262
+ CheckedInputStream checkedStream = new CheckedInputStream (
263
+ new VectorValuesInputStream (knnVectorValues , vectorDataType ),
264
+ new CRC32 ()
265
+ )
266
+ ) {
267
+ // VectorValuesInputStream#read only reads 1 vector max at a time, so no point making this buffer any larger than that
268
+ initializeVectorValues (knnVectorValues );
269
+ int bufferSize = knnVectorValues .bytesPerVector ();
270
+ final byte [] buffer = new byte [bufferSize ];
271
+ // Checksum is computed by reading through the CheckedInputStream
272
+ while (checkedStream .read (buffer , 0 , bufferSize ) != -1 ) {
273
+ }
274
+ return checkedStream .getChecksum ().getValue ();
275
+ }
276
+ }
277
+
218
278
@ Override
219
279
public void readFromRepository (String path , IndexOutputWithBuffer indexOutputWithBuffer ) throws IOException {
220
280
if (path == null || path .isEmpty ()) {
0 commit comments