Skip to content

Commit 7b0229d

Browse files
authored
[BUG FIX] Fix bwc failure in neural sparse search (#696)
1 parent dd3b30c commit 7b0229d

File tree

5 files changed

+107
-18
lines changed

5 files changed

+107
-18
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1010
- Fix async actions are left in neural_sparse query ([#438](https://github.com/opensearch-project/neural-search/pull/438))
1111
- Fix typo for sparse encoding processor factory([#578](https://github.com/opensearch-project/neural-search/pull/578))
1212
- Add non-null check for queryBuilder in NeuralQueryEnricherProcessor ([#615](https://github.com/opensearch-project/neural-search/pull/615))
13+
- Add max_token_score field placeholder in NeuralSparseQueryBuilder to fix the rolling-upgrade from 2.x nodes bwc tests. ([#696](https://github.com/opensearch-project/neural-search/pull/696))
1314
### Infrastructure
1415
- Adding integration tests for scenario of hybrid query with aggregations ([#632](https://github.com/opensearch-project/neural-search/pull/632))
1516
### Documentation

qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/AbstractRestartUpgradeRestTestCase.java

-2
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44
*/
55
package org.opensearch.neuralsearch.bwc;
66

7-
import java.net.URL;
87
import java.nio.file.Files;
98
import java.nio.file.Path;
109
import java.util.Locale;
11-
import java.util.Objects;
1210
import java.util.Optional;
1311
import org.junit.Before;
1412
import org.opensearch.common.settings.Settings;

qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/AbstractRollingUpgradeTestCase.java

-2
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44
*/
55
package org.opensearch.neuralsearch.bwc;
66

7-
import java.net.URL;
87
import java.nio.file.Files;
98
import java.nio.file.Path;
109
import java.util.Locale;
11-
import java.util.Objects;
1210
import java.util.Optional;
1311
import org.junit.Before;
1412
import org.opensearch.common.settings.Settings;

src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java

+23-9
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
import lombok.extern.log4j.Log4j2;
4747

4848
/**
49-
* SparseEncodingQueryBuilder is responsible for handling "neural_sparse" query types. It uses an ML SPARSE_ENCODING model
49+
* SparseEncodingQueryBuilder is responsible for handling "neural_sparse" query types. It uses an ML NEURAL_SPARSE model
5050
* or SPARSE_TOKENIZE model to produce a Map with String keys and Float values for input text. Then it will be transformed
5151
* to Lucene FeatureQuery wrapped by Lucene BooleanQuery.
5252
*/
@@ -63,6 +63,11 @@ public class NeuralSparseQueryBuilder extends AbstractQueryBuilder<NeuralSparseQ
6363
static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text");
6464
@VisibleForTesting
6565
static final ParseField MODEL_ID_FIELD = new ParseField("model_id");
66+
// We use max_token_score field to help WAND scorer prune query clause in lucene 9.7. But in lucene 9.8 the inner
67+
// logics change, this field is not needed any more.
68+
@VisibleForTesting
69+
@Deprecated
70+
static final ParseField MAX_TOKEN_SCORE_FIELD = new ParseField("max_token_score").withAllDeprecated();
6671

6772
private static MLCommonsClientAccessor ML_CLIENT;
6873

@@ -73,6 +78,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
7378
private String fieldName;
7479
private String queryText;
7580
private String modelId;
81+
private Float maxTokenScore;
7682
private Supplier<Map<String, Float>> queryTokensSupplier;
7783
private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_13_0;
7884

@@ -91,6 +97,7 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException {
9197
} else {
9298
this.modelId = in.readString();
9399
}
100+
this.maxTokenScore = in.readOptionalFloat();
94101
if (in.readBoolean()) {
95102
Map<String, Float> queryTokens = in.readMap(StreamInput::readString, StreamInput::readFloat);
96103
this.queryTokensSupplier = () -> queryTokens;
@@ -106,6 +113,7 @@ protected void doWriteTo(StreamOutput out) throws IOException {
106113
} else {
107114
out.writeString(this.modelId);
108115
}
116+
out.writeOptionalFloat(maxTokenScore);
109117
if (!Objects.isNull(queryTokensSupplier) && !Objects.isNull(queryTokensSupplier.get())) {
110118
out.writeBoolean(true);
111119
out.writeMap(queryTokensSupplier.get(), StreamOutput::writeString, StreamOutput::writeFloat);
@@ -122,6 +130,7 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
122130
if (Objects.nonNull(modelId)) {
123131
xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId);
124132
}
133+
if (maxTokenScore != null) xContentBuilder.field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), maxTokenScore);
125134
printBoostAndQueryName(xContentBuilder);
126135
xContentBuilder.endObject();
127136
xContentBuilder.endObject();
@@ -131,7 +140,8 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
131140
* The expected parsing form looks like:
132141
* "SAMPLE_FIELD": {
133142
* "query_text": "string",
134-
* "model_id": "string"
143+
* "model_id": "string",
144+
* "max_token_score": float (optional)
135145
* }
136146
*
137147
* @param parser XContentParser
@@ -189,6 +199,8 @@ private static void parseQueryParams(XContentParser parser, NeuralSparseQueryBui
189199
sparseEncodingQueryBuilder.queryText(parser.text());
190200
} else if (MODEL_ID_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
191201
sparseEncodingQueryBuilder.modelId(parser.text());
202+
} else if (MAX_TOKEN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
203+
sparseEncodingQueryBuilder.maxTokenScore(parser.floatValue());
192204
} else {
193205
throw new ParsingException(
194206
parser.getTokenLocation(),
@@ -227,6 +239,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
227239
return new NeuralSparseQueryBuilder().fieldName(fieldName)
228240
.queryText(queryText)
229241
.modelId(modelId)
242+
.maxTokenScore(maxTokenScore)
230243
.queryTokensSupplier(queryTokensSetOnce::get);
231244
}
232245

@@ -280,22 +293,23 @@ private static void validateQueryTokens(Map<String, Float> queryTokens) {
280293
@Override
281294
protected boolean doEquals(NeuralSparseQueryBuilder obj) {
282295
if (this == obj) return true;
283-
if (Objects.isNull(obj) || getClass() != obj.getClass()) return false;
284-
if (Objects.isNull(queryTokensSupplier) && !Objects.isNull(obj.queryTokensSupplier)) return false;
285-
if (!Objects.isNull(queryTokensSupplier) && Objects.isNull(obj.queryTokensSupplier)) return false;
296+
if (obj == null || getClass() != obj.getClass()) return false;
297+
if (queryTokensSupplier == null && obj.queryTokensSupplier != null) return false;
298+
if (queryTokensSupplier != null && obj.queryTokensSupplier == null) return false;
286299
EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName)
287300
.append(queryText, obj.queryText)
288-
.append(modelId, obj.modelId);
289-
if (!Objects.isNull(queryTokensSupplier)) {
301+
.append(modelId, obj.modelId)
302+
.append(maxTokenScore, obj.maxTokenScore);
303+
if (queryTokensSupplier != null) {
290304
equalsBuilder.append(queryTokensSupplier.get(), obj.queryTokensSupplier.get());
291305
}
292306
return equalsBuilder.isEquals();
293307
}
294308

295309
@Override
296310
protected int doHashCode() {
297-
HashCodeBuilder builder = new HashCodeBuilder().append(fieldName).append(queryText).append(modelId);
298-
if (!Objects.isNull(queryTokensSupplier)) {
311+
HashCodeBuilder builder = new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).append(maxTokenScore);
312+
if (queryTokensSupplier != null) {
299313
builder.append(queryTokensSupplier.get());
300314
}
301315
return builder.toHashCode();

src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java

+83-5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD;
1111
import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD;
1212
import static org.opensearch.neuralsearch.TestUtils.xContentBuilderToMap;
13+
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.MAX_TOKEN_SCORE_FIELD;
1314
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.MODEL_ID_FIELD;
1415
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.NAME;
1516
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.QUERY_TEXT_FIELD;
@@ -22,6 +23,9 @@
2223
import java.util.function.BiConsumer;
2324
import java.util.function.Supplier;
2425

26+
import org.apache.lucene.document.FeatureField;
27+
import org.apache.lucene.search.BooleanClause;
28+
import org.apache.lucene.search.BooleanQuery;
2529
import org.junit.Before;
2630
import org.opensearch.Version;
2731
import org.opensearch.client.Client;
@@ -37,9 +41,11 @@
3741
import org.opensearch.core.xcontent.ToXContent;
3842
import org.opensearch.core.xcontent.XContentBuilder;
3943
import org.opensearch.core.xcontent.XContentParser;
44+
import org.opensearch.index.mapper.MappedFieldType;
4045
import org.opensearch.index.query.MatchAllQueryBuilder;
4146
import org.opensearch.index.query.QueryBuilder;
4247
import org.opensearch.index.query.QueryRewriteContext;
48+
import org.opensearch.index.query.QueryShardContext;
4349
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
4450
import org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils;
4551
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
@@ -54,6 +60,7 @@ public class NeuralSparseQueryBuilderTests extends OpenSearchTestCase {
5460
private static final String MODEL_ID = "mfgfgdsfgfdgsde";
5561
private static final float BOOST = 1.8f;
5662
private static final String QUERY_NAME = "queryName";
63+
private static final Float MAX_TOKEN_SCORE = 123f;
5764
private static final Supplier<Map<String, Float>> QUERY_TOKENS_SUPPLIER = () -> Map.of("hello", 1.f, "world", 2.f);
5865

5966
@Before
@@ -121,6 +128,32 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() {
121128
assertEquals(QUERY_NAME, sparseEncodingQueryBuilder.queryName());
122129
}
123130

131+
@SneakyThrows
132+
public void testFromXContent_whenBuiltWithMaxTokenScore_thenThrowWarning() {
133+
/*
134+
{
135+
"VECTOR_FIELD": {
136+
"query_text": "string",
137+
"model_id": "string",
138+
"max_token_score": 123.0
139+
}
140+
}
141+
*/
142+
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
143+
.startObject()
144+
.startObject(FIELD_NAME)
145+
.field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT)
146+
.field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID)
147+
.field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), MAX_TOKEN_SCORE)
148+
.endObject()
149+
.endObject();
150+
151+
XContentParser contentParser = createParser(xContentBuilder);
152+
contentParser.nextToken();
153+
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = NeuralSparseQueryBuilder.fromXContent(contentParser);
154+
assertWarnings("Deprecated field [max_token_score] used, this field is unused and will be removed entirely");
155+
}
156+
124157
@SneakyThrows
125158
public void testFromXContent_whenBuildWithMultipleRootFields_thenFail() {
126159
/*
@@ -248,7 +281,8 @@ public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() {
248281
public void testToXContent() {
249282
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
250283
.modelId(MODEL_ID)
251-
.queryText(QUERY_TEXT);
284+
.queryText(QUERY_TEXT)
285+
.maxTokenScore(MAX_TOKEN_SCORE);
252286

253287
XContentBuilder builder = XContentFactory.jsonBuilder();
254288
builder = sparseEncodingQueryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS);
@@ -273,6 +307,7 @@ public void testToXContent() {
273307

274308
assertEquals(MODEL_ID, secondInnerMap.get(MODEL_ID_FIELD.getPreferredName()));
275309
assertEquals(QUERY_TEXT, secondInnerMap.get(QUERY_TEXT_FIELD.getPreferredName()));
310+
assertEquals(MAX_TOKEN_SCORE, (Double) secondInnerMap.get(MAX_TOKEN_SCORE_FIELD.getPreferredName()), 0.0);
276311
}
277312

278313
public void testStreams_whenMinVersionIsBeforeDefaultModelId_thenSuccess() {
@@ -285,6 +320,7 @@ public void testStreams() {
285320
NeuralSparseQueryBuilder original = new NeuralSparseQueryBuilder();
286321
original.fieldName(FIELD_NAME);
287322
original.queryText(QUERY_TEXT);
323+
original.maxTokenScore(MAX_TOKEN_SCORE);
288324
original.modelId(MODEL_ID);
289325
original.boost(BOOST);
290326
original.queryName(QUERY_NAME);
@@ -306,11 +342,11 @@ public void testStreams() {
306342
queryTokensSetOnce.set(Map.of("hello", 1.0f, "world", 2.0f));
307343
original.queryTokensSupplier(queryTokensSetOnce::get);
308344

309-
BytesStreamOutput streamOutput2 = new BytesStreamOutput();
310-
original.writeTo(streamOutput2);
345+
streamOutput = new BytesStreamOutput();
346+
original.writeTo(streamOutput);
311347

312348
filterStreamInput = new NamedWriteableAwareStreamInput(
313-
streamOutput2.bytes().streamInput(),
349+
streamOutput.bytes().streamInput(),
314350
new NamedWriteableRegistry(
315351
List.of(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new))
316352
)
@@ -327,6 +363,8 @@ public void testHashAndEquals() {
327363
String queryText2 = "query text 2";
328364
String modelId1 = "model-1";
329365
String modelId2 = "model-2";
366+
float maxTokenScore1 = 1.1f;
367+
float maxTokenScore2 = 2.2f;
330368
float boost1 = 1.8f;
331369
float boost2 = 3.8f;
332370
String queryName1 = "query-1";
@@ -337,60 +375,77 @@ public void testHashAndEquals() {
337375
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_baseline = new NeuralSparseQueryBuilder().fieldName(fieldName1)
338376
.queryText(queryText1)
339377
.modelId(modelId1)
378+
.maxTokenScore(maxTokenScore1)
340379
.boost(boost1)
341380
.queryName(queryName1);
342381

343382
// Identical to sparseEncodingQueryBuilder_baseline
344383
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_baselineCopy = new NeuralSparseQueryBuilder().fieldName(fieldName1)
345384
.queryText(queryText1)
346385
.modelId(modelId1)
386+
.maxTokenScore(maxTokenScore1)
347387
.boost(boost1)
348388
.queryName(queryName1);
349389

350390
// Identical to sparseEncodingQueryBuilder_baseline except default boost and query name
351391
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_defaultBoostAndQueryName = new NeuralSparseQueryBuilder().fieldName(fieldName1)
352392
.queryText(queryText1)
353-
.modelId(modelId1);
393+
.modelId(modelId1)
394+
.maxTokenScore(maxTokenScore1);
354395

355396
// Identical to sparseEncodingQueryBuilder_baseline except diff field name
356397
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffFieldName = new NeuralSparseQueryBuilder().fieldName(fieldName2)
357398
.queryText(queryText1)
358399
.modelId(modelId1)
400+
.maxTokenScore(maxTokenScore1)
359401
.boost(boost1)
360402
.queryName(queryName1);
361403

362404
// Identical to sparseEncodingQueryBuilder_baseline except diff query text
363405
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryText = new NeuralSparseQueryBuilder().fieldName(fieldName1)
364406
.queryText(queryText2)
365407
.modelId(modelId1)
408+
.maxTokenScore(maxTokenScore1)
366409
.boost(boost1)
367410
.queryName(queryName1);
368411

369412
// Identical to sparseEncodingQueryBuilder_baseline except diff model ID
370413
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffModelId = new NeuralSparseQueryBuilder().fieldName(fieldName1)
371414
.queryText(queryText1)
372415
.modelId(modelId2)
416+
.maxTokenScore(maxTokenScore1)
373417
.boost(boost1)
374418
.queryName(queryName1);
375419

376420
// Identical to sparseEncodingQueryBuilder_baseline except diff boost
377421
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffBoost = new NeuralSparseQueryBuilder().fieldName(fieldName1)
378422
.queryText(queryText1)
379423
.modelId(modelId1)
424+
.maxTokenScore(maxTokenScore1)
380425
.boost(boost2)
381426
.queryName(queryName1);
382427

383428
// Identical to sparseEncodingQueryBuilder_baseline except diff query name
384429
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryName = new NeuralSparseQueryBuilder().fieldName(fieldName1)
385430
.queryText(queryText1)
386431
.modelId(modelId1)
432+
.maxTokenScore(maxTokenScore1)
387433
.boost(boost1)
388434
.queryName(queryName2);
389435

436+
// Identical to sparseEncodingQueryBuilder_baseline except diff max token score
437+
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffMaxTokenScore = new NeuralSparseQueryBuilder().fieldName(fieldName1)
438+
.queryText(queryText1)
439+
.modelId(modelId1)
440+
.maxTokenScore(maxTokenScore2)
441+
.boost(boost1)
442+
.queryName(queryName1);
443+
390444
// Identical to sparseEncodingQueryBuilder_baseline except non-null query tokens supplier
391445
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_nonNullQueryTokens = new NeuralSparseQueryBuilder().fieldName(fieldName1)
392446
.queryText(queryText1)
393447
.modelId(modelId1)
448+
.maxTokenScore(maxTokenScore1)
394449
.boost(boost1)
395450
.queryName(queryName1)
396451
.queryTokensSupplier(() -> queryTokens1);
@@ -399,6 +454,7 @@ public void testHashAndEquals() {
399454
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryTokens = new NeuralSparseQueryBuilder().fieldName(fieldName1)
400455
.queryText(queryText1)
401456
.modelId(modelId1)
457+
.maxTokenScore(maxTokenScore1)
402458
.boost(boost1)
403459
.queryName(queryName1)
404460
.queryTokensSupplier(() -> queryTokens2);
@@ -427,6 +483,9 @@ public void testHashAndEquals() {
427483
assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffQueryName);
428484
assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffQueryName.hashCode());
429485

486+
assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffMaxTokenScore);
487+
assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffMaxTokenScore.hashCode());
488+
430489
assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_nonNullQueryTokens);
431490
assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_nonNullQueryTokens.hashCode());
432491

@@ -486,4 +545,23 @@ private void setUpClusterService(Version version) {
486545
ClusterService clusterService = NeuralSearchClusterTestUtils.mockClusterService(version);
487546
NeuralSearchClusterUtil.instance().initialize(clusterService);
488547
}
548+
549+
@SneakyThrows
550+
public void testDoToQuery_successfulDoToQuery() {
551+
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
552+
.maxTokenScore(MAX_TOKEN_SCORE)
553+
.queryText(QUERY_TEXT)
554+
.modelId(MODEL_ID)
555+
.queryTokensSupplier(QUERY_TOKENS_SUPPLIER);
556+
QueryShardContext mockedQueryShardContext = mock(QueryShardContext.class);
557+
MappedFieldType mockedMappedFieldType = mock(MappedFieldType.class);
558+
doAnswer(invocation -> "rank_features").when(mockedMappedFieldType).typeName();
559+
doAnswer(invocation -> mockedMappedFieldType).when(mockedQueryShardContext).fieldMapper(any());
560+
561+
BooleanQuery.Builder targetQueryBuilder = new BooleanQuery.Builder();
562+
targetQueryBuilder.add(FeatureField.newLinearQuery(FIELD_NAME, "hello", 1.f), BooleanClause.Occur.SHOULD);
563+
targetQueryBuilder.add(FeatureField.newLinearQuery(FIELD_NAME, "world", 2.f), BooleanClause.Occur.SHOULD);
564+
565+
assertEquals(sparseEncodingQueryBuilder.doToQuery(mockedQueryShardContext), targetQueryBuilder.build());
566+
}
489567
}

0 commit comments

Comments
 (0)