Skip to content

Commit 6b6c3b1

Browse files
jed326Jay Deng
authored and
Jay Deng
committed
Add download + indexOuput#write implementation to RemoteIndexBuildStrategy
Signed-off-by: Jay Deng <jayd0104@gmail.com>
1 parent c7ac05c commit 6b6c3b1

File tree

4 files changed

+174
-8
lines changed

4 files changed

+174
-8
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66

77
## [Unreleased 3.0](https://github.com/opensearch-project/k-NN/compare/2.x...HEAD)
88
### Features
9+
* [Remote Vector Index Build] Implement data download and IndexOutput write functionality [#2554](https://github.com/opensearch-project/k-NN/pull/2554)
910
### Enhancements
1011
### Bug Fixes
1112
### Infrastructure

src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java

+37-5
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,31 @@
55

66
package org.opensearch.knn.index.codec.nativeindex.remote;
77

8+
import com.google.common.annotations.VisibleForTesting;
89
import lombok.extern.log4j.Log4j2;
910
import org.apache.commons.lang.NotImplementedException;
1011
import org.apache.lucene.index.SegmentWriteState;
1112
import org.opensearch.common.StopWatch;
1213
import org.opensearch.common.annotation.ExperimentalApi;
14+
import org.opensearch.common.blobstore.BlobContainer;
15+
import org.opensearch.common.blobstore.BlobPath;
1316
import org.opensearch.index.IndexSettings;
1417
import org.opensearch.knn.common.featureflags.KNNFeatureFlags;
1518
import org.opensearch.knn.index.KNNSettings;
1619
import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategy;
1720
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
21+
import org.opensearch.knn.index.engine.KNNEngine;
22+
import org.opensearch.knn.index.store.IndexOutputWithBuffer;
1823
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
1924
import org.opensearch.repositories.RepositoriesService;
2025
import org.opensearch.repositories.Repository;
2126
import org.opensearch.repositories.RepositoryMissingException;
2227
import org.opensearch.repositories.blobstore.BlobStoreRepository;
2328

2429
import java.io.IOException;
30+
import java.io.InputStream;
31+
import java.nio.file.Path;
32+
import java.nio.file.Paths;
2533
import java.util.function.Supplier;
2634

2735
import static org.opensearch.knn.index.KNNSettings.KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING;
@@ -37,6 +45,7 @@ public class RemoteIndexBuildStrategy implements NativeIndexBuildStrategy {
3745

3846
private final Supplier<RepositoriesService> repositoriesServiceSupplier;
3947
private final NativeIndexBuildStrategy fallbackStrategy;
48+
4049
private static final String VECTOR_BLOB_FILE_EXTENSION = ".knnvec";
4150
private static final String DOC_ID_FILE_EXTENSION = ".knndid";
4251

@@ -93,12 +102,12 @@ public void buildAndWriteIndex(BuildIndexParams indexInfo) throws IOException {
93102
log.debug("Submit vector build took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName());
94103

95104
stopWatch = new StopWatch().start();
96-
awaitVectorBuild();
105+
String downloadPath = awaitVectorBuild();
97106
time_in_millis = stopWatch.stop().totalTime().millis();
98107
log.debug("Await vector build took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName());
99108

100109
stopWatch = new StopWatch().start();
101-
readFromRepository();
110+
readFromRepository(downloadPath, indexInfo.getIndexOutputWithBuffer());
102111
time_in_millis = stopWatch.stop().totalTime().millis();
103112
log.debug("Repository read took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName());
104113
} catch (Exception e) {
@@ -155,15 +164,38 @@ private void submitVectorBuild() {
155164

156165
/**
157166
* Wait on remote vector build to complete
167+
* @return String The path from which we should perform download, delimited by "/"
158168
*/
159-
private void awaitVectorBuild() {
169+
private String awaitVectorBuild() throws NotImplementedException {
160170
throw new NotImplementedException();
161171
}
162172

163173
/**
164174
* Read constructed vector file from remote repository and write to IndexOutput
165175
*/
166-
private void readFromRepository() {
167-
throw new NotImplementedException();
176+
@VisibleForTesting
177+
void readFromRepository(String path, IndexOutputWithBuffer indexOutputWithBuffer) throws IOException {
178+
if (path == null || path.isEmpty()) {
179+
throw new IllegalArgumentException("download path is null or empty");
180+
}
181+
Path downloadPath = Paths.get(path);
182+
String fileName = downloadPath.getFileName().toString();
183+
if (!fileName.endsWith(KNNEngine.FAISS.getExtension())) {
184+
log.error("download path [{}] does not end with extension [{}}", downloadPath, KNNEngine.FAISS.getExtension());
185+
throw new IllegalArgumentException("download path has incorrect file extension");
186+
}
187+
188+
BlobPath blobContainerPath = new BlobPath();
189+
if (downloadPath.getParent() != null) {
190+
for (Path p : downloadPath.getParent()) {
191+
blobContainerPath = blobContainerPath.add(p.getFileName().toString());
192+
}
193+
}
194+
195+
BlobContainer blobContainer = getRepository().blobStore().blobContainer(blobContainerPath);
196+
// TODO: We are using the sequential download API as multi-part parallel download is difficult for us to implement today and
197+
// requires some changes in core. For more details, see: https://github.com/opensearch-project/k-NN/issues/2464
198+
InputStream graphStream = blobContainer.readBlob(fileName);
199+
indexOutputWithBuffer.writeFromStreamWithBuffer(graphStream);
168200
}
169201
}

src/main/java/org/opensearch/knn/index/store/IndexOutputWithBuffer.java

+45-1
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,25 @@
88
import org.apache.lucene.store.IndexOutput;
99

1010
import java.io.IOException;
11+
import java.io.InputStream;
1112

13+
/**
14+
* Wrapper around {@link IndexOutput} to perform writes in a buffered manner. This class is created per flush/merge, and may be used twice if
15+
* {@link org.opensearch.knn.index.codec.nativeindex.remote.RemoteIndexBuildStrategy} needs to fall back to a different build strategy.
16+
*/
1217
public class IndexOutputWithBuffer {
1318
// Underlying `IndexOutput` obtained from Lucene's Directory.
1419
private IndexOutput indexOutput;
1520
// Write buffer. Native engine will copy bytes into this buffer.
1621
// Allocating 64KB here since it show better performance in NMSLIB with the size. (We had slightly improvement in FAISS than having 4KB)
1722
// NMSLIB writes an adjacent list size first, then followed by serializing the list. Since we usually have more adjacent lists, having
1823
// 64KB to accumulate bytes as possible to reduce the times of calling `writeBytes`.
19-
private byte[] buffer = new byte[64 * 1024];
24+
private static final int CHUNK_SIZE = 64 * 1024;
25+
private final byte[] buffer;
2026

2127
public IndexOutputWithBuffer(IndexOutput indexOutput) {
2228
this.indexOutput = indexOutput;
29+
this.buffer = new byte[CHUNK_SIZE];
2330
}
2431

2532
// This method will be called in JNI layer which precisely knows
@@ -33,6 +40,43 @@ public void writeBytes(int length) {
3340
}
3441
}
3542

43+
/**
44+
* Writes to the {@link IndexOutput} by buffering bytes into the existing buffer in this class.
45+
*
46+
* @param inputStream The stream from which we are reading bytes to write
47+
* @throws IOException
48+
* @see IndexOutputWithBuffer#writeFromStreamWithBuffer(InputStream, byte[])
49+
*/
50+
public void writeFromStreamWithBuffer(InputStream inputStream) throws IOException {
51+
writeFromStreamWithBuffer(inputStream, this.buffer);
52+
}
53+
54+
/**
55+
* Writes to the {@link IndexOutput} by buffering bytes with @param outputBuffer. This method allows
56+
* {@link org.opensearch.knn.index.codec.nativeindex.remote.RemoteIndexBuildStrategy} to provide a separate, larger buffer as that buffer is for buffering
57+
* bytes downloaded from the repository, so it may be more performant to use a larger buffer.
58+
* We do not change the size of the existing buffer in case a fallback to the existing build strategy is needed.
59+
* TODO: Tune the size of the buffer used by RemoteIndexBuildStrategy based on benchmarking
60+
*
61+
* @param inputStream The stream from which we are reading bytes to write
62+
* @param outputBuffer The buffer used to buffer bytes
63+
* @throws IOException
64+
* @see IndexOutputWithBuffer#writeFromStreamWithBuffer(InputStream)
65+
*/
66+
private void writeFromStreamWithBuffer(InputStream inputStream, byte[] outputBuffer) throws IOException {
67+
int bytesRead = 0;
68+
// InputStream uses -1 indicates there are no more bytes to be read
69+
while (bytesRead != -1) {
70+
// Try to read CHUNK_SIZE into the buffer. The actual amount read may be less.
71+
bytesRead = inputStream.read(outputBuffer, 0, CHUNK_SIZE);
72+
assert bytesRead <= CHUNK_SIZE;
73+
// However many bytes we read, write it to the IndexOutput if != -1
74+
if (bytesRead != -1) {
75+
indexOutput.writeBytes(outputBuffer, 0, bytesRead);
76+
}
77+
}
78+
}
79+
3680
@Override
3781
public String toString() {
3882
return "{indexOutput=" + indexOutput + ", len(buffer)=" + buffer.length + "}";

src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategyTests.java

+91-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,18 @@
55

66
package org.opensearch.knn.index.codec.nativeindex.remote;
77

8+
import org.apache.lucene.store.Directory;
9+
import org.apache.lucene.store.IOContext;
10+
import org.apache.lucene.store.IndexInput;
11+
import org.apache.lucene.store.IndexOutput;
12+
import org.junit.Before;
813
import org.mockito.Mockito;
14+
import org.opensearch.common.blobstore.AsyncMultiStreamBlobContainer;
15+
import org.opensearch.common.blobstore.BlobPath;
16+
import org.opensearch.common.blobstore.BlobStore;
17+
import org.opensearch.common.settings.ClusterSettings;
18+
import org.opensearch.knn.KNNTestCase;
19+
import org.opensearch.knn.index.KNNSettings;
920
import org.opensearch.knn.index.VectorDataType;
1021
import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategy;
1122
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
@@ -16,17 +27,21 @@
1627
import org.opensearch.knn.index.vectorvalues.TestVectorValues;
1728
import org.opensearch.repositories.RepositoriesService;
1829
import org.opensearch.repositories.RepositoryMissingException;
19-
import org.opensearch.test.OpenSearchTestCase;
30+
import org.opensearch.repositories.blobstore.BlobStoreRepository;
2031

32+
import java.io.ByteArrayInputStream;
2133
import java.io.IOException;
34+
import java.io.InputStream;
2235
import java.util.List;
2336
import java.util.Map;
37+
import java.util.Random;
2438

2539
import static org.mockito.ArgumentMatchers.any;
2640
import static org.mockito.Mockito.mock;
2741
import static org.mockito.Mockito.when;
42+
import static org.opensearch.knn.index.KNNSettings.KNN_REMOTE_VECTOR_REPO_SETTING;
2843

29-
public class RemoteIndexBuildStrategyTests extends OpenSearchTestCase {
44+
public class RemoteIndexBuildStrategyTests extends KNNTestCase {
3045

3146
static int fallbackCounter = 0;
3247

@@ -38,6 +53,16 @@ public void buildAndWriteIndex(BuildIndexParams indexInfo) throws IOException {
3853
}
3954
}
4055

56+
@Before
57+
@Override
58+
public void setUp() throws Exception {
59+
super.setUp();
60+
ClusterSettings clusterSettings = mock(ClusterSettings.class);
61+
when(clusterSettings.get(KNN_REMOTE_VECTOR_REPO_SETTING)).thenReturn("test-repo-name");
62+
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
63+
KNNSettings.state().setClusterService(clusterService);
64+
}
65+
4166
public void testFallback() throws IOException {
4267
List<float[]> vectorValues = List.of(new float[] { 1, 2 }, new float[] { 2, 3 }, new float[] { 3, 4 });
4368
final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues(
@@ -64,4 +89,68 @@ public void testFallback() throws IOException {
6489
objectUnderTest.buildAndWriteIndex(buildIndexParams);
6590
assertEquals(1, fallbackCounter);
6691
}
92+
93+
/**
94+
* Verify the buffered read method in {@link RemoteIndexBuildStrategy#readFromRepository} produces the correct result
95+
*/
96+
public void testRepositoryRead() throws IOException {
97+
String TEST_FILE_NAME = randomAlphaOfLength(8) + KNNEngine.FAISS.getExtension();
98+
99+
// Create an InputStream with random values
100+
int TEST_ARRAY_SIZE = 64 * 1024 * 10;
101+
byte[] byteArray = new byte[TEST_ARRAY_SIZE];
102+
Random random = new Random();
103+
random.nextBytes(byteArray);
104+
InputStream randomStream = new ByteArrayInputStream(byteArray);
105+
106+
// Create a test segment that we will read/write from
107+
Directory directory;
108+
directory = newFSDirectory(createTempDir());
109+
String TEST_SEGMENT_NAME = "test-segment-name";
110+
IndexOutput testIndexOutput = directory.createOutput(TEST_SEGMENT_NAME, IOContext.DEFAULT);
111+
IndexOutputWithBuffer testIndexOutputWithBuffer = new IndexOutputWithBuffer(testIndexOutput);
112+
113+
// Set up RemoteIndexBuildStrategy and write to IndexOutput
114+
RepositoriesService repositoriesService = mock(RepositoriesService.class);
115+
BlobStoreRepository mockRepository = mock(BlobStoreRepository.class);
116+
BlobPath testBasePath = new BlobPath().add("testBasePath");
117+
BlobStore mockBlobStore = mock(BlobStore.class);
118+
AsyncMultiStreamBlobContainer mockBlobContainer = mock(AsyncMultiStreamBlobContainer.class);
119+
120+
when(repositoriesService.repository(any())).thenReturn(mockRepository);
121+
when(mockRepository.basePath()).thenReturn(testBasePath);
122+
when(mockRepository.blobStore()).thenReturn(mockBlobStore);
123+
when(mockBlobStore.blobContainer(any())).thenReturn(mockBlobContainer);
124+
when(mockBlobContainer.readBlob(TEST_FILE_NAME)).thenReturn(randomStream);
125+
126+
RemoteIndexBuildStrategy objectUnderTest = new RemoteIndexBuildStrategy(
127+
() -> repositoriesService,
128+
mock(NativeIndexBuildStrategy.class)
129+
);
130+
131+
// Verify file extension check
132+
assertThrows(IllegalArgumentException.class, () -> objectUnderTest.readFromRepository("test_file.txt", testIndexOutputWithBuffer));
133+
134+
// Now test with valid file extensions
135+
String testPath = randomFrom(
136+
List.of(
137+
"testBasePath/testDirectory/" + TEST_FILE_NAME, // Test with subdirectory
138+
"testBasePath/" + TEST_FILE_NAME, // Test with only base path
139+
TEST_FILE_NAME // test with no base path
140+
)
141+
);
142+
// This should read from randomStream into testIndexOutput
143+
objectUnderTest.readFromRepository(testPath, testIndexOutputWithBuffer);
144+
testIndexOutput.close();
145+
146+
// Now try to read from the IndexOutput
147+
IndexInput testIndexInput = directory.openInput(TEST_SEGMENT_NAME, IOContext.DEFAULT);
148+
byte[] resultByteArray = new byte[TEST_ARRAY_SIZE];
149+
testIndexInput.readBytes(resultByteArray, 0, TEST_ARRAY_SIZE);
150+
assertArrayEquals(byteArray, resultByteArray);
151+
152+
// Test Cleanup
153+
testIndexInput.close();
154+
directory.close();
155+
}
67156
}

0 commit comments

Comments
 (0)