Skip to content

Commit 4a71a7a

Browse files
author
Jay Deng
committed
WIP
1 parent 7511f21 commit 4a71a7a

File tree

3 files changed

+94
-7
lines changed

3 files changed

+94
-7
lines changed

src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java

+36-6
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55

66
package org.opensearch.knn.index.codec.nativeindex.remote;
77

8+
import com.google.common.annotations.VisibleForTesting;
89
import lombok.extern.log4j.Log4j2;
910
import org.apache.commons.lang.NotImplementedException;
11+
import org.apache.lucene.store.IndexOutput;
1012
import org.opensearch.action.LatchedActionListener;
1113
import org.opensearch.common.CheckedTriFunction;
1214
import org.opensearch.common.StopWatch;
@@ -25,6 +27,7 @@
2527
import org.opensearch.knn.index.VectorDataType;
2628
import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategy;
2729
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
30+
import org.opensearch.knn.index.engine.KNNEngine;
2831
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
2932
import org.opensearch.repositories.RepositoriesService;
3033
import org.opensearch.repositories.Repository;
@@ -143,7 +146,8 @@ public void buildAndWriteIndex(BuildIndexParams indexInfo) throws IOException {
143146
log.debug("Await vector build took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName());
144147

145148
stopWatch = new StopWatch().start();
146-
readFromRepository();
149+
// TODO: This blob will be retrieved from the remote vector build service status response
150+
readFromRepository(blobName + KNNEngine.FAISS.getExtension(), indexInfo.getIndexOutputWithBuffer().getIndexOutput());
147151
time_in_millis = stopWatch.stop().totalTime().millis();
148152
log.debug("Repository read took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName());
149153
} catch (Exception e) {
@@ -171,6 +175,14 @@ private BlobStoreRepository getRepository() throws RepositoryMissingException {
171175
return (BlobStoreRepository) repository;
172176
}
173177

178+
/**
179+
* @return The blob container to read/write from, determined from the repository base path and index settings. This container is where all blobs will be written to.
180+
*/
181+
private BlobContainer getBlobContainer() {
182+
BlobPath path = getRepository().basePath().add(indexSettings.getUUID() + VECTORS_PATH);
183+
return getRepository().blobStore().blobContainer(path);
184+
}
185+
174186
/**
175187
* This method is responsible for writing both the vector blobs and doc ids provided by {@param knnVectorValuesSupplier} to the vector repository configured by {@link KNN_REMOTE_VECTOR_REPO_SETTING}.
176188
* If the repository implements {@link AsyncMultiStreamBlobContainer}, then parallel uploads will be used. Parallel uploads are backed by a {@link WriteContext}, for which we have a custom
@@ -192,9 +204,7 @@ private void writeToRepository(
192204
VectorDataType vectorDataType,
193205
Supplier<KNNVectorValues<?>> knnVectorValuesSupplier
194206
) throws IOException, InterruptedException {
195-
// Get the blob container based on blobName and the repo base path. This is where the blobs will be written to.
196-
BlobPath path = getRepository().basePath().add(indexSettings.getUUID() + VECTORS_PATH);
197-
BlobContainer blobContainer = getRepository().blobStore().blobContainer(path);
207+
BlobContainer blobContainer = getBlobContainer();
198208

199209
KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();
200210
initializeVectorValues(knnVectorValues);
@@ -343,7 +353,27 @@ private void awaitVectorBuild() {
343353
/**
344354
* Read constructed vector file from remote repository and write to IndexOutput
345355
*/
346-
private void readFromRepository() {
347-
throw new NotImplementedException();
356+
@VisibleForTesting
357+
void readFromRepository(String blobName, IndexOutput indexOutput) throws IOException {
358+
BlobContainer blobContainer = getBlobContainer();
359+
// TODO: We are using the sequential download API as multi-part parallel download is difficult for us to implement today and
360+
// requires some changes in core. For more details, see: https://github.com/opensearch-project/k-NN/issues/2464
361+
InputStream graphStream = blobContainer.readBlob(blobName);
362+
363+
// Allocate buffer of 64KB, same as used for CPU builds, see: IndexOutputWithBuffer
364+
int CHUNK_SIZE = 64 * 1024;
365+
byte[] buffer = new byte[CHUNK_SIZE];
366+
367+
int bytesRead = 0;
368+
// InputStream uses -1 indicates there are no more bytes to be read
369+
while (bytesRead != -1) {
370+
// Try to read CHUNK_SIZE into the buffer. The actual amount read may be less.
371+
bytesRead = graphStream.read(buffer, 0, CHUNK_SIZE);
372+
assert bytesRead <= CHUNK_SIZE;
373+
// However many bytes we read, write it to the IndexOutput if != -1
374+
if (bytesRead != -1) {
375+
indexOutput.writeBytes(buffer, 0, bytesRead);
376+
}
377+
}
348378
}
349379
}

src/main/java/org/opensearch/knn/index/store/IndexOutputWithBuffer.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55

66
package org.opensearch.knn.index.store;
77

8+
import lombok.Getter;
89
import org.apache.lucene.store.IndexOutput;
910

1011
import java.io.IOException;
1112

1213
public class IndexOutputWithBuffer {
14+
// Getting is exposed for RemoteIndexBuildStrategy to write to the IndexOutput.
15+
@Getter
1316
// Underlying `IndexOutput` obtained from Lucene's Directory.
14-
private IndexOutput indexOutput;
17+
private final IndexOutput indexOutput;
1518
// Write buffer. Native engine will copy bytes into this buffer.
1619
// Allocating 64KB here since it show better performance in NMSLIB with the size. (We had slightly improvement in FAISS than having 4KB)
1720
// NMSLIB writes an adjacent list size first, then followed by serializing the list. Since we usually have more adjacent lists, having

src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategyTests.java

+54
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import org.apache.lucene.search.Sort;
1313
import org.apache.lucene.store.Directory;
1414
import org.apache.lucene.store.IOContext;
15+
import org.apache.lucene.store.IndexInput;
16+
import org.apache.lucene.store.IndexOutput;
1517
import org.apache.lucene.util.InfoStream;
1618
import org.apache.lucene.util.Version;
1719
import org.junit.Before;
@@ -42,11 +44,13 @@
4244
import org.opensearch.repositories.RepositoryMissingException;
4345
import org.opensearch.repositories.blobstore.BlobStoreRepository;
4446

47+
import java.io.ByteArrayInputStream;
4548
import java.io.IOException;
4649
import java.io.InputStream;
4750
import java.nio.file.Path;
4851
import java.util.List;
4952
import java.util.Map;
53+
import java.util.Random;
5054

5155
import static org.mockito.ArgumentMatchers.any;
5256
import static org.mockito.Mockito.mock;
@@ -208,4 +212,54 @@ public void testRepositoryInteraction() throws IOException {
208212
verify(mockBlobStore).blobContainer(any());
209213
verify(mockRepository).basePath();
210214
}
215+
216+
/**
217+
* Verify the buffered read method in {@link RemoteIndexBuildStrategy#readFromRepository} produces the correct result
218+
*/
219+
public void testRepositoryRead() throws IOException {
220+
// Create an InputStream with random values
221+
int TEST_ARRAY_SIZE = 64 * 1024 * 10;
222+
byte[] byteArray = new byte[TEST_ARRAY_SIZE];
223+
Random random = new Random();
224+
random.nextBytes(byteArray);
225+
InputStream randomStream = new ByteArrayInputStream(byteArray);
226+
227+
// Create a test segment that we will read/write from
228+
Directory directory;
229+
directory = newFSDirectory(createTempDir());
230+
String TEST_SEGMENT_NAME = "test-segment-name";
231+
IndexOutput testIndexOutput = directory.createOutput(TEST_SEGMENT_NAME, IOContext.DEFAULT);
232+
233+
// Set up RemoteIndexBuildStrategy and write to IndexOutput
234+
RepositoriesService repositoriesService = mock(RepositoriesService.class);
235+
BlobStoreRepository mockRepository = mock(BlobStoreRepository.class);
236+
BlobPath testBasePath = new BlobPath().add("testBasePath");
237+
BlobStore mockBlobStore = mock(BlobStore.class);
238+
AsyncMultiStreamBlobContainer mockBlobContainer = mock(AsyncMultiStreamBlobContainer.class);
239+
240+
when(repositoriesService.repository(any())).thenReturn(mockRepository);
241+
when(mockRepository.basePath()).thenReturn(testBasePath);
242+
when(mockRepository.blobStore()).thenReturn(mockBlobStore);
243+
when(mockBlobStore.blobContainer(any())).thenReturn(mockBlobContainer);
244+
when(mockBlobContainer.readBlob("test-blob")).thenReturn(randomStream);
245+
246+
RemoteIndexBuildStrategy objectUnderTest = new RemoteIndexBuildStrategy(
247+
() -> repositoriesService,
248+
mock(NativeIndexBuildStrategy.class),
249+
mock(IndexSettings.class)
250+
);
251+
// This should read from randomStream into testIndexOutput
252+
objectUnderTest.readFromRepository("test-blob", testIndexOutput);
253+
testIndexOutput.close();
254+
255+
// Now try to read from the IndexOutput
256+
IndexInput testIndexInput = directory.openInput(TEST_SEGMENT_NAME, IOContext.DEFAULT);
257+
byte[] resultByteArray = new byte[TEST_ARRAY_SIZE];
258+
testIndexInput.readBytes(resultByteArray, 0, TEST_ARRAY_SIZE);
259+
assertArrayEquals(byteArray, resultByteArray);
260+
261+
// Test Cleanup
262+
testIndexInput.close();
263+
directory.close();
264+
}
211265
}

0 commit comments

Comments
 (0)