Skip to content

Commit 2e7e464

Browse files
author
Dooyong Kim
committed
Added FaissHNSW and bridge to Lucene HNSW graph.
Signed-off-by: Dooyong Kim <kdooyong@amazon.com> Co-authored-by: Dooyong Kim <kdooyong@amazon.com>
1 parent 4e80b65 commit 2e7e464

File tree

10 files changed

+547
-13
lines changed

10 files changed

+547
-13
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.knn.memoryoptsearch.faiss;
7+
8+
import lombok.Getter;
9+
import org.apache.lucene.store.IndexInput;
10+
11+
import java.io.IOException;
12+
13+
/**
14+
* Ported implementation of the FAISS HNSW graph search algorithm.
15+
* While it follows the same steps as the original FAISS implementation, differences in how the JVM and C++ handle floating-point
16+
* calculations can lead to slight variations in results. However, such cases are very rare, and in most instances, the results are
17+
* identical to FAISS. Even when there are ranking differences, they do not impact the precision or recall of the search.
18+
* For more details, refer to the [FAISS HNSW implementation](
19+
* <a href="https://github.com/facebookresearch/faiss/blob/main/faiss/impl/HNSW.h">...</a>).
20+
*/
21+
@Getter
22+
public class FaissHNSW {
23+
// Cumulative number of neighbors per each level.
24+
private int[] cumNumberNeighborPerLevel;
25+
// Offset to be added to cumNumberNeighborPerLevel[level] to get the actual start offset of neighbor list.
26+
private long[] offsets = null;
27+
// Neighbor list storage.
28+
private final FaissSection neighbors = new FaissSection();
29+
// Entry point in HNSW graph
30+
private int entryPoint;
31+
// Maximum level of HNSW graph
32+
private int maxLevel = -1;
33+
// Default efSearch parameter. This determines the navigation queue size.
34+
// More value, algorithm will more navigate candidates.
35+
private int efSearch = 16;
36+
// Total number of vectors stored in graph.
37+
private long totalNumberOfVectors;
38+
39+
/**
40+
* Partially loads the FAISS HNSW graph from the provided index input stream.
41+
* The graph is divided into multiple sections, and this method marks the starting offset of each section then skip to the next
42+
* section instead of loading the entire graph into memory. During the search, bytes will be accessed via {@link IndexInput}.
43+
*
44+
* @param input An input stream for a FAISS HNSW graph file, allowing access to the neighbor list and vector locations.
45+
* @param totalNumberOfVectors The total number of vectors stored in the graph.
46+
* @return {@link FaissHNSW}, a graph search structure that represents the FAISS HNSW graph
47+
*
48+
* FYI <a href="https://github.com/facebookresearch/faiss/blob/main/faiss/impl/index_read.cpp#L363">FAISS Deserialization</a>
49+
*
50+
* @throws IOException
51+
*/
52+
public static FaissHNSW load(IndexInput input, long totalNumberOfVectors) throws IOException {
53+
// Total number of vectors
54+
FaissHNSW faissHNSW = new FaissHNSW();
55+
faissHNSW.totalNumberOfVectors = totalNumberOfVectors;
56+
57+
// We don't use `double[] assignProbas` for search. It is for index construction.
58+
long size = input.readLong();
59+
input.skipBytes(Double.BYTES * size);
60+
61+
// Accumulate number of neighbor per each level.
62+
size = input.readLong();
63+
faissHNSW.cumNumberNeighborPerLevel = new int[(int) size];
64+
if (size > 0) {
65+
input.readInts(faissHNSW.cumNumberNeighborPerLevel, 0, (int) size);
66+
}
67+
68+
// We don't use `level`.
69+
final FaissSection levels = new FaissSection();
70+
levels.markSection(input, Integer.BYTES);
71+
72+
// Load `offsets` into memory.
73+
size = input.readLong();
74+
faissHNSW.offsets = new long[(int) size];
75+
input.readLongs(faissHNSW.offsets, 0, faissHNSW.offsets.length);
76+
77+
// Mark neighbor list section.
78+
faissHNSW.neighbors.markSection(input, Integer.BYTES);
79+
80+
// HNSW graph parameters
81+
faissHNSW.entryPoint = input.readInt();
82+
83+
faissHNSW.maxLevel = input.readInt();
84+
85+
// We don't use this field. It's for index building.
86+
final int efConstruction = input.readInt();
87+
88+
faissHNSW.efSearch = input.readInt();
89+
90+
// dummy read a deprecated field.
91+
input.readInt();
92+
93+
return faissHNSW;
94+
}
95+
}

