Skip to content

Commit 0769ad7

Browse files
authored
Encapsulate KNNQueryBuilder creation within NeuralKNNQueryBuilder (#1183)
* Encapsulate KNNQueryBuilder creation within NeuralKNNQueryBuilder Signed-off-by: Junqiu Lei <junqiu@amazon.com>
1 parent ebfb058 commit 0769ad7

File tree

9 files changed

+479
-32
lines changed

9 files changed

+479
-32
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1212
### Documentation
1313
### Maintenance
1414
### Refactoring
15+
- Encapsulate KNNQueryBuilder creation within NeuralKNNQueryBuilder ([#1183](https://github.com/opensearch-project/neural-search/pull/1183))
1516

1617
## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.19...2.x)
1718
### Features

qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/rolling/BatchIngestionIT.java

+47
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package org.opensearch.neuralsearch.bwc.rolling;
66

77
import org.opensearch.neuralsearch.util.TestUtils;
8+
import org.opensearch.ml.common.model.MLModelState;
89

910
import java.nio.file.Files;
1011
import java.nio.file.Path;
@@ -28,29 +29,59 @@ public void testBatchIngestion_SparseEncodingProcessor_E2EFlow() throws Exceptio
2829
case OLD:
2930
sparseModelId = uploadSparseEncodingModel();
3031
loadModel(sparseModelId);
32+
MLModelState oldModelState = getModelState(sparseModelId);
33+
logger.info("Model state in OLD phase: {}", oldModelState);
34+
if (oldModelState != MLModelState.LOADED) {
35+
logger.error("Model {} is not in LOADED state in OLD phase. Current state: {}", sparseModelId, oldModelState);
36+
waitForModelToLoad(sparseModelId);
37+
}
3138
createPipelineForSparseEncodingProcessor(sparseModelId, SPARSE_PIPELINE, 2);
39+
logger.info("Pipeline state in OLD phase: {}", getIngestionPipeline(SPARSE_PIPELINE));
3240
createIndexWithConfiguration(
3341
indexName,
3442
Files.readString(Path.of(classLoader.getResource("processor/SparseIndexMappings.json").toURI())),
3543
SPARSE_PIPELINE
3644
);
3745
List<Map<String, String>> docs = prepareDataForBulkIngestion(0, 5);
3846
bulkAddDocuments(indexName, TEXT_FIELD_NAME, SPARSE_PIPELINE, docs);
47+
logger.info("Document count after OLD phase ingestion: {}", getDocCount(indexName));
3948
validateDocCountAndInfo(indexName, 5, () -> getDocById(indexName, "4"), EMBEDDING_FIELD_NAME, Map.class);
4049
break;
4150
case MIXED:
4251
sparseModelId = TestUtils.getModelId(getIngestionPipeline(SPARSE_PIPELINE), SPARSE_ENCODING_PROCESSOR);
4352
loadModel(sparseModelId);
53+
MLModelState mixedModelState = getModelState(sparseModelId);
54+
logger.info("Model state in MIXED phase: {}", mixedModelState);
55+
if (mixedModelState != MLModelState.LOADED) {
56+
logger.error("Model {} is not in LOADED state in MIXED phase. Current state: {}", sparseModelId, mixedModelState);
57+
waitForModelToLoad(sparseModelId);
58+
}
59+
logger.info("Pipeline state in MIXED phase: {}", getIngestionPipeline(SPARSE_PIPELINE));
4460
List<Map<String, String>> docsForMixed = prepareDataForBulkIngestion(5, 5);
61+
logger.info("Document count before MIXED phase ingestion: {}", getDocCount(indexName));
4562
bulkAddDocuments(indexName, TEXT_FIELD_NAME, SPARSE_PIPELINE, docsForMixed);
63+
logger.info("Document count after MIXED phase ingestion: {}", getDocCount(indexName));
4664
validateDocCountAndInfo(indexName, 10, () -> getDocById(indexName, "9"), EMBEDDING_FIELD_NAME, Map.class);
4765
break;
4866
case UPGRADED:
4967
try {
5068
sparseModelId = TestUtils.getModelId(getIngestionPipeline(SPARSE_PIPELINE), SPARSE_ENCODING_PROCESSOR);
5169
loadModel(sparseModelId);
70+
MLModelState upgradedModelState = getModelState(sparseModelId);
71+
logger.info("Model state in UPGRADED phase: {}", upgradedModelState);
72+
if (upgradedModelState != MLModelState.LOADED) {
73+
logger.error(
74+
"Model {} is not in LOADED state in UPGRADED phase. Current state: {}",
75+
sparseModelId,
76+
upgradedModelState
77+
);
78+
waitForModelToLoad(sparseModelId);
79+
}
80+
logger.info("Pipeline state in UPGRADED phase: {}", getIngestionPipeline(SPARSE_PIPELINE));
5281
List<Map<String, String>> docsForUpgraded = prepareDataForBulkIngestion(10, 5);
82+
logger.info("Document count before UPGRADED phase ingestion: {}", getDocCount(indexName));
5383
bulkAddDocuments(indexName, TEXT_FIELD_NAME, SPARSE_PIPELINE, docsForUpgraded);
84+
logger.info("Document count after UPGRADED phase ingestion: {}", getDocCount(indexName));
5485
validateDocCountAndInfo(indexName, 15, () -> getDocById(indexName, "14"), EMBEDDING_FIELD_NAME, Map.class);
5586
} finally {
5687
wipeOfTestResources(indexName, SPARSE_PIPELINE, sparseModelId, null);
@@ -60,4 +91,20 @@ public void testBatchIngestion_SparseEncodingProcessor_E2EFlow() throws Exceptio
6091
throw new IllegalStateException("Unexpected value: " + getClusterType());
6192
}
6293
}
94+
95+
private void waitForModelToLoad(String modelId) throws Exception {
96+
int maxAttempts = 30; // Maximum number of attempts
97+
int waitTimeInSeconds = 2; // Time to wait between attempts
98+
99+
for (int attempt = 0; attempt < maxAttempts; attempt++) {
100+
MLModelState state = getModelState(modelId);
101+
if (state == MLModelState.LOADED) {
102+
logger.info("Model {} is now loaded after {} attempts", modelId, attempt + 1);
103+
return;
104+
}
105+
logger.info("Waiting for model {} to load. Current state: {}. Attempt {}/{}", modelId, state, attempt + 1, maxAttempts);
106+
Thread.sleep(waitTimeInSeconds * 1000);
107+
}
108+
throw new RuntimeException("Model " + modelId + " failed to load after " + maxAttempts + " attempts");
109+
}
63110
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.query;
6+
7+
import lombok.Getter;
8+
import org.apache.lucene.index.IndexReader;
9+
import org.apache.lucene.search.IndexSearcher;
10+
import org.apache.lucene.search.Query;
11+
import org.apache.lucene.search.QueryVisitor;
12+
import org.apache.lucene.search.ScoreMode;
13+
import org.apache.lucene.search.Weight;
14+
15+
import java.io.IOException;
16+
import java.util.Objects;
17+
18+
/**
19+
* Wraps KNN Lucene query to support neural search extensions.
20+
* Delegates core operations to the underlying KNN query.
21+
*/
22+
@Getter
23+
public class NeuralKNNQuery extends Query {
24+
private final Query knnQuery;
25+
26+
public NeuralKNNQuery(Query knnQuery) {
27+
this.knnQuery = knnQuery;
28+
}
29+
30+
@Override
31+
public String toString(String field) {
32+
return knnQuery.toString(field);
33+
}
34+
35+
@Override
36+
public void visit(QueryVisitor visitor) {
37+
// Delegate the visitor to the underlying KNN query
38+
knnQuery.visit(visitor);
39+
}
40+
41+
@Override
42+
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
43+
// Delegate weight creation to the underlying KNN query
44+
return knnQuery.createWeight(searcher, scoreMode, boost);
45+
}
46+
47+
@Override
48+
public Query rewrite(IndexReader reader) throws IOException {
49+
Query rewritten = knnQuery.rewrite(reader);
50+
if (rewritten == knnQuery) {
51+
return this;
52+
}
53+
return new NeuralKNNQuery(rewritten);
54+
}
55+
56+
@Override
57+
public boolean equals(Object other) {
58+
if (this == other) return true;
59+
if (other == null || getClass() != other.getClass()) return false;
60+
NeuralKNNQuery that = (NeuralKNNQuery) other;
61+
return Objects.equals(knnQuery, that.knnQuery);
62+
}
63+
64+
@Override
65+
public int hashCode() {
66+
return Objects.hash(knnQuery);
67+
}
68+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.query;
6+
7+
import lombok.Getter;
8+
import org.apache.lucene.search.Query;
9+
import org.opensearch.core.common.io.stream.StreamOutput;
10+
import org.opensearch.core.xcontent.XContentBuilder;
11+
import org.opensearch.index.query.AbstractQueryBuilder;
12+
import org.opensearch.index.query.QueryBuilder;
13+
import org.opensearch.index.query.QueryRewriteContext;
14+
import org.opensearch.index.query.QueryShardContext;
15+
import org.opensearch.knn.index.query.KNNQueryBuilder;
16+
import org.opensearch.knn.index.query.parser.KNNQueryBuilderParser;
17+
import org.opensearch.knn.index.query.rescore.RescoreContext;
18+
import org.opensearch.knn.index.util.IndexUtil;
19+
20+
import java.io.IOException;
21+
import java.util.Map;
22+
import java.util.Objects;
23+
24+
/**
25+
* NeuralKNNQueryBuilder wraps KNNQueryBuilder to:
26+
* 1. Isolate KNN plugin API changes to a single location
27+
* 2. Allow extension with neural-search-specific information (e.g., query text)
28+
*/
29+
30+
@Getter
31+
public class NeuralKNNQueryBuilder extends AbstractQueryBuilder<NeuralKNNQueryBuilder> {
32+
private final KNNQueryBuilder knnQueryBuilder;
33+
34+
/**
35+
* Creates a new builder instance.
36+
*/
37+
public static Builder builder() {
38+
return new Builder();
39+
}
40+
41+
public String fieldName() {
42+
return knnQueryBuilder.fieldName();
43+
}
44+
45+
public int k() {
46+
return knnQueryBuilder.getK();
47+
}
48+
49+
/**
50+
* Builder for NeuralKNNQueryBuilder.
51+
*/
52+
public static class Builder {
53+
private String fieldName;
54+
private float[] vector;
55+
private Integer k;
56+
private QueryBuilder filter;
57+
private Float maxDistance;
58+
private Float minScore;
59+
private Boolean expandNested;
60+
private Map<String, ?> methodParameters;
61+
private RescoreContext rescoreContext;
62+
63+
private Builder() {}
64+
65+
public Builder fieldName(String fieldName) {
66+
this.fieldName = fieldName;
67+
return this;
68+
}
69+
70+
public Builder vector(float[] vector) {
71+
this.vector = vector;
72+
return this;
73+
}
74+
75+
public Builder k(Integer k) {
76+
this.k = k;
77+
return this;
78+
}
79+
80+
public Builder filter(QueryBuilder filter) {
81+
this.filter = filter;
82+
return this;
83+
}
84+
85+
public Builder maxDistance(Float maxDistance) {
86+
this.maxDistance = maxDistance;
87+
return this;
88+
}
89+
90+
public Builder minScore(Float minScore) {
91+
this.minScore = minScore;
92+
return this;
93+
}
94+
95+
public Builder expandNested(Boolean expandNested) {
96+
this.expandNested = expandNested;
97+
return this;
98+
}
99+
100+
public Builder methodParameters(Map<String, ?> methodParameters) {
101+
this.methodParameters = methodParameters;
102+
return this;
103+
}
104+
105+
public Builder rescoreContext(RescoreContext rescoreContext) {
106+
this.rescoreContext = rescoreContext;
107+
return this;
108+
}
109+
110+
public NeuralKNNQueryBuilder build() {
111+
KNNQueryBuilder knnBuilder = KNNQueryBuilder.builder()
112+
.fieldName(fieldName)
113+
.vector(vector)
114+
.k(k)
115+
.filter(filter)
116+
.maxDistance(maxDistance)
117+
.minScore(minScore)
118+
.expandNested(expandNested)
119+
.methodParameters(methodParameters)
120+
.rescoreContext(rescoreContext)
121+
.build();
122+
return new NeuralKNNQueryBuilder(knnBuilder);
123+
}
124+
}
125+
126+
private NeuralKNNQueryBuilder(KNNQueryBuilder knnQueryBuilder) {
127+
this.knnQueryBuilder = knnQueryBuilder;
128+
}
129+
130+
@Override
131+
public void doWriteTo(StreamOutput out) throws IOException {
132+
KNNQueryBuilderParser.streamOutput(out, knnQueryBuilder, IndexUtil::isClusterOnOrAfterMinRequiredVersion);
133+
}
134+
135+
@Override
136+
protected void doXContent(XContentBuilder builder, Params params) throws IOException {
137+
knnQueryBuilder.doXContent(builder, params);
138+
}
139+
140+
@Override
141+
protected QueryBuilder doRewrite(QueryRewriteContext context) throws IOException {
142+
QueryBuilder rewritten = knnQueryBuilder.rewrite(context);
143+
if (rewritten == knnQueryBuilder) {
144+
return this;
145+
}
146+
return new NeuralKNNQueryBuilder((KNNQueryBuilder) rewritten);
147+
}
148+
149+
@Override
150+
protected Query doToQuery(QueryShardContext context) throws IOException {
151+
Query knnQuery = knnQueryBuilder.toQuery(context);
152+
return new NeuralKNNQuery(knnQuery);
153+
}
154+
155+
@Override
156+
protected boolean doEquals(NeuralKNNQueryBuilder other) {
157+
return Objects.equals(knnQueryBuilder, other.knnQueryBuilder);
158+
}
159+
160+
@Override
161+
protected int doHashCode() {
162+
return Objects.hash(knnQueryBuilder);
163+
}
164+
165+
@Override
166+
public String getWriteableName() {
167+
return knnQueryBuilder.getWriteableName();
168+
}
169+
}

src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java

+10-10
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
import org.opensearch.index.query.QueryBuilder;
4141
import org.opensearch.index.query.QueryRewriteContext;
4242
import org.opensearch.index.query.QueryShardContext;
43-
import org.opensearch.knn.index.query.KNNQueryBuilder;
4443
import org.opensearch.knn.index.query.parser.MethodParametersParser;
4544
import org.opensearch.knn.index.query.parser.RescoreParser;
4645
import org.opensearch.knn.index.query.rescore.RescoreContext;
@@ -463,22 +462,23 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
463462
// https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/index/query/Rewriteable.java#L117.
464463
// With the asynchronous call, on first rewrite, we create a new
465464
// vector supplier that will get populated once the asynchronous call finishes and pass this supplier in to
466-
// create a new builder. Once the supplier's value gets set, we return a KNNQueryBuilder. Otherwise, we just
467-
// return the current unmodified query builder.
465+
// create a new builder. Once the supplier's value gets set, we return a NeuralKNNQueryBuilder
466+
// which wrapped KNNQueryBuilder. Otherwise, we just return the current unmodified query builder.
468467
if (vectorSupplier() != null) {
469468
if (vectorSupplier().get() == null) {
470469
return this;
471470
}
472-
return KNNQueryBuilder.builder()
471+
472+
return NeuralKNNQueryBuilder.builder()
473473
.fieldName(fieldName())
474474
.vector(vectorSupplier.get())
475+
.k(k())
475476
.filter(filter())
476-
.maxDistance(maxDistance)
477-
.minScore(minScore)
478-
.expandNested(expandNested)
479-
.k(k)
480-
.methodParameters(methodParameters)
481-
.rescoreContext(rescoreContext)
477+
.maxDistance(maxDistance())
478+
.minScore(minScore())
479+
.expandNested(expandNested())
480+
.methodParameters(methodParameters())
481+
.rescoreContext(rescoreContext())
482482
.build();
483483
}
484484

0 commit comments

Comments
 (0)