forked from opensearch-project/neural-search
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNeuralKNNQuery.java
67 lines (57 loc) · 1.85 KB
/
NeuralKNNQuery.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.query;
import lombok.Getter;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import java.io.IOException;
import java.util.Objects;
/**
* Wraps KNN Lucene query to support neural search extensions.
* Delegates core operations to the underlying KNN query.
*/
@Getter
public class NeuralKNNQuery extends Query {
private final Query knnQuery;
public NeuralKNNQuery(Query knnQuery) {
this.knnQuery = knnQuery;
}
@Override
public String toString(String field) {
return knnQuery.toString(field);
}
@Override
public void visit(QueryVisitor visitor) {
// Delegate the visitor to the underlying KNN query
knnQuery.visit(visitor);
}
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
// Delegate weight creation to the underlying KNN query
return knnQuery.createWeight(searcher, scoreMode, boost);
}
@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
Query rewritten = knnQuery.rewrite(indexSearcher);
if (rewritten == knnQuery) {
return this;
}
return new NeuralKNNQuery(rewritten);
}
@Override
public boolean equals(Object other) {
if (this == other) return true;
if (other == null || getClass() != other.getClass()) return false;
NeuralKNNQuery that = (NeuralKNNQuery) other;
return Objects.equals(knnQuery, that.knnQuery);
}
@Override
public int hashCode() {
return Objects.hash(knnQuery);
}
}