src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissHNSWFlatIndex.java

+29-7
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.knn.memoryoptsearch.faiss;
77

8+
import lombok.Getter;
89
import org.apache.lucene.index.ByteVectorValues;
910
import org.apache.lucene.index.FloatVectorValues;
1011
import org.apache.lucene.index.VectorEncoding;
@@ -18,30 +19,51 @@
1819
* For more details, please refer to <a href="https://github.com/facebookresearch/faiss/blob/main/faiss/IndexHNSW.h">...</a>
1920
*/
2021
public class FaissHNSWFlatIndex extends FaissIndex {
22+
public static final String IHNF = "IHNf";
23+
public static final String IHNS = "IHNs";
24+
25+
@Getter
26+
private FaissHNSW hnsw = new FaissHNSW();
27+
private FaissIndex storage;
28+
private VectorEncoding vectorEncoding;
29+
2130
public FaissHNSWFlatIndex(final String indexType) {
2231
super(indexType);
32+
33+
// Set encoding
34+
if (indexType.equals(IHNF)) {
35+
vectorEncoding = VectorEncoding.FLOAT32;
36+
} else if (indexType.equals(IHNS)) {
37+
vectorEncoding = VectorEncoding.BYTE;
38+
} else {
39+
throw new IllegalStateException("Unsupported index type: " + indexType + " in " + FaissHNSWFlatIndex.class.getSimpleName());
40+
}
2341
}
2442

2543
@Override
2644
protected void doLoad(IndexInput input) throws IOException {
27-
// TODO(KDY) : This will be covered in part-3 (FAISS HNSW).
45+
// Read common header
46+
readCommonHeader(input);
47+
48+
// Partial load HNSW graph
49+
hnsw = FaissHNSW.load(input, getTotalNumberOfVectors());
50+
51+
// Partial load flat vector storage
52+
storage = FaissIndex.load(input);
2853
}
2954

3055
@Override
3156
public VectorEncoding getVectorEncoding() {
32-
// TODO(KDY) : This will be covered in part-3 (FAISS HNSW).
33-
return null;
57+
return vectorEncoding;
3458
}
3559

3660
@Override
3761
public FloatVectorValues getFloatValues(IndexInput indexInput) throws IOException {
38-
// TODO(KDY) : This will be covered in part-3 (FAISS HNSW).
39-
return null;
62+
return storage.getFloatValues(indexInput);
4063
}
4164

4265
@Override
4366
public ByteVectorValues getByteValues(IndexInput indexInput) throws IOException {
44-
// TODO(KDY) : This will be covered in part-3 (FAISS HNSW).
45-
return null;
67+
return storage.getByteValues(indexInput);
4668
}
4769
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.knn.memoryoptsearch.faiss;
7+
8+
import lombok.Getter;
9+
import org.apache.lucene.store.IndexInput;
10+
11+
import java.io.IOException;
12+
13+
/**
14+
* This section maps to a section in FAISS index with a starting offset and the section size.
15+
* A FAISS index file consists of multiple logical sections, each beginning with four bytes indicating an index type. A section may contain
16+
* a nested section or vector storage, forming a tree structure with a top-level index as the starting point.
17+
*
18+
* Ex: FAISS index file
19+
* +------------+ -> 0
20+
* + +
21+
* + IxMp + -> FaissSection(offset=0, section_size=120)
22+
* +------------+ -> 120
23+
* + +
24+
* + IHNf + -> FaissSection(offset=120, section_size=380)
25+
* +------------+ -> 500
26+
* + +
27+
* + IxF2 + -> FaissSection(offset=500, section_size=700)
28+
* +------------+ -> 1200
29+
*
30+
*/
31+
public class FaissSection {
32+
@Getter
33+
protected long baseOffset;
34+
@Getter
35+
protected long sectionSize;
36+
37+
/**
38+
* Mark the starting offset and the size of section then skip to the next section.
39+
*
40+
* @param input Input read stream.
41+
* @param singleElementSize Size of atomic element. In file, it only stores the number of elements and the size of element will be
42+
* used to calculate the actual size of section. Ex: size=100, element=int, then the actual section size=400.
43+
* @throws IOException
44+
*/
45+
public void markSection(IndexInput input, int singleElementSize) throws IOException {
46+
this.sectionSize = input.readLong() * singleElementSize;
47+
this.baseOffset = input.getFilePointer();
48+
// Skip the whole section and jump to the next section in the file.
49+
try {
50+
input.seek(baseOffset + sectionSize);
51+
} catch (IOException e) {
52+
throw new IOException("Failed to partial load where baseOffset=" + baseOffset + ", sectionSize=" + sectionSize, e);
53+
}
54+
}
55+
}

src/main/java/org/opensearch/knn/memoryoptsearch/faiss/IndexTypeToFaissIndexMapping.java

+8-6
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import java.util.Collections;
1111
import java.util.HashMap;
1212
import java.util.Map;
13-
import java.util.function.Supplier;
13+
import java.util.function.Function;
1414

1515
/**
1616
* This table maintains a mapping between FAISS index types and their corresponding index implementations.
@@ -20,12 +20,14 @@
2020
*/
2121
@UtilityClass
2222
public class IndexTypeToFaissIndexMapping {
23-
private static final Map<String, Supplier<FaissIndex>> INDEX_TYPE_TO_FAISS_INDEX;
23+
private static final Map<String, Function<String, FaissIndex>> INDEX_TYPE_TO_FAISS_INDEX;
2424

2525
static {
26-
final Map<String, Supplier<FaissIndex>> mapping = new HashMap<>();
26+
final Map<String, Function<String, FaissIndex>> mapping = new HashMap<>();
2727

28-
mapping.put(FaissIdMapIndex.IXMP, FaissIdMapIndex::new);
28+
mapping.put(FaissIdMapIndex.IXMP, (indexType) -> new FaissIdMapIndex());
29+
mapping.put(FaissHNSWFlatIndex.IHNF, FaissHNSWFlatIndex::new);
30+
mapping.put(FaissHNSWFlatIndex.IHNS, FaissHNSWFlatIndex::new);
2931

3032
INDEX_TYPE_TO_FAISS_INDEX = Collections.unmodifiableMap(mapping);
3133
}
@@ -37,9 +39,9 @@ public class IndexTypeToFaissIndexMapping {
3739
* @return Actual implementation that is corresponding to the given index type.
3840
*/
3941
public FaissIndex getFaissIndex(final String indexType) {
40-
final Supplier<FaissIndex> faissIndexSupplier = INDEX_TYPE_TO_FAISS_INDEX.get(indexType);
42+
final Function<String, FaissIndex> faissIndexSupplier = INDEX_TYPE_TO_FAISS_INDEX.get(indexType);
4143
if (faissIndexSupplier != null) {
42-
return faissIndexSupplier.get();
44+
return faissIndexSupplier.apply(indexType);
4345
}
4446
throw new UnsupportedFaissIndexException("Index type [" + indexType + "] is not supported.");
4547
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.knn.memoryoptsearch.faiss;
7+
8+
import org.apache.lucene.store.IndexInput;
9+
import org.apache.lucene.util.hnsw.HnswGraph;
10+
11+
import java.io.IOException;
12+
13+
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
14+
15+
/**
16+
* This graph implements Lucene's HNSW graph interface using the FAISS HNSW graph. Conceptually, both libraries represent the graph
17+
* similarly, maintaining a list of neighbor IDs. This implementation acts as a bridge, enabling Lucene's HNSW graph searcher to perform
18+
* vector searches on a FAISS index.
19+
*/
20+
public class LuceneFaissHnswGraph extends HnswGraph {
21+
private final FaissHNSW faissHnsw;
22+
private final IndexInput indexInput;
23+
private final int numVectors;
24+
private int[] neighborIdList;
25+
private int numNeighbors;
26+
private int nextNeighborIndex;
27+
28+
public LuceneFaissHnswGraph(final FaissHNSW faissHNSW, final int numVectors, final IndexInput indexInput) {
29+
this.faissHnsw = faissHNSW;
30+
this.indexInput = indexInput;
31+
this.numVectors = numVectors;
32+
}
33+
34+
/**
35+
* Seek to the starting offset of neighbor list of `internalVectorId` at the given `level`.
36+
* In which, it will load all ids into a buffer array.
37+
* @param level
38+
* @param internalVectorId
39+
*/
40+
@Override
41+
public void seek(int level, int internalVectorId) {
42+
// Get a relative starting offset of neighbor list at `level`.
43+
long o = faissHnsw.getOffsets()[internalVectorId];
44+
45+
// `begin` and `end` represent for a pair of staring offset and end offset.
46+
// But, what `end` represents is the maximum offset a neighbor list at a level can have.
47+
// Therefore, it is required to traverse a list until getting a terminal `-1`.
48+
// Ex: [1, 5, 20, 100, -1, -1, ..., -1]
49+
final long begin = o + faissHnsw.getCumNumberNeighborPerLevel()[level];
50+
final long end = o + faissHnsw.getCumNumberNeighborPerLevel()[level + 1];
51+
loadNeighborIdList(begin, end);
52+
}
53+
54+
private void loadNeighborIdList(final long begin, final long end) {
55+
// Make sure we have sufficient space for neighbor list
56+
final long maxLength = end - begin;
57+
if (neighborIdList == null || neighborIdList.length < maxLength) {
58+
neighborIdList = new int[(int) (maxLength * 1.5)];
59+
}
60+
61+
// Seek to the first offset of neighbor list
62+
try {
63+
indexInput.seek(faissHnsw.getNeighbors().getBaseOffset() + Integer.BYTES * begin);
64+
} catch (IOException e) {
65+
throw new RuntimeException(e);
66+
}
67+
68+
// Fill the array with neighbor ids
69+
int index = 0;
70+
try {
71+
for (long i = begin; i < end; i++) {
72+
final int neighborId = indexInput.readInt();
73+
if (neighborId >= 0) {
74+
neighborIdList[index++] = neighborId;
75+
} else {
76+
break;
77+
}
78+
}
79+
80+
// Set variables for navigation
81+
numNeighbors = index;
82+
nextNeighborIndex = 0;
83+
} catch (IOException e) {
84+
throw new RuntimeException(e);
85+
}
86+
}
87+
88+
@Override
89+
public int size() {
90+
return numVectors;
91+
}
92+
93+
@Override
94+
public int nextNeighbor() {
95+
if (nextNeighborIndex < numNeighbors) {
96+
return neighborIdList[nextNeighborIndex++];
97+
}
98+
99+
// Neighbor list has been exhausted.
100+
return NO_MORE_DOCS;
101+
}
102+
103+
@Override
104+
public int numLevels() {
105+
return faissHnsw.getMaxLevel();
106+
}
107+
108+
@Override
109+
public int entryNode() {
110+
return faissHnsw.getEntryPoint();
111+
}
112+
113+
@Override
114+
public NodesIterator getNodesOnLevel(int i) {
115+
throw new UnsupportedOperationException();
116+
}
117+
}

0 commit comments

Comments
 (0)