Skip to content

Commit 5873add

Browse files
authored
Add vector data upload implementation to RemoteIndexBuildStrategy (#2550)
Signed-off-by: Jay Deng <jayd0104@gmail.com>
1 parent d5b2982 commit 5873add

16 files changed

+1344
-101
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66

77
## [Unreleased 3.0](https://github.com/opensearch-project/k-NN/compare/2.x...HEAD)
88
### Features
9+
* [Remote Vector Index Build] Introduce Remote Native Index Build feature flag, settings, and initial skeleton [#2525](https://github.com/opensearch-project/k-NN/pull/2525)
10+
* [Remote Vector Index Build] Implement vector data upload and vector data size threshold setting [#2550](https://github.com/opensearch-project/k-NN/pull/2550)
911
### Enhancements
1012
* Introduce node level circuit breakers for k-NN [#2509](https://github.com/opensearch-project/k-NN/pull/2509)
1113
### Bug Fixes

src/main/java/org/opensearch/knn/index/KNNSettings.java

+19-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import org.opensearch.OpenSearchParseException;
1313
import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsRequest;
1414
import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsResponse;
15-
import org.opensearch.transport.client.Client;
1615
import org.opensearch.cluster.metadata.IndexMetadata;
1716
import org.opensearch.cluster.service.ClusterService;
1817
import org.opensearch.common.Booleans;
@@ -29,6 +28,7 @@
2928
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager;
3029
import org.opensearch.monitor.jvm.JvmInfo;
3130
import org.opensearch.monitor.os.OsProbe;
31+
import org.opensearch.transport.client.Client;
3232

3333
import java.security.InvalidParameterException;
3434
import java.util.Arrays;
@@ -99,6 +99,7 @@ public class KNNSettings {
9999
public static final String KNN_DERIVED_SOURCE_ENABLED = "index.knn.derived_source.enabled";
100100
public static final String KNN_INDEX_REMOTE_VECTOR_BUILD = "index.knn.remote_index_build.enabled";
101101
public static final String KNN_REMOTE_VECTOR_REPO = "knn.remote_index_build.vector_repo";
102+
public static final String KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD = "index.knn.remote_index_build.size_threshold";
102103

103104
/**
104105
* Default setting values
@@ -129,6 +130,8 @@ public class KNNSettings {
129130
// 10% of the JVM heap
130131
public static final Integer KNN_DEFAULT_QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = 60;
131132
public static final boolean KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_VALUE = false;
133+
// TODO: Tune this default value based on benchmarking
134+
public static final ByteSizeValue KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD_DEFAULT_VALUE = new ByteSizeValue(50, ByteSizeUnit.MB);
132135

133136
/**
134137
* Settings Definition
@@ -397,6 +400,15 @@ public class KNNSettings {
397400
*/
398401
public static final Setting<String> KNN_REMOTE_VECTOR_REPO_SETTING = Setting.simpleString(KNN_REMOTE_VECTOR_REPO, Dynamic, NodeScope);
399402

403+
/**
404+
* Index level setting which indicates the size threshold above which remote vector builds will be enabled.
405+
*/
406+
public static final Setting<ByteSizeValue> KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD_SETTING = Setting.byteSizeSetting(
407+
KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD,
408+
KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD_DEFAULT_VALUE,
409+
Dynamic,
410+
IndexScope
411+
);
400412
/**
401413
* Dynamic settings
402414
*/
@@ -584,6 +596,10 @@ private Setting<?> getSetting(String key) {
584596
return KNN_REMOTE_VECTOR_REPO_SETTING;
585597
}
586598

599+
if (KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD.equals(key)) {
600+
return KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD_SETTING;
601+
}
602+
587603
throw new IllegalArgumentException("Cannot find setting by key [" + key + "]");
588604
}
589605

@@ -611,7 +627,8 @@ public List<Setting<?>> getSettings() {
611627
KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING,
612628
KNN_DERIVED_SOURCE_ENABLED_SETTING,
613629
KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING,
614-
KNN_REMOTE_VECTOR_REPO_SETTING
630+
KNN_REMOTE_VECTOR_REPO_SETTING,
631+
KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD_SETTING
615632
);
616633
return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream()))
617634
.collect(Collectors.toList());

src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategyFactory.java

+22-7
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,18 @@
77

88
import org.apache.lucene.index.FieldInfo;
99
import org.opensearch.index.IndexSettings;
10+
import org.opensearch.knn.common.featureflags.KNNFeatureFlags;
1011
import org.opensearch.knn.index.codec.nativeindex.remote.RemoteIndexBuildStrategy;
1112
import org.opensearch.knn.index.engine.KNNEngine;
13+
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
1214
import org.opensearch.repositories.RepositoriesService;
1315

16+
import java.io.IOException;
1417
import java.util.function.Supplier;
1518

1619
import static org.opensearch.knn.common.FieldInfoExtractor.extractKNNEngine;
1720
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
21+
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues;
1822

1923
/**
2024
* Creates the {@link NativeIndexBuildStrategy}
@@ -34,11 +38,18 @@ public NativeIndexBuildStrategyFactory(Supplier<RepositoriesService> repositorie
3438
}
3539

3640
/**
37-
* Creates or returns the desired {@link NativeIndexBuildStrategy} implementation. Intended to be used by {@link NativeIndexWriter}
38-
* @param fieldInfo
39-
* @return
41+
* @param fieldInfo Field related attributes/info
42+
* @param totalLiveDocs Number of documents with the vector field. This values comes from {@link org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsWriter#flush}
43+
* and {@link org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsWriter#mergeOneField}
44+
* @param knnVectorValues An instance of {@link KNNVectorValues} which is used to evaluate the size threshold KNN_REMOTE_VECTOR_BUILD_THRESHOLD
45+
* @return The {@link NativeIndexBuildStrategy} to be used. Intended to be used by {@link NativeIndexWriter}
46+
* @throws IOException
4047
*/
41-
public NativeIndexBuildStrategy getBuildStrategy(final FieldInfo fieldInfo) {
48+
public NativeIndexBuildStrategy getBuildStrategy(
49+
final FieldInfo fieldInfo,
50+
final int totalLiveDocs,
51+
final KNNVectorValues<?> knnVectorValues
52+
) throws IOException {
4253
final KNNEngine knnEngine = extractKNNEngine(fieldInfo);
4354
boolean isTemplate = fieldInfo.attributes().containsKey(MODEL_ID);
4455
boolean iterative = !isTemplate && KNNEngine.FAISS == knnEngine;
@@ -47,11 +58,15 @@ public NativeIndexBuildStrategy getBuildStrategy(final FieldInfo fieldInfo) {
4758
? MemOptimizedNativeIndexBuildStrategy.getInstance()
4859
: DefaultIndexBuildStrategy.getInstance();
4960

50-
if (repositoriesServiceSupplier != null
61+
initializeVectorValues(knnVectorValues);
62+
long vectorBlobLength = ((long) knnVectorValues.bytesPerVector()) * totalLiveDocs;
63+
64+
if (KNNFeatureFlags.isKNNRemoteVectorBuildEnabled()
65+
&& repositoriesServiceSupplier != null
5166
&& indexSettings != null
5267
&& knnEngine.supportsRemoteIndexBuild()
53-
&& RemoteIndexBuildStrategy.shouldBuildIndexRemotely(indexSettings)) {
54-
return new RemoteIndexBuildStrategy(repositoriesServiceSupplier, strategy);
68+
&& RemoteIndexBuildStrategy.shouldBuildIndexRemotely(indexSettings, vectorBlobLength)) {
69+
return new RemoteIndexBuildStrategy(repositoriesServiceSupplier, strategy, indexSettings);
5570
} else {
5671
return strategy;
5772
}

src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java

+7-2
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public class NativeIndexWriter {
5858

5959
private final SegmentWriteState state;
6060
private final FieldInfo fieldInfo;
61-
private final NativeIndexBuildStrategy indexBuilder;
61+
private final NativeIndexBuildStrategyFactory indexBuilderFactory;
6262
@Nullable
6363
private final QuantizationState quantizationState;
6464

@@ -148,6 +148,11 @@ private void buildAndWriteIndex(final Supplier<KNNVectorValues<?>> knnVectorValu
148148
knnVectorValuesSupplier,
149149
totalLiveDocs
150150
);
151+
NativeIndexBuildStrategy indexBuilder = indexBuilderFactory.getBuildStrategy(
152+
fieldInfo,
153+
totalLiveDocs,
154+
knnVectorValuesSupplier.get()
155+
);
151156
indexBuilder.buildAndWriteIndex(nativeIndexParams);
152157
CodecUtil.writeFooter(output);
153158
}
@@ -316,6 +321,6 @@ private static NativeIndexWriter createWriter(
316321
@Nullable final QuantizationState quantizationState,
317322
NativeIndexBuildStrategyFactory nativeIndexBuildStrategyFactory
318323
) {
319-
return new NativeIndexWriter(state, fieldInfo, nativeIndexBuildStrategyFactory.getBuildStrategy(fieldInfo), quantizationState);
324+
return new NativeIndexWriter(state, fieldInfo, nativeIndexBuildStrategyFactory, quantizationState);
320325
}
321326
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.knn.index.codec.nativeindex.remote;
7+
8+
import lombok.AllArgsConstructor;
9+
import lombok.extern.log4j.Log4j2;
10+
import org.opensearch.action.LatchedActionListener;
11+
import org.opensearch.common.CheckedTriFunction;
12+
import org.opensearch.common.StreamContext;
13+
import org.opensearch.common.blobstore.AsyncMultiStreamBlobContainer;
14+
import org.opensearch.common.blobstore.BlobContainer;
15+
import org.opensearch.common.blobstore.BlobPath;
16+
import org.opensearch.common.blobstore.stream.write.WriteContext;
17+
import org.opensearch.common.blobstore.stream.write.WritePriority;
18+
import org.opensearch.common.io.InputStreamContainer;
19+
import org.opensearch.core.action.ActionListener;
20+
import org.opensearch.index.IndexSettings;
21+
import org.opensearch.knn.index.VectorDataType;
22+
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
23+
import org.opensearch.repositories.blobstore.BlobStoreRepository;
24+
25+
import java.io.BufferedInputStream;
26+
import java.io.IOException;
27+
import java.io.InputStream;
28+
import java.util.concurrent.CountDownLatch;
29+
import java.util.concurrent.atomic.AtomicReference;
30+
import java.util.function.Supplier;
31+
32+
import static org.opensearch.knn.index.codec.nativeindex.remote.RemoteIndexBuildStrategy.DOC_ID_FILE_EXTENSION;
33+
import static org.opensearch.knn.index.codec.nativeindex.remote.RemoteIndexBuildStrategy.VECTORS_PATH;
34+
import static org.opensearch.knn.index.codec.nativeindex.remote.RemoteIndexBuildStrategy.VECTOR_BLOB_FILE_EXTENSION;
35+
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues;
36+
37+
@Log4j2
38+
@AllArgsConstructor
39+
public class DefaultVectorRepositoryAccessor implements VectorRepositoryAccessor {
40+
private final BlobStoreRepository repository;
41+
private final IndexSettings indexSettings;
42+
43+
/**
44+
* If the repository implements {@link AsyncMultiStreamBlobContainer}, then parallel uploads will be used. Parallel uploads are backed by a {@link WriteContext}, for which we have a custom
45+
* {@link org.opensearch.common.blobstore.stream.write.StreamContextSupplier} implementation.
46+
*
47+
* @see DefaultVectorRepositoryAccessor#getStreamContext
48+
* @see DefaultVectorRepositoryAccessor#getTransferPartStreamSupplier
49+
*
50+
* @param blobName Base name of the blobs we are writing, excluding file extensions
51+
* @param totalLiveDocs Number of documents we are processing. This is used to compute the size of the blob we are writing
52+
* @param vectorDataType Data type of the vector (FLOAT, BYTE, BINARY)
53+
* @param knnVectorValuesSupplier Supplier for {@link KNNVectorValues}
54+
* @throws IOException
55+
* @throws InterruptedException
56+
*/
57+
@Override
58+
public void writeToRepository(
59+
String blobName,
60+
int totalLiveDocs,
61+
VectorDataType vectorDataType,
62+
Supplier<KNNVectorValues<?>> knnVectorValuesSupplier
63+
) throws IOException, InterruptedException {
64+
assert repository != null;
65+
// Get the blob container based on blobName and the repo base path. This is where the blobs will be written to.
66+
BlobPath path = repository.basePath().add(indexSettings.getUUID() + VECTORS_PATH);
67+
BlobContainer blobContainer = repository.blobStore().blobContainer(path);
68+
69+
KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();
70+
initializeVectorValues(knnVectorValues);
71+
long vectorBlobLength = (long) knnVectorValues.bytesPerVector() * totalLiveDocs;
72+
73+
if (blobContainer instanceof AsyncMultiStreamBlobContainer) {
74+
// First initiate vectors upload
75+
log.debug("Repository {} Supports Parallel Blob Upload", repository);
76+
// WriteContext is the main entry point into asyncBlobUpload. It stores all of our upload configurations, analogous to
77+
// BuildIndexParams
78+
WriteContext writeContext = new WriteContext.Builder().fileName(blobName + VECTOR_BLOB_FILE_EXTENSION)
79+
.streamContextSupplier((partSize) -> getStreamContext(partSize, vectorBlobLength, knnVectorValuesSupplier, vectorDataType))
80+
.fileSize(vectorBlobLength)
81+
.failIfAlreadyExists(true)
82+
.writePriority(WritePriority.NORMAL)
83+
// TODO: Checksum implementations -- It is difficult to calculate a checksum on the knnVectorValues as
84+
// there is no underlying file upon which we can create the checksum. We should be able to create a
85+
// checksum still by iterating through once, however this will be an expensive operation.
86+
.uploadFinalizer((bool) -> {})
87+
.doRemoteDataIntegrityCheck(false)
88+
.expectedChecksum(null)
89+
.build();
90+
91+
AtomicReference<Exception> exception = new AtomicReference<>();
92+
final CountDownLatch latch = new CountDownLatch(1);
93+
((AsyncMultiStreamBlobContainer) blobContainer).asyncBlobUpload(
94+
writeContext,
95+
new LatchedActionListener<>(new ActionListener<>() {
96+
@Override
97+
public void onResponse(Void unused) {
98+
log.debug(
99+
"Parallel vector upload succeeded for blob {} with size {}",
100+
blobName + VECTOR_BLOB_FILE_EXTENSION,
101+
vectorBlobLength
102+
);
103+
}
104+
105+
@Override
106+
public void onFailure(Exception e) {
107+
log.error(
108+
"Parallel vector upload failed for blob {} with size {}",
109+
blobName + VECTOR_BLOB_FILE_EXTENSION,
110+
vectorBlobLength,
111+
e
112+
);
113+
exception.set(e);
114+
}
115+
}, latch)
116+
);
117+
118+
// Then upload doc id blob before waiting on vector uploads
119+
// TODO: We wrap with a BufferedInputStream to support retries. We can tune this buffer size to optimize performance.
120+
// Note: We do not use the parallel upload API here as the doc id blob will be much smaller than the vector blob
121+
writeDocIds(knnVectorValuesSupplier.get(), vectorBlobLength, totalLiveDocs, blobName, blobContainer);
122+
latch.await();
123+
if (exception.get() != null) {
124+
throw new IOException(exception.get());
125+
}
126+
} else {
127+
log.debug("Repository {} Does Not Support Parallel Blob Upload", repository);
128+
// Write Vectors
129+
InputStream vectorStream = new BufferedInputStream(new VectorValuesInputStream(knnVectorValuesSupplier.get(), vectorDataType));
130+
log.debug("Writing {} bytes for {} docs to {}", vectorBlobLength, totalLiveDocs, blobName + VECTOR_BLOB_FILE_EXTENSION);
131+
blobContainer.writeBlob(blobName + VECTOR_BLOB_FILE_EXTENSION, vectorStream, vectorBlobLength, true);
132+
// Then write doc ids
133+
writeDocIds(knnVectorValuesSupplier.get(), vectorBlobLength, totalLiveDocs, blobName, blobContainer);
134+
}
135+
}
136+
137+
/**
138+
* Helper method for uploading doc ids to repository, as it's re-used in both parallel and sequential upload cases
139+
* @param knnVectorValues
140+
* @param vectorBlobLength
141+
* @param totalLiveDocs
142+
* @param blobName
143+
* @param blobContainer
144+
* @throws IOException
145+
*/
146+
private void writeDocIds(
147+
KNNVectorValues<?> knnVectorValues,
148+
long vectorBlobLength,
149+
long totalLiveDocs,
150+
String blobName,
151+
BlobContainer blobContainer
152+
) throws IOException {
153+
InputStream docStream = new BufferedInputStream(new DocIdInputStream(knnVectorValues));
154+
log.debug(
155+
"Writing {} bytes for {} docs ids to {}",
156+
vectorBlobLength,
157+
totalLiveDocs * Integer.BYTES,
158+
blobName + DOC_ID_FILE_EXTENSION
159+
);
160+
blobContainer.writeBlob(blobName + DOC_ID_FILE_EXTENSION, docStream, totalLiveDocs * Integer.BYTES, true);
161+
}
162+
163+
/**
164+
* Returns a {@link org.opensearch.common.StreamContext}. Intended to be invoked as a {@link org.opensearch.common.blobstore.stream.write.StreamContextSupplier},
165+
* which takes the partSize determined by the repository implementation and calculates the number of parts as well as handles the last part of the stream.
166+
*
167+
* @see DefaultVectorRepositoryAccessor#getTransferPartStreamSupplier
168+
*
169+
* @param partSize Size of each InputStream to be uploaded in parallel. Provided by repository implementation
170+
* @param vectorBlobLength Total size of the vectors across all InputStreams
171+
* @param knnVectorValuesSupplier Supplier for {@link KNNVectorValues}
172+
* @param vectorDataType Data type of the vector (FLOAT, BYTE, BINARY)
173+
* @return a {@link org.opensearch.common.StreamContext} with a function that will create {@link InputStream}s of {@param partSize}
174+
*/
175+
private StreamContext getStreamContext(
176+
long partSize,
177+
long vectorBlobLength,
178+
Supplier<KNNVectorValues<?>> knnVectorValuesSupplier,
179+
VectorDataType vectorDataType
180+
) {
181+
long lastPartSize = (vectorBlobLength % partSize) != 0 ? vectorBlobLength % partSize : partSize;
182+
int numberOfParts = (int) ((vectorBlobLength % partSize) == 0 ? vectorBlobLength / partSize : (vectorBlobLength / partSize) + 1);
183+
return new StreamContext(
184+
getTransferPartStreamSupplier(knnVectorValuesSupplier, vectorDataType),
185+
partSize,
186+
lastPartSize,
187+
numberOfParts
188+
);
189+
}
190+
191+
/**
192+
* This method handles creating {@link VectorValuesInputStream}s based on the part number, the requested size of the stream part, and the position that the stream starts at within the underlying {@link KNNVectorValues}
193+
*
194+
* @param knnVectorValuesSupplier Supplier for {@link KNNVectorValues}
195+
* @param vectorDataType Data type of the vector (FLOAT, BYTE, BINARY)
196+
* @return a function with which the repository implementation will use to create {@link VectorValuesInputStream}s of specific sizes and start positions.
197+
*/
198+
private CheckedTriFunction<Integer, Long, Long, InputStreamContainer, IOException> getTransferPartStreamSupplier(
199+
Supplier<KNNVectorValues<?>> knnVectorValuesSupplier,
200+
VectorDataType vectorDataType
201+
) {
202+
return ((partNo, size, position) -> {
203+
log.info("Creating InputStream for partNo: {}, size: {}, position: {}", partNo, size, position);
204+
VectorValuesInputStream vectorValuesInputStream = new VectorValuesInputStream(
205+
knnVectorValuesSupplier.get(),
206+
vectorDataType,
207+
position,
208+
size
209+
);
210+
return new InputStreamContainer(vectorValuesInputStream, size, position);
211+
});
212+
}
213+
}

0 commit comments

Comments
 (0)