diff --git a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissHNSW.java b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissHNSW.java new file mode 100644 index 0000000000..7f38a1394c --- /dev/null +++ b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissHNSW.java @@ -0,0 +1,94 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.memoryoptsearch.faiss; + +import lombok.Getter; +import org.apache.lucene.store.IndexInput; + +import java.io.IOException; + +/** + * While it follows the same steps as the original FAISS deserialization, differences in how the JVM and C++ handle floating-point + * calculations can lead to slight variations in results. However, such cases are very rare, and in most instances, the results are + * identical to FAISS. Even when there are ranking differences, they do not impact the precision or recall of the search. + * For more details, refer to the [FAISS HNSW implementation]( + * ...). + */ +@Getter +public class FaissHNSW { + // Cumulative number of neighbors per each level. + private int[] cumNumberNeighborPerLevel; + // offsets[i]:offset[i+1] gives all the neighbors for vector i + // Offset to be added to cumNumberNeighborPerLevel[level] to get the actual start offset of neighbor list. + private long[] offsets = null; + // Neighbor list storage. + private FaissSection neighbors; + // levels[i] = the maximum levels of `i`th vector + 1. + // Ex: If 544th vector has three levels (e.g. 0-level, 1-level, 2-level), then levels[433] would be 3. + // This indicates that 544th vector exists at all levels of (0-level, 1-level, 2-level). + private FaissSection levels; + // Entry point in HNSW graph + private int entryPoint; + // Maximum level of HNSW graph + private int maxLevel = -1; + // Default efSearch parameter. This determines the navigation queue size. + // More value, algorithm will more navigate candidates. + private int efSearch = 16; + // Total number of vectors stored in graph. + private long totalNumberOfVectors; + + /** + * Partially loads the FAISS HNSW graph from the provided index input stream. + * The graph is divided into multiple sections, and this method marks the starting offset of each section then skip to the next + * section instead of loading the entire graph into memory. During the search, bytes will be accessed via {@link IndexInput}. + * + * @param input An input stream for a FAISS HNSW graph file, allowing access to the neighbor list and vector locations. + * @param totalNumberOfVectors The total number of vectors stored in the graph. + * + * FYI FAISS Deserialization + * + * @throws IOException + */ + public void load(IndexInput input, long totalNumberOfVectors) throws IOException { + // Total number of vectors + this.totalNumberOfVectors = totalNumberOfVectors; + + // We don't use `double[] assignProbas` for search. It is for index construction. + long size = input.readLong(); + input.skipBytes(Double.BYTES * size); + + // Accumulate number of neighbor per each level. + size = input.readLong(); + cumNumberNeighborPerLevel = new int[Math.toIntExact(size)]; + if (size > 0) { + input.readInts(cumNumberNeighborPerLevel, 0, (int) size); + } + + // Maximum levels per each vector + levels = new FaissSection(input, Integer.BYTES); + + // Load `offsets` into memory. + size = input.readLong(); + offsets = new long[(int) size]; + input.readLongs(offsets, 0, offsets.length); + + // Mark neighbor list section. + neighbors = new FaissSection(input, Integer.BYTES); + + // HNSW graph parameters + entryPoint = input.readInt(); + + maxLevel = input.readInt(); + + // Gets efConstruction. We don't use this field. It's for index building. + input.readInt(); + + efSearch = input.readInt(); + + // dummy read a deprecated field. + input.readInt(); + } +} diff --git a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissHNSWFlatIndex.java b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissHNSWFlatIndex.java deleted file mode 100644 index 023b877c32..0000000000 --- a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissHNSWFlatIndex.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.memoryoptsearch.faiss; - -import org.apache.lucene.index.ByteVectorValues; -import org.apache.lucene.index.FloatVectorValues; -import org.apache.lucene.index.VectorEncoding; -import org.apache.lucene.store.IndexInput; - -import java.io.IOException; - -/** - * A flat HNSW index that contains both an HNSW graph and flat vector storage. - * This is the ported version of `IndexHNSW` from FAISS. - * For more details, please refer to ... - */ -public class FaissHNSWFlatIndex extends FaissIndex { - public FaissHNSWFlatIndex(final String indexType) { - super(indexType); - } - - @Override - protected void doLoad(IndexInput input) throws IOException { - // TODO(KDY) : This will be covered in part-3 (FAISS HNSW). - } - - @Override - public VectorEncoding getVectorEncoding() { - // TODO(KDY) : This will be covered in part-3 (FAISS HNSW). - return null; - } - - @Override - public FloatVectorValues getFloatValues(IndexInput indexInput) throws IOException { - // TODO(KDY) : This will be covered in part-3 (FAISS HNSW). - return null; - } - - @Override - public ByteVectorValues getByteValues(IndexInput indexInput) throws IOException { - // TODO(KDY) : This will be covered in part-3 (FAISS HNSW). - return null; - } -} diff --git a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissHNSWIndex.java b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissHNSWIndex.java new file mode 100644 index 0000000000..1122e2b8d9 --- /dev/null +++ b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissHNSWIndex.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.memoryoptsearch.faiss; + +import lombok.Getter; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.store.IndexInput; + +import java.io.IOException; + +/** + * A flat HNSW index that contains both an HNSW graph and flat vector storage. + * This is the ported version of `IndexHNSW` from FAISS. + * For more details, please refer to ... + */ +public class FaissHNSWIndex extends FaissIndex { + // Flat float vector format - + // https://github.com/facebookresearch/faiss/blob/15491a1e4f5a513a8684e5b7262ef4ec22eda19d/faiss/IndexHNSW.h#L122 + public static final String IHNF = "IHNf"; + // Quantized flat format with HNSW - + // https://github.com/facebookresearch/faiss/blob/15491a1e4f5a513a8684e5b7262ef4ec22eda19d/faiss/IndexHNSW.h#L144C8-L144C19 + public static final String IHNS = "IHNs"; + + @Getter + private FaissHNSW hnsw = new FaissHNSW(); + private FaissIndex flatVectors; + private VectorEncoding vectorEncoding; + + public FaissHNSWIndex(final String indexType) { + super(indexType); + + // Set encoding + if (indexType.equals(IHNF)) { + vectorEncoding = VectorEncoding.FLOAT32; + } else if (indexType.equals(IHNS)) { + vectorEncoding = VectorEncoding.BYTE; + } else { + throw new IllegalStateException("Unsupported index type: " + indexType + " in " + FaissHNSWIndex.class.getSimpleName()); + } + } + + /** + * Loading HNSW graph and nested storage index. + * For more details, please refer to + * ... + * @param input + * @throws IOException + */ + @Override + protected void doLoad(IndexInput input) throws IOException { + // Read common header + readCommonHeader(input); + + // Partial load HNSW graph + hnsw.load(input, getTotalNumberOfVectors()); + + // Partial load flat vector storage + flatVectors = FaissIndex.load(input); + } + + @Override + public VectorEncoding getVectorEncoding() { + return vectorEncoding; + } + + @Override + public FloatVectorValues getFloatValues(IndexInput indexInput) throws IOException { + return flatVectors.getFloatValues(indexInput); + } + + @Override + public ByteVectorValues getByteValues(IndexInput indexInput) throws IOException { + return flatVectors.getByteValues(indexInput); + } +} diff --git a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissHnswGraph.java b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissHnswGraph.java new file mode 100644 index 0000000000..682731a1f9 --- /dev/null +++ b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissHnswGraph.java @@ -0,0 +1,187 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.memoryoptsearch.faiss; + +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.hnsw.HnswGraph; + +import java.io.IOException; +import java.util.NoSuchElementException; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +/** + * This graph implements Lucene's HNSW graph interface using the FAISS HNSW graph. Conceptually, both libraries represent the graph + * similarly, maintaining a list of neighbor IDs. This implementation acts as a bridge, enabling Lucene's HNSW graph searcher to perform + * vector searches on a FAISS index. + * + * NOTE: This is not thread safe. It should be created every time in {@link KnnVectorsReader}.search likewise + * OffHeapHnswGraph + * in Lucene. + */ +public class FaissHnswGraph extends HnswGraph { + private final FaissHNSW faissHnsw; + private final IndexInput indexInput; + private final int numVectors; + private int[] neighborIdList; + private int numNeighbors; + private int nextNeighborIndex; + + public FaissHnswGraph(final FaissHNSW faissHNSW, final int numVectors, final IndexInput indexInput) { + this.faissHnsw = faissHNSW; + this.indexInput = indexInput; + this.numVectors = numVectors; + } + + /** + * Seek to the starting offset of neighbor ids at the given `level`. In which, it will load all ids into a buffer array. + * @param level The level of graph + * @param internalVectorId An internal vector id. + */ + @Override + public void seek(int level, int internalVectorId) { + // Get a relative starting offset of neighbor list at `level`. + long o = faissHnsw.getOffsets()[internalVectorId]; + + // `begin` and `end` represent for a pair of staring offset and end offset. + // But, what `end` represents is the maximum offset a neighbor list at a level can have. + // Therefore, it is required to traverse a list until getting a terminal `-1`. + // Ex: [1, 5, 20, 100, -1, -1, ..., -1] + final long begin = o + faissHnsw.getCumNumberNeighborPerLevel()[level]; + final long end = o + faissHnsw.getCumNumberNeighborPerLevel()[level + 1]; + loadNeighborIdList(begin, end); + } + + private void loadNeighborIdList(final long begin, final long end) { + // Make sure we have sufficient space for neighbor list + final long maxLength = end - begin; + if (neighborIdList == null || neighborIdList.length < maxLength) { + neighborIdList = new int[(int) (maxLength)]; + } + + // Seek to the first offset of neighbor list + try { + indexInput.seek(faissHnsw.getNeighbors().getBaseOffset() + Integer.BYTES * begin); + } catch (IOException e) { + throw new RuntimeException(e); + } + + // Fill the array with neighbor ids + int index = 0; + try { + for (long i = begin; i < end; i++) { + final int neighborId = indexInput.readInt(); + // The idea is that a vector does not always have a complete list of neighbor vectors. + // FAISS assigns a fixed size to the neighbor list and uses -1 to indicate missing entries. + // Therefore, we can safely stop once hit -1. + // For example, if the neighbor list size is 16 and a vector has only 8 neighbors, the list would appear as: + // [1, 4, 6, 8, 13, 17, 60, 88, -1, -1, ..., -1]. + if (neighborId >= 0) { + neighborIdList[index++] = neighborId; + } else { + break; + } + } + + // Set variables for navigation + numNeighbors = index; + nextNeighborIndex = 0; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public int size() { + return numVectors; + } + + @Override + public int nextNeighbor() { + if (nextNeighborIndex < numNeighbors) { + return neighborIdList[nextNeighborIndex++]; + } + + // Neighbor list has been exhausted. + return NO_MORE_DOCS; + } + + @Override + public int numLevels() { + return faissHnsw.getMaxLevel(); + } + + @Override + public int entryNode() { + return faissHnsw.getEntryPoint(); + } + + @Override + public NodesIterator getNodesOnLevel(final int level) { + try { + // Prepare input stream to `level` section. + final FaissSection levelsSection = faissHnsw.getLevels(); + final IndexInput levelIndexInput = indexInput.clone(); + levelIndexInput.seek(levelsSection.getBaseOffset()); + + // Count the number of vectors at the level. + int numVectorsAtLevel = 0; + for (int i = 0; i < numVectors; ++i) { + final int maxLevel = levelIndexInput.readInt(); + // Note that maxLevel=3 indicates that a vector exists level-0 (bottom), level-1 and level-2. + if (maxLevel > level) { + ++numVectorsAtLevel; + } + } + + // Return iterator + levelIndexInput.seek(levelsSection.getBaseOffset()); + return new NodesIterator(numVectorsAtLevel) { + int vectorNo = -1; + int numVisitedVectors = 0; + + @Override + public boolean hasNext() { + return numVisitedVectors < size; + } + + @Override + public int nextInt() { + while (true) { + try { + // Advance + ++vectorNo; + final int maxLevel = levelIndexInput.readInt(); + + // Check the level + if (maxLevel > level) { + ++numVisitedVectors; + return vectorNo; + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } + + @Override + public int consume(int[] ints) { + if (hasNext() == false) { + throw new NoSuchElementException(); + } + final int copySize = Math.min(size - numVisitedVectors, ints.length); + for (int i = 0; i < copySize; ++i) { + ints[i] = nextInt(); + } + return copySize; + } + }; + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissIdMapIndex.java b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissIdMapIndex.java index c5d43a384d..9d0ea13437 100644 --- a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissIdMapIndex.java +++ b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissIdMapIndex.java @@ -24,15 +24,12 @@ * However, these IDs only cover the sparse 30% of Lucene documents, so an ID mapping is needed to convert the internal physical vector ID * into the corresponding Lucene document ID. * If the mapping is an identity mapping, where each `i` is mapped to itself, we omit storing it to save memory. - *
- * FYI :
- * IndexIDMap.h
*/
public class FaissIdMapIndex extends FaissIndex {
public static final String IXMP = "IxMp";
@Getter
- private FaissHNSWFlatIndex nestedIndex;
+ private FaissHNSWIndex nestedIndex;
private long[] vectorIdToDocIdMapping;
public FaissIdMapIndex() {
@@ -41,7 +38,8 @@ public FaissIdMapIndex() {
/**
* Partially load id mapping and its nested index to which vector searching will be delegated.
- *
+ * Faiss deserialization code :
+ * IndexIDMap.h
* @param input An input stream for a FAISS HNSW graph file, allowing access to the neighbor list and vector locations.
* @throws IOException
*/
@@ -50,11 +48,11 @@ protected void doLoad(IndexInput input) throws IOException {
readCommonHeader(input);
final FaissIndex nestedIndex = FaissIndex.load(input);
- if (nestedIndex instanceof FaissHNSWFlatIndex) {
- this.nestedIndex = (FaissHNSWFlatIndex) nestedIndex;
+ if (nestedIndex instanceof FaissHNSWIndex) {
+ this.nestedIndex = (FaissHNSWIndex) nestedIndex;
} else {
throw new IllegalStateException(
- "Invalid nested index. Expected " + FaissHNSWFlatIndex.class.getSimpleName() + " , but got " + nestedIndex.getIndexType()
+ "Invalid nested index. Expected " + FaissHNSWIndex.class.getSimpleName() + " , but got " + nestedIndex.getIndexType()
);
}
diff --git a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissSection.java b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissSection.java
new file mode 100644
index 0000000000..19c71754b4
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissSection.java
@@ -0,0 +1,55 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.memoryoptsearch.faiss;
+
+import lombok.Getter;
+import org.apache.lucene.store.IndexInput;
+
+import java.io.IOException;
+
+/**
+ * This section maps to a section in FAISS index with a starting offset and the section size.
+ * A FAISS index file consists of multiple logical sections, each beginning with four bytes indicating an index type. A section may contain
+ * a nested section or vector storage, forming a tree structure with a top-level index as the starting point.
+ *
+ * Ex: FAISS index file
+ * +------------+ -> 0
+ * + +
+ * + IxMp + -> FaissSection(offset=0, section_size=120)
+ * +------------+ -> 120
+ * + +
+ * + IHNf + -> FaissSection(offset=120, section_size=380)
+ * +------------+ -> 500
+ * + +
+ * + IxF2 + -> FaissSection(offset=500, section_size=700)
+ * +------------+ -> 1200
+ *
+ */
+public class FaissSection {
+ @Getter
+ private long baseOffset;
+ @Getter
+ private long sectionSize;
+
+ /**
+ * Mark the starting offset and the size of section then skip to the next section.
+ *
+ * @param input Input read stream.
+ * @param singleElementSize Size of atomic element. In file, it only stores the number of elements and the size of element will be
+ * used to calculate the actual size of section. Ex: size=100, element=int, then the actual section size=400.
+ * @throws IOException
+ */
+ public FaissSection(IndexInput input, int singleElementSize) throws IOException {
+ this.sectionSize = input.readLong() * singleElementSize;
+ this.baseOffset = input.getFilePointer();
+ // Skip the whole section and jump to the next section in the file.
+ try {
+ input.seek(baseOffset + sectionSize);
+ } catch (IOException e) {
+ throw new IOException("Failed to partial load where baseOffset=" + baseOffset + ", sectionSize=" + sectionSize, e);
+ }
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/IndexTypeToFaissIndexMapping.java b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/IndexTypeToFaissIndexMapping.java
index d2c5b1fd77..45b9fe498e 100644
--- a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/IndexTypeToFaissIndexMapping.java
+++ b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/IndexTypeToFaissIndexMapping.java
@@ -10,7 +10,7 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
-import java.util.function.Supplier;
+import java.util.function.Function;
/**
* This table maintains a mapping between FAISS index types and their corresponding index implementations.
@@ -20,12 +20,14 @@
*/
@UtilityClass
public class IndexTypeToFaissIndexMapping {
- private static final Map