diff --git a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissHNSW.java b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissHNSW.java new file mode 100644 index 0000000000..7f38a1394c --- /dev/null +++ b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissHNSW.java @@ -0,0 +1,94 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.memoryoptsearch.faiss; + +import lombok.Getter; +import org.apache.lucene.store.IndexInput; + +import java.io.IOException; + +/** + * While it follows the same steps as the original FAISS deserialization, differences in how the JVM and C++ handle floating-point + * calculations can lead to slight variations in results. However, such cases are very rare, and in most instances, the results are + * identical to FAISS. Even when there are ranking differences, they do not impact the precision or recall of the search. + * For more details, refer to the [FAISS HNSW implementation]( + * ...). + */ +@Getter +public class FaissHNSW { + // Cumulative number of neighbors per each level. + private int[] cumNumberNeighborPerLevel; + // offsets[i]:offset[i+1] gives all the neighbors for vector i + // Offset to be added to cumNumberNeighborPerLevel[level] to get the actual start offset of neighbor list. + private long[] offsets = null; + // Neighbor list storage. + private FaissSection neighbors; + // levels[i] = the maximum levels of `i`th vector + 1. + // Ex: If 544th vector has three levels (e.g. 0-level, 1-level, 2-level), then levels[433] would be 3. + // This indicates that 544th vector exists at all levels of (0-level, 1-level, 2-level). + private FaissSection levels; + // Entry point in HNSW graph + private int entryPoint; + // Maximum level of HNSW graph + private int maxLevel = -1; + // Default efSearch parameter. This determines the navigation queue size. + // More value, algorithm will more navigate candidates. + private int efSearch = 16; + // Total number of vectors stored in graph. + private long totalNumberOfVectors; + + /** + * Partially loads the FAISS HNSW graph from the provided index input stream. + * The graph is divided into multiple sections, and this method marks the starting offset of each section then skip to the next + * section instead of loading the entire graph into memory. During the search, bytes will be accessed via {@link IndexInput}. + * + * @param input An input stream for a FAISS HNSW graph file, allowing access to the neighbor list and vector locations. + * @param totalNumberOfVectors The total number of vectors stored in the graph. + * + * FYI FAISS Deserialization + * + * @throws IOException + */ + public void load(IndexInput input, long totalNumberOfVectors) throws IOException { + // Total number of vectors + this.totalNumberOfVectors = totalNumberOfVectors; + + // We don't use `double[] assignProbas` for search. It is for index construction. + long size = input.readLong(); + input.skipBytes(Double.BYTES * size); + + // Accumulate number of neighbor per each level. + size = input.readLong(); + cumNumberNeighborPerLevel = new int[Math.toIntExact(size)]; + if (size > 0) { + input.readInts(cumNumberNeighborPerLevel, 0, (int) size); + } + + // Maximum levels per each vector + levels = new FaissSection(input, Integer.BYTES); + + // Load `offsets` into memory. + size = input.readLong(); + offsets = new long[(int) size]; + input.readLongs(offsets, 0, offsets.length); + + // Mark neighbor list section. + neighbors = new FaissSection(input, Integer.BYTES); + + // HNSW graph parameters + entryPoint = input.readInt(); + + maxLevel = input.readInt(); + + // Gets efConstruction. We don't use this field. It's for index building. + input.readInt(); + + efSearch = input.readInt(); + + // dummy read a deprecated field. + input.readInt(); + } +} diff --git a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissHNSWFlatIndex.java b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissHNSWFlatIndex.java deleted file mode 100644 index 023b877c32..0000000000 --- a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissHNSWFlatIndex.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.memoryoptsearch.faiss; - -import org.apache.lucene.index.ByteVectorValues; -import org.apache.lucene.index.FloatVectorValues; -import org.apache.lucene.index.VectorEncoding; -import org.apache.lucene.store.IndexInput; - -import java.io.IOException; - -/** - * A flat HNSW index that contains both an HNSW graph and flat vector storage. - * This is the ported version of `IndexHNSW` from FAISS. - * For more details, please refer to ... - */ -public class FaissHNSWFlatIndex extends FaissIndex { - public FaissHNSWFlatIndex(final String indexType) { - super(indexType); - } - - @Override - protected void doLoad(IndexInput input) throws IOException { - // TODO(KDY) : This will be covered in part-3 (FAISS HNSW). - } - - @Override - public VectorEncoding getVectorEncoding() { - // TODO(KDY) : This will be covered in part-3 (FAISS HNSW). - return null; - } - - @Override - public FloatVectorValues getFloatValues(IndexInput indexInput) throws IOException { - // TODO(KDY) : This will be covered in part-3 (FAISS HNSW). - return null; - } - - @Override - public ByteVectorValues getByteValues(IndexInput indexInput) throws IOException { - // TODO(KDY) : This will be covered in part-3 (FAISS HNSW). - return null; - } -} diff --git a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissHNSWIndex.java b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissHNSWIndex.java new file mode 100644 index 0000000000..1122e2b8d9 --- /dev/null +++ b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissHNSWIndex.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.memoryoptsearch.faiss; + +import lombok.Getter; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.store.IndexInput; + +import java.io.IOException; + +/** + * A flat HNSW index that contains both an HNSW graph and flat vector storage. + * This is the ported version of `IndexHNSW` from FAISS. + * For more details, please refer to ... + */ +public class FaissHNSWIndex extends FaissIndex { + // Flat float vector format - + // https://github.com/facebookresearch/faiss/blob/15491a1e4f5a513a8684e5b7262ef4ec22eda19d/faiss/IndexHNSW.h#L122 + public static final String IHNF = "IHNf"; + // Quantized flat format with HNSW - + // https://github.com/facebookresearch/faiss/blob/15491a1e4f5a513a8684e5b7262ef4ec22eda19d/faiss/IndexHNSW.h#L144C8-L144C19 + public static final String IHNS = "IHNs"; + + @Getter + private FaissHNSW hnsw = new FaissHNSW(); + private FaissIndex flatVectors; + private VectorEncoding vectorEncoding; + + public FaissHNSWIndex(final String indexType) { + super(indexType); + + // Set encoding + if (indexType.equals(IHNF)) { + vectorEncoding = VectorEncoding.FLOAT32; + } else if (indexType.equals(IHNS)) { + vectorEncoding = VectorEncoding.BYTE; + } else { + throw new IllegalStateException("Unsupported index type: " + indexType + " in " + FaissHNSWIndex.class.getSimpleName()); + } + } + + /** + * Loading HNSW graph and nested storage index. + * For more details, please refer to + * ... + * @param input + * @throws IOException + */ + @Override + protected void doLoad(IndexInput input) throws IOException { + // Read common header + readCommonHeader(input); + + // Partial load HNSW graph + hnsw.load(input, getTotalNumberOfVectors()); + + // Partial load flat vector storage + flatVectors = FaissIndex.load(input); + } + + @Override + public VectorEncoding getVectorEncoding() { + return vectorEncoding; + } + + @Override + public FloatVectorValues getFloatValues(IndexInput indexInput) throws IOException { + return flatVectors.getFloatValues(indexInput); + } + + @Override + public ByteVectorValues getByteValues(IndexInput indexInput) throws IOException { + return flatVectors.getByteValues(indexInput); + } +} diff --git a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissHnswGraph.java b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissHnswGraph.java new file mode 100644 index 0000000000..682731a1f9 --- /dev/null +++ b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissHnswGraph.java @@ -0,0 +1,187 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.memoryoptsearch.faiss; + +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.hnsw.HnswGraph; + +import java.io.IOException; +import java.util.NoSuchElementException; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +/** + * This graph implements Lucene's HNSW graph interface using the FAISS HNSW graph. Conceptually, both libraries represent the graph + * similarly, maintaining a list of neighbor IDs. This implementation acts as a bridge, enabling Lucene's HNSW graph searcher to perform + * vector searches on a FAISS index. + * + * NOTE: This is not thread safe. It should be created every time in {@link KnnVectorsReader}.search likewise + * OffHeapHnswGraph + * in Lucene. + */ +public class FaissHnswGraph extends HnswGraph { + private final FaissHNSW faissHnsw; + private final IndexInput indexInput; + private final int numVectors; + private int[] neighborIdList; + private int numNeighbors; + private int nextNeighborIndex; + + public FaissHnswGraph(final FaissHNSW faissHNSW, final int numVectors, final IndexInput indexInput) { + this.faissHnsw = faissHNSW; + this.indexInput = indexInput; + this.numVectors = numVectors; + } + + /** + * Seek to the starting offset of neighbor ids at the given `level`. In which, it will load all ids into a buffer array. + * @param level The level of graph + * @param internalVectorId An internal vector id. + */ + @Override + public void seek(int level, int internalVectorId) { + // Get a relative starting offset of neighbor list at `level`. + long o = faissHnsw.getOffsets()[internalVectorId]; + + // `begin` and `end` represent for a pair of staring offset and end offset. + // But, what `end` represents is the maximum offset a neighbor list at a level can have. + // Therefore, it is required to traverse a list until getting a terminal `-1`. + // Ex: [1, 5, 20, 100, -1, -1, ..., -1] + final long begin = o + faissHnsw.getCumNumberNeighborPerLevel()[level]; + final long end = o + faissHnsw.getCumNumberNeighborPerLevel()[level + 1]; + loadNeighborIdList(begin, end); + } + + private void loadNeighborIdList(final long begin, final long end) { + // Make sure we have sufficient space for neighbor list + final long maxLength = end - begin; + if (neighborIdList == null || neighborIdList.length < maxLength) { + neighborIdList = new int[(int) (maxLength)]; + } + + // Seek to the first offset of neighbor list + try { + indexInput.seek(faissHnsw.getNeighbors().getBaseOffset() + Integer.BYTES * begin); + } catch (IOException e) { + throw new RuntimeException(e); + } + + // Fill the array with neighbor ids + int index = 0; + try { + for (long i = begin; i < end; i++) { + final int neighborId = indexInput.readInt(); + // The idea is that a vector does not always have a complete list of neighbor vectors. + // FAISS assigns a fixed size to the neighbor list and uses -1 to indicate missing entries. + // Therefore, we can safely stop once hit -1. + // For example, if the neighbor list size is 16 and a vector has only 8 neighbors, the list would appear as: + // [1, 4, 6, 8, 13, 17, 60, 88, -1, -1, ..., -1]. + if (neighborId >= 0) { + neighborIdList[index++] = neighborId; + } else { + break; + } + } + + // Set variables for navigation + numNeighbors = index; + nextNeighborIndex = 0; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public int size() { + return numVectors; + } + + @Override + public int nextNeighbor() { + if (nextNeighborIndex < numNeighbors) { + return neighborIdList[nextNeighborIndex++]; + } + + // Neighbor list has been exhausted. + return NO_MORE_DOCS; + } + + @Override + public int numLevels() { + return faissHnsw.getMaxLevel(); + } + + @Override + public int entryNode() { + return faissHnsw.getEntryPoint(); + } + + @Override + public NodesIterator getNodesOnLevel(final int level) { + try { + // Prepare input stream to `level` section. + final FaissSection levelsSection = faissHnsw.getLevels(); + final IndexInput levelIndexInput = indexInput.clone(); + levelIndexInput.seek(levelsSection.getBaseOffset()); + + // Count the number of vectors at the level. + int numVectorsAtLevel = 0; + for (int i = 0; i < numVectors; ++i) { + final int maxLevel = levelIndexInput.readInt(); + // Note that maxLevel=3 indicates that a vector exists level-0 (bottom), level-1 and level-2. + if (maxLevel > level) { + ++numVectorsAtLevel; + } + } + + // Return iterator + levelIndexInput.seek(levelsSection.getBaseOffset()); + return new NodesIterator(numVectorsAtLevel) { + int vectorNo = -1; + int numVisitedVectors = 0; + + @Override + public boolean hasNext() { + return numVisitedVectors < size; + } + + @Override + public int nextInt() { + while (true) { + try { + // Advance + ++vectorNo; + final int maxLevel = levelIndexInput.readInt(); + + // Check the level + if (maxLevel > level) { + ++numVisitedVectors; + return vectorNo; + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } + + @Override + public int consume(int[] ints) { + if (hasNext() == false) { + throw new NoSuchElementException(); + } + final int copySize = Math.min(size - numVisitedVectors, ints.length); + for (int i = 0; i < copySize; ++i) { + ints[i] = nextInt(); + } + return copySize; + } + }; + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissIdMapIndex.java b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissIdMapIndex.java index c5d43a384d..9d0ea13437 100644 --- a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissIdMapIndex.java +++ b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissIdMapIndex.java @@ -24,15 +24,12 @@ * However, these IDs only cover the sparse 30% of Lucene documents, so an ID mapping is needed to convert the internal physical vector ID * into the corresponding Lucene document ID. * If the mapping is an identity mapping, where each `i` is mapped to itself, we omit storing it to save memory. - *

- * FYI : - * IndexIDMap.h */ public class FaissIdMapIndex extends FaissIndex { public static final String IXMP = "IxMp"; @Getter - private FaissHNSWFlatIndex nestedIndex; + private FaissHNSWIndex nestedIndex; private long[] vectorIdToDocIdMapping; public FaissIdMapIndex() { @@ -41,7 +38,8 @@ public FaissIdMapIndex() { /** * Partially load id mapping and its nested index to which vector searching will be delegated. - * + * Faiss deserialization code : + * IndexIDMap.h * @param input An input stream for a FAISS HNSW graph file, allowing access to the neighbor list and vector locations. * @throws IOException */ @@ -50,11 +48,11 @@ protected void doLoad(IndexInput input) throws IOException { readCommonHeader(input); final FaissIndex nestedIndex = FaissIndex.load(input); - if (nestedIndex instanceof FaissHNSWFlatIndex) { - this.nestedIndex = (FaissHNSWFlatIndex) nestedIndex; + if (nestedIndex instanceof FaissHNSWIndex) { + this.nestedIndex = (FaissHNSWIndex) nestedIndex; } else { throw new IllegalStateException( - "Invalid nested index. Expected " + FaissHNSWFlatIndex.class.getSimpleName() + " , but got " + nestedIndex.getIndexType() + "Invalid nested index. Expected " + FaissHNSWIndex.class.getSimpleName() + " , but got " + nestedIndex.getIndexType() ); } diff --git a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissSection.java b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissSection.java new file mode 100644 index 0000000000..19c71754b4 --- /dev/null +++ b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissSection.java @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.memoryoptsearch.faiss; + +import lombok.Getter; +import org.apache.lucene.store.IndexInput; + +import java.io.IOException; + +/** + * This section maps to a section in FAISS index with a starting offset and the section size. + * A FAISS index file consists of multiple logical sections, each beginning with four bytes indicating an index type. A section may contain + * a nested section or vector storage, forming a tree structure with a top-level index as the starting point. + * + * Ex: FAISS index file + * +------------+ -> 0 + * + + + * + IxMp + -> FaissSection(offset=0, section_size=120) + * +------------+ -> 120 + * + + + * + IHNf + -> FaissSection(offset=120, section_size=380) + * +------------+ -> 500 + * + + + * + IxF2 + -> FaissSection(offset=500, section_size=700) + * +------------+ -> 1200 + * + */ +public class FaissSection { + @Getter + private long baseOffset; + @Getter + private long sectionSize; + + /** + * Mark the starting offset and the size of section then skip to the next section. + * + * @param input Input read stream. + * @param singleElementSize Size of atomic element. In file, it only stores the number of elements and the size of element will be + * used to calculate the actual size of section. Ex: size=100, element=int, then the actual section size=400. + * @throws IOException + */ + public FaissSection(IndexInput input, int singleElementSize) throws IOException { + this.sectionSize = input.readLong() * singleElementSize; + this.baseOffset = input.getFilePointer(); + // Skip the whole section and jump to the next section in the file. + try { + input.seek(baseOffset + sectionSize); + } catch (IOException e) { + throw new IOException("Failed to partial load where baseOffset=" + baseOffset + ", sectionSize=" + sectionSize, e); + } + } +} diff --git a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/IndexTypeToFaissIndexMapping.java b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/IndexTypeToFaissIndexMapping.java index d2c5b1fd77..45b9fe498e 100644 --- a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/IndexTypeToFaissIndexMapping.java +++ b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/IndexTypeToFaissIndexMapping.java @@ -10,7 +10,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; -import java.util.function.Supplier; +import java.util.function.Function; /** * This table maintains a mapping between FAISS index types and their corresponding index implementations. @@ -20,12 +20,14 @@ */ @UtilityClass public class IndexTypeToFaissIndexMapping { - private static final Map> INDEX_TYPE_TO_FAISS_INDEX; + private static final Map> INDEX_TYPE_TO_FAISS_INDEX; static { - final Map> mapping = new HashMap<>(); + final Map> mapping = new HashMap<>(); - mapping.put(FaissIdMapIndex.IXMP, FaissIdMapIndex::new); + mapping.put(FaissIdMapIndex.IXMP, (indexType) -> new FaissIdMapIndex()); + mapping.put(FaissHNSWIndex.IHNF, FaissHNSWIndex::new); + mapping.put(FaissHNSWIndex.IHNS, FaissHNSWIndex::new); INDEX_TYPE_TO_FAISS_INDEX = Collections.unmodifiableMap(mapping); } @@ -37,9 +39,9 @@ public class IndexTypeToFaissIndexMapping { * @return Actual implementation that is corresponding to the given index type. */ public FaissIndex getFaissIndex(final String indexType) { - final Supplier faissIndexSupplier = INDEX_TYPE_TO_FAISS_INDEX.get(indexType); + final Function faissIndexSupplier = INDEX_TYPE_TO_FAISS_INDEX.get(indexType); if (faissIndexSupplier != null) { - return faissIndexSupplier.get(); + return faissIndexSupplier.apply(indexType); } throw new UnsupportedFaissIndexException("Index type [" + indexType + "] is not supported."); } diff --git a/src/test/java/org/opensearch/knn/memoryoptsearch/FaissHNSWTests.java b/src/test/java/org/opensearch/knn/memoryoptsearch/FaissHNSWTests.java new file mode 100644 index 0000000000..139227e494 --- /dev/null +++ b/src/test/java/org/opensearch/knn/memoryoptsearch/FaissHNSWTests.java @@ -0,0 +1,176 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.memoryoptsearch; + +import lombok.SneakyThrows; +import org.apache.lucene.store.IndexInput; +import org.opensearch.common.lucene.store.ByteArrayIndexInput; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.memoryoptsearch.faiss.FaissHNSW; + +import java.net.URL; +import java.nio.file.Files; +import java.nio.file.Path; + +public class FaissHNSWTests extends KNNTestCase { + @SneakyThrows + public void testLoadGraphWithSingleVector() { + final IndexInput indexInput = loadHnswBinary("data/memoryoptsearch/faiss_hnsw_one_vector.bin"); + final FaissHNSW faissHNSW = new FaissHNSW(); + faissHNSW.load(indexInput, 1); + doTest(faissHNSW, new int[] { 0, 32, 48, 64, 80, 96, 112, 128, 144 }, new long[] { 0, 32 }, 160, 128, 0, 0, 100); + } + + @SneakyThrows + public void testLoadGraphWithNVectors() { + final IndexInput indexInput = loadHnswBinary("data/memoryoptsearch/faiss_hnsw_100_vectors.bin"); + final FaissHNSW faissHNSW = new FaissHNSW(); + faissHNSW.load(indexInput, 100); + final int[] cumulativeNumNeighbors = new int[] { 0, 32, 48, 64, 80, 96, 112, 128, 144 }; + doTest(faissHNSW, cumulativeNumNeighbors, ANSWER_OFFSETS, 1348, 13184, 12, 1, 100); + } + + @SneakyThrows + public static IndexInput loadHnswBinary(final String relativePath) { + final URL hnswWithOneVector = FaissHNSWTests.class.getClassLoader().getResource(relativePath); + final byte[] bytes = Files.readAllBytes(Path.of(hnswWithOneVector.toURI())); + final IndexInput indexInput = new ByteArrayIndexInput("FaissHNSWTests", bytes); + return indexInput; + } + + private void doTest( + final FaissHNSW faissHNSW, + final int[] cumulativeNumNeighbors, + final long[] offsets, + final long neighborsBaseOffset, + final long neighborsSectionSize, + final int entryPoint, + final int maxLevel, + final int efSearch + ) { + // Cumulative number of neighbor per level + assertArrayEquals(cumulativeNumNeighbors, faissHNSW.getCumNumberNeighborPerLevel()); + + // offsets + assertArrayEquals(offsets, faissHNSW.getOffsets()); + + // neighbors + assertEquals(neighborsBaseOffset, faissHNSW.getNeighbors().getBaseOffset()); + assertEquals(neighborsSectionSize, faissHNSW.getNeighbors().getSectionSize()); + + // entry point + assertEquals(entryPoint, faissHNSW.getEntryPoint()); + + // max level + assertEquals(maxLevel, faissHNSW.getMaxLevel()); + + // efSearch + assertEquals(efSearch, faissHNSW.getEfSearch()); + } + + private static final long[] ANSWER_OFFSETS = new long[] { + 0, + 32, + 64, + 96, + 128, + 160, + 192, + 224, + 256, + 288, + 320, + 352, + 400, + 448, + 480, + 512, + 544, + 576, + 608, + 640, + 672, + 704, + 736, + 784, + 816, + 848, + 880, + 912, + 944, + 976, + 1008, + 1040, + 1072, + 1104, + 1136, + 1168, + 1200, + 1248, + 1280, + 1312, + 1344, + 1376, + 1408, + 1440, + 1472, + 1504, + 1536, + 1568, + 1600, + 1632, + 1664, + 1696, + 1728, + 1776, + 1808, + 1840, + 1872, + 1904, + 1936, + 1968, + 2000, + 2032, + 2064, + 2096, + 2128, + 2160, + 2192, + 2224, + 2256, + 2288, + 2320, + 2352, + 2384, + 2416, + 2464, + 2496, + 2528, + 2560, + 2592, + 2624, + 2656, + 2688, + 2720, + 2752, + 2784, + 2816, + 2848, + 2880, + 2912, + 2944, + 2976, + 3008, + 3040, + 3072, + 3104, + 3136, + 3168, + 3200, + 3232, + 3264, + 3296 }; +} diff --git a/src/test/java/org/opensearch/knn/memoryoptsearch/FaissHnswGraphTests.java b/src/test/java/org/opensearch/knn/memoryoptsearch/FaissHnswGraphTests.java new file mode 100644 index 0000000000..e69598fa6f --- /dev/null +++ b/src/test/java/org/opensearch/knn/memoryoptsearch/FaissHnswGraphTests.java @@ -0,0 +1,120 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.memoryoptsearch; + +import lombok.SneakyThrows; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.hnsw.HnswGraph; +import org.mockito.Mockito; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.memoryoptsearch.faiss.FaissHNSW; +import org.opensearch.knn.memoryoptsearch.faiss.FaissHNSWIndex; +import org.opensearch.knn.memoryoptsearch.faiss.FaissHnswGraph; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Set; + +import static org.mockito.Mockito.when; +import static org.opensearch.knn.memoryoptsearch.FaissHNSWTests.loadHnswBinary; + +public class FaissHnswGraphTests extends KNNTestCase { + private static final int NUM_VECTORS = 100; + + @SneakyThrows + public void testTraverseHnswGraph() { + final FaissHnswGraph graph = prepareFaissHnswGraph(); + + // Validate graph + graph.seek(0, 0); + assertArrayEquals(FIRST_NEIGHBOR_LIST_AT_0_LEVEL, getNeighborIdList(graph)); + + graph.seek(0, 99); + assertArrayEquals(NINETY_NINETH_NEIGHBOR_LIST_AT_0_LEVEL, getNeighborIdList(graph)); + + graph.seek(1, 0); + assertArrayEquals(FIRST_NEIGHBOR_LIST_AT_1_LEVEL, getNeighborIdList(graph)); + } + + @SneakyThrows + public void testNodesIterator() { + final FaissHnswGraph graph = prepareFaissHnswGraph(); + // Iterate all vectors at level-0 + HnswGraph.NodesIterator iterator = graph.getNodesOnLevel(0); + Set vectorIds = new HashSet<>(); + while (iterator.hasNext()) { + vectorIds.add(iterator.next()); + } + assertEquals(NUM_VECTORS, vectorIds.size()); + for (int i = 0; i < NUM_VECTORS; ++i) { + assertTrue(vectorIds.contains(i)); + } + + // Test bulk + int[] buffer = new int[37]; + iterator = graph.getNodesOnLevel(0); + + // Copied 37/100 + int copied = iterator.consume(buffer); + assertEquals(buffer.length, copied); + + // Copied 74/100 + copied = iterator.consume(buffer); + assertEquals(buffer.length, copied); + + // Copied 26 more, 100/100. + copied = iterator.consume(buffer); + assertEquals(26, copied); + + try { + iterator.consume(buffer); + fail(); + } catch (NoSuchElementException e) { + // exhausted + } + } + + @SneakyThrows + private static int[] getNeighborIdList(final FaissHnswGraph graph) { + final List neighborIds = new ArrayList<>(); + while (true) { + final int vectorId = graph.nextNeighbor(); + if (vectorId != DocIdSetIterator.NO_MORE_DOCS) { + neighborIds.add(vectorId); + } else { + break; + } + } + + return neighborIds.stream().mapToInt(i -> i).toArray(); + } + + @SneakyThrows + private static FaissHnswGraph prepareFaissHnswGraph() { + // Prepare parent index + final FaissHNSWIndex parentIndex = Mockito.mock(FaissHNSWIndex.class); + IndexInput indexInput = loadHnswBinary("data/memoryoptsearch/faiss_hnsw_100_vectors.bin"); + + // Prepare FaissHNSW + final int totalNumberOfVectors = 100; + final FaissHNSW faissHNSW = new FaissHNSW(); + faissHNSW.load(indexInput, totalNumberOfVectors); + when(parentIndex.getHnsw()).thenReturn(faissHNSW); + when(parentIndex.getTotalNumberOfVectors()).thenReturn(totalNumberOfVectors); + + // Create LuceneFaissHnswGraph + indexInput = loadHnswBinary("data/memoryoptsearch/faiss_hnsw_100_vectors.bin"); + final FaissHnswGraph graph = new FaissHnswGraph(faissHNSW, totalNumberOfVectors, indexInput); + return graph; + } + + private static final int[] FIRST_NEIGHBOR_LIST_AT_0_LEVEL = new int[] { 25, 10, 11, 16, 82 }; + private static final int[] NINETY_NINETH_NEIGHBOR_LIST_AT_0_LEVEL = new int[] { 79, 14, 51, 42, 87, 11, 34, 60, 77, 46, 37, 62 }; + private static final int[] FIRST_NEIGHBOR_LIST_AT_1_LEVEL = new int[] { 51, 31, 10, 33, 11, 23, 97, 16, 65, 32, 24, 98 }; +} diff --git a/src/test/java/org/opensearch/knn/memoryoptsearch/FaissIdMapIndexTests.java b/src/test/java/org/opensearch/knn/memoryoptsearch/FaissIdMapIndexTests.java index 38190ab143..29cbdfba42 100644 --- a/src/test/java/org/opensearch/knn/memoryoptsearch/FaissIdMapIndexTests.java +++ b/src/test/java/org/opensearch/knn/memoryoptsearch/FaissIdMapIndexTests.java @@ -16,7 +16,7 @@ import org.opensearch.common.lucene.store.ByteArrayIndexInput; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.memoryoptsearch.faiss.FaissHNSWFlatIndex; +import org.opensearch.knn.memoryoptsearch.faiss.FaissHNSWIndex; import org.opensearch.knn.memoryoptsearch.faiss.FaissIdMapIndex; import org.opensearch.knn.memoryoptsearch.faiss.FaissIndex; @@ -173,7 +173,7 @@ private static FaissIdMapIndex triggerLoadAndGetIndex( // Mock static `load` to return a dummy mock try (MockedStatic mockStaticFaissIndex = mockStatic(FaissIndex.class)) { // Nested index - final FaissHNSWFlatIndex nestedIndex = mock(FaissHNSWFlatIndex.class); + final FaissHNSWIndex nestedIndex = mock(FaissHNSWIndex.class); mockStaticFaissIndex.when(() -> FaissIndex.load(any())).thenReturn(nestedIndex); // Byte vectors diff --git a/src/test/resources/data/memoryoptsearch/faiss_flat_float_50_vectors_128_dim.bin b/src/test/resources/data/memoryoptsearch/faiss_flat_float_50_vectors_128_dim.bin new file mode 100644 index 0000000000..cf92a4c9f0 Binary files /dev/null and b/src/test/resources/data/memoryoptsearch/faiss_flat_float_50_vectors_128_dim.bin differ diff --git a/src/test/resources/data/memoryoptsearch/faiss_hnsw_100_vectors.bin b/src/test/resources/data/memoryoptsearch/faiss_hnsw_100_vectors.bin new file mode 100644 index 0000000000..f09cc6de6f Binary files /dev/null and b/src/test/resources/data/memoryoptsearch/faiss_hnsw_100_vectors.bin differ diff --git a/src/test/resources/data/memoryoptsearch/faiss_hnsw_one_vector.bin b/src/test/resources/data/memoryoptsearch/faiss_hnsw_one_vector.bin new file mode 100644 index 0000000000..f5d3483cb9 Binary files /dev/null and b/src/test/resources/data/memoryoptsearch/faiss_hnsw_one_vector.bin differ