10
10
import static org .opensearch .index .query .AbstractQueryBuilder .BOOST_FIELD ;
11
11
import static org .opensearch .index .query .AbstractQueryBuilder .NAME_FIELD ;
12
12
import static org .opensearch .neuralsearch .TestUtils .xContentBuilderToMap ;
13
+ import static org .opensearch .neuralsearch .query .NeuralSparseQueryBuilder .MAX_TOKEN_SCORE_FIELD ;
13
14
import static org .opensearch .neuralsearch .query .NeuralSparseQueryBuilder .MODEL_ID_FIELD ;
14
15
import static org .opensearch .neuralsearch .query .NeuralSparseQueryBuilder .NAME ;
15
16
import static org .opensearch .neuralsearch .query .NeuralSparseQueryBuilder .QUERY_TEXT_FIELD ;
22
23
import java .util .function .BiConsumer ;
23
24
import java .util .function .Supplier ;
24
25
26
+ import org .apache .lucene .document .FeatureField ;
27
+ import org .apache .lucene .search .BooleanClause ;
28
+ import org .apache .lucene .search .BooleanQuery ;
25
29
import org .junit .Before ;
26
30
import org .opensearch .Version ;
27
31
import org .opensearch .client .Client ;
37
41
import org .opensearch .core .xcontent .ToXContent ;
38
42
import org .opensearch .core .xcontent .XContentBuilder ;
39
43
import org .opensearch .core .xcontent .XContentParser ;
44
+ import org .opensearch .index .mapper .MappedFieldType ;
40
45
import org .opensearch .index .query .MatchAllQueryBuilder ;
41
46
import org .opensearch .index .query .QueryBuilder ;
42
47
import org .opensearch .index .query .QueryRewriteContext ;
48
+ import org .opensearch .index .query .QueryShardContext ;
43
49
import org .opensearch .neuralsearch .ml .MLCommonsClientAccessor ;
44
50
import org .opensearch .neuralsearch .util .NeuralSearchClusterTestUtils ;
45
51
import org .opensearch .neuralsearch .util .NeuralSearchClusterUtil ;
@@ -54,6 +60,7 @@ public class NeuralSparseQueryBuilderTests extends OpenSearchTestCase {
54
60
private static final String MODEL_ID = "mfgfgdsfgfdgsde" ;
55
61
private static final float BOOST = 1.8f ;
56
62
private static final String QUERY_NAME = "queryName" ;
63
+ private static final Float MAX_TOKEN_SCORE = 123f ;
57
64
private static final Supplier <Map <String , Float >> QUERY_TOKENS_SUPPLIER = () -> Map .of ("hello" , 1.f , "world" , 2.f );
58
65
59
66
@ Before
@@ -121,6 +128,32 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() {
121
128
assertEquals (QUERY_NAME , sparseEncodingQueryBuilder .queryName ());
122
129
}
123
130
131
+ @ SneakyThrows
132
+ public void testFromXContent_whenBuiltWithMaxTokenScore_thenThrowWarning () {
133
+ /*
134
+ {
135
+ "VECTOR_FIELD": {
136
+ "query_text": "string",
137
+ "model_id": "string",
138
+ "max_token_score": 123.0
139
+ }
140
+ }
141
+ */
142
+ XContentBuilder xContentBuilder = XContentFactory .jsonBuilder ()
143
+ .startObject ()
144
+ .startObject (FIELD_NAME )
145
+ .field (QUERY_TEXT_FIELD .getPreferredName (), QUERY_TEXT )
146
+ .field (MODEL_ID_FIELD .getPreferredName (), MODEL_ID )
147
+ .field (MAX_TOKEN_SCORE_FIELD .getPreferredName (), MAX_TOKEN_SCORE )
148
+ .endObject ()
149
+ .endObject ();
150
+
151
+ XContentParser contentParser = createParser (xContentBuilder );
152
+ contentParser .nextToken ();
153
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder = NeuralSparseQueryBuilder .fromXContent (contentParser );
154
+ assertWarnings ("Deprecated field [max_token_score] used, this field is unused and will be removed entirely" );
155
+ }
156
+
124
157
@ SneakyThrows
125
158
public void testFromXContent_whenBuildWithMultipleRootFields_thenFail () {
126
159
/*
@@ -248,7 +281,8 @@ public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() {
248
281
public void testToXContent () {
249
282
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder ().fieldName (FIELD_NAME )
250
283
.modelId (MODEL_ID )
251
- .queryText (QUERY_TEXT );
284
+ .queryText (QUERY_TEXT )
285
+ .maxTokenScore (MAX_TOKEN_SCORE );
252
286
253
287
XContentBuilder builder = XContentFactory .jsonBuilder ();
254
288
builder = sparseEncodingQueryBuilder .toXContent (builder , ToXContent .EMPTY_PARAMS );
@@ -273,6 +307,7 @@ public void testToXContent() {
273
307
274
308
assertEquals (MODEL_ID , secondInnerMap .get (MODEL_ID_FIELD .getPreferredName ()));
275
309
assertEquals (QUERY_TEXT , secondInnerMap .get (QUERY_TEXT_FIELD .getPreferredName ()));
310
+ assertEquals (MAX_TOKEN_SCORE , (Double ) secondInnerMap .get (MAX_TOKEN_SCORE_FIELD .getPreferredName ()), 0.0 );
276
311
}
277
312
278
313
public void testStreams_whenMinVersionIsBeforeDefaultModelId_thenSuccess () {
@@ -285,6 +320,7 @@ public void testStreams() {
285
320
NeuralSparseQueryBuilder original = new NeuralSparseQueryBuilder ();
286
321
original .fieldName (FIELD_NAME );
287
322
original .queryText (QUERY_TEXT );
323
+ original .maxTokenScore (MAX_TOKEN_SCORE );
288
324
original .modelId (MODEL_ID );
289
325
original .boost (BOOST );
290
326
original .queryName (QUERY_NAME );
@@ -306,11 +342,11 @@ public void testStreams() {
306
342
queryTokensSetOnce .set (Map .of ("hello" , 1.0f , "world" , 2.0f ));
307
343
original .queryTokensSupplier (queryTokensSetOnce ::get );
308
344
309
- BytesStreamOutput streamOutput2 = new BytesStreamOutput ();
310
- original .writeTo (streamOutput2 );
345
+ streamOutput = new BytesStreamOutput ();
346
+ original .writeTo (streamOutput );
311
347
312
348
filterStreamInput = new NamedWriteableAwareStreamInput (
313
- streamOutput2 .bytes ().streamInput (),
349
+ streamOutput .bytes ().streamInput (),
314
350
new NamedWriteableRegistry (
315
351
List .of (new NamedWriteableRegistry .Entry (QueryBuilder .class , MatchAllQueryBuilder .NAME , MatchAllQueryBuilder ::new ))
316
352
)
@@ -327,6 +363,8 @@ public void testHashAndEquals() {
327
363
String queryText2 = "query text 2" ;
328
364
String modelId1 = "model-1" ;
329
365
String modelId2 = "model-2" ;
366
+ float maxTokenScore1 = 1.1f ;
367
+ float maxTokenScore2 = 2.2f ;
330
368
float boost1 = 1.8f ;
331
369
float boost2 = 3.8f ;
332
370
String queryName1 = "query-1" ;
@@ -337,60 +375,77 @@ public void testHashAndEquals() {
337
375
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_baseline = new NeuralSparseQueryBuilder ().fieldName (fieldName1 )
338
376
.queryText (queryText1 )
339
377
.modelId (modelId1 )
378
+ .maxTokenScore (maxTokenScore1 )
340
379
.boost (boost1 )
341
380
.queryName (queryName1 );
342
381
343
382
// Identical to sparseEncodingQueryBuilder_baseline
344
383
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_baselineCopy = new NeuralSparseQueryBuilder ().fieldName (fieldName1 )
345
384
.queryText (queryText1 )
346
385
.modelId (modelId1 )
386
+ .maxTokenScore (maxTokenScore1 )
347
387
.boost (boost1 )
348
388
.queryName (queryName1 );
349
389
350
390
// Identical to sparseEncodingQueryBuilder_baseline except default boost and query name
351
391
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_defaultBoostAndQueryName = new NeuralSparseQueryBuilder ().fieldName (fieldName1 )
352
392
.queryText (queryText1 )
353
- .modelId (modelId1 );
393
+ .modelId (modelId1 )
394
+ .maxTokenScore (maxTokenScore1 );
354
395
355
396
// Identical to sparseEncodingQueryBuilder_baseline except diff field name
356
397
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffFieldName = new NeuralSparseQueryBuilder ().fieldName (fieldName2 )
357
398
.queryText (queryText1 )
358
399
.modelId (modelId1 )
400
+ .maxTokenScore (maxTokenScore1 )
359
401
.boost (boost1 )
360
402
.queryName (queryName1 );
361
403
362
404
// Identical to sparseEncodingQueryBuilder_baseline except diff query text
363
405
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryText = new NeuralSparseQueryBuilder ().fieldName (fieldName1 )
364
406
.queryText (queryText2 )
365
407
.modelId (modelId1 )
408
+ .maxTokenScore (maxTokenScore1 )
366
409
.boost (boost1 )
367
410
.queryName (queryName1 );
368
411
369
412
// Identical to sparseEncodingQueryBuilder_baseline except diff model ID
370
413
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffModelId = new NeuralSparseQueryBuilder ().fieldName (fieldName1 )
371
414
.queryText (queryText1 )
372
415
.modelId (modelId2 )
416
+ .maxTokenScore (maxTokenScore1 )
373
417
.boost (boost1 )
374
418
.queryName (queryName1 );
375
419
376
420
// Identical to sparseEncodingQueryBuilder_baseline except diff boost
377
421
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffBoost = new NeuralSparseQueryBuilder ().fieldName (fieldName1 )
378
422
.queryText (queryText1 )
379
423
.modelId (modelId1 )
424
+ .maxTokenScore (maxTokenScore1 )
380
425
.boost (boost2 )
381
426
.queryName (queryName1 );
382
427
383
428
// Identical to sparseEncodingQueryBuilder_baseline except diff query name
384
429
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryName = new NeuralSparseQueryBuilder ().fieldName (fieldName1 )
385
430
.queryText (queryText1 )
386
431
.modelId (modelId1 )
432
+ .maxTokenScore (maxTokenScore1 )
387
433
.boost (boost1 )
388
434
.queryName (queryName2 );
389
435
436
+ // Identical to sparseEncodingQueryBuilder_baseline except diff max token score
437
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffMaxTokenScore = new NeuralSparseQueryBuilder ().fieldName (fieldName1 )
438
+ .queryText (queryText1 )
439
+ .modelId (modelId1 )
440
+ .maxTokenScore (maxTokenScore2 )
441
+ .boost (boost1 )
442
+ .queryName (queryName1 );
443
+
390
444
// Identical to sparseEncodingQueryBuilder_baseline except non-null query tokens supplier
391
445
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_nonNullQueryTokens = new NeuralSparseQueryBuilder ().fieldName (fieldName1 )
392
446
.queryText (queryText1 )
393
447
.modelId (modelId1 )
448
+ .maxTokenScore (maxTokenScore1 )
394
449
.boost (boost1 )
395
450
.queryName (queryName1 )
396
451
.queryTokensSupplier (() -> queryTokens1 );
@@ -399,6 +454,7 @@ public void testHashAndEquals() {
399
454
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryTokens = new NeuralSparseQueryBuilder ().fieldName (fieldName1 )
400
455
.queryText (queryText1 )
401
456
.modelId (modelId1 )
457
+ .maxTokenScore (maxTokenScore1 )
402
458
.boost (boost1 )
403
459
.queryName (queryName1 )
404
460
.queryTokensSupplier (() -> queryTokens2 );
@@ -427,6 +483,9 @@ public void testHashAndEquals() {
427
483
assertNotEquals (sparseEncodingQueryBuilder_baseline , sparseEncodingQueryBuilder_diffQueryName );
428
484
assertNotEquals (sparseEncodingQueryBuilder_baseline .hashCode (), sparseEncodingQueryBuilder_diffQueryName .hashCode ());
429
485
486
+ assertNotEquals (sparseEncodingQueryBuilder_baseline , sparseEncodingQueryBuilder_diffMaxTokenScore );
487
+ assertNotEquals (sparseEncodingQueryBuilder_baseline .hashCode (), sparseEncodingQueryBuilder_diffMaxTokenScore .hashCode ());
488
+
430
489
assertNotEquals (sparseEncodingQueryBuilder_baseline , sparseEncodingQueryBuilder_nonNullQueryTokens );
431
490
assertNotEquals (sparseEncodingQueryBuilder_baseline .hashCode (), sparseEncodingQueryBuilder_nonNullQueryTokens .hashCode ());
432
491
@@ -486,4 +545,23 @@ private void setUpClusterService(Version version) {
486
545
ClusterService clusterService = NeuralSearchClusterTestUtils .mockClusterService (version );
487
546
NeuralSearchClusterUtil .instance ().initialize (clusterService );
488
547
}
548
+
549
+ @ SneakyThrows
550
+ public void testDoToQuery_successfulDoToQuery () {
551
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder ().fieldName (FIELD_NAME )
552
+ .maxTokenScore (MAX_TOKEN_SCORE )
553
+ .queryText (QUERY_TEXT )
554
+ .modelId (MODEL_ID )
555
+ .queryTokensSupplier (QUERY_TOKENS_SUPPLIER );
556
+ QueryShardContext mockedQueryShardContext = mock (QueryShardContext .class );
557
+ MappedFieldType mockedMappedFieldType = mock (MappedFieldType .class );
558
+ doAnswer (invocation -> "rank_features" ).when (mockedMappedFieldType ).typeName ();
559
+ doAnswer (invocation -> mockedMappedFieldType ).when (mockedQueryShardContext ).fieldMapper (any ());
560
+
561
+ BooleanQuery .Builder targetQueryBuilder = new BooleanQuery .Builder ();
562
+ targetQueryBuilder .add (FeatureField .newLinearQuery (FIELD_NAME , "hello" , 1.f ), BooleanClause .Occur .SHOULD );
563
+ targetQueryBuilder .add (FeatureField .newLinearQuery (FIELD_NAME , "world" , 2.f ), BooleanClause .Occur .SHOULD );
564
+
565
+ assertEquals (sparseEncodingQueryBuilder .doToQuery (mockedQueryShardContext ), targetQueryBuilder .build ());
566
+ }
489
567
}
0 commit comments