4
4
*/
5
5
package org .opensearch .neuralsearch .processor ;
6
6
7
+ import static org .mockito .Mockito .mock ;
7
8
import static org .mockito .Mockito .spy ;
9
+ import static org .mockito .Mockito .when ;
8
10
import static org .opensearch .neuralsearch .search .util .HybridSearchResultFormatUtil .createDelimiterElementForHybridSearchResults ;
9
11
import static org .opensearch .neuralsearch .search .util .HybridSearchResultFormatUtil .createStartStopElementForHybridSearchResults ;
10
12
29
31
import org .opensearch .search .SearchHits ;
30
32
import org .opensearch .search .SearchShardTarget ;
31
33
import org .opensearch .search .fetch .FetchSearchResult ;
34
+ import org .opensearch .search .internal .ShardSearchRequest ;
32
35
import org .opensearch .search .query .QuerySearchResult ;
33
36
import org .opensearch .test .OpenSearchTestCase ;
34
37
@@ -156,6 +159,9 @@ public void testFetchResults_whenOneShardAndQueryAndFetchResultsPresent_thenDoNo
156
159
);
157
160
querySearchResult .setSearchShardTarget (searchShardTarget );
158
161
querySearchResult .setShardIndex (shardId );
162
+ ShardSearchRequest shardSearchRequest = mock (ShardSearchRequest .class );
163
+ when (shardSearchRequest .requestCache ()).thenReturn (Boolean .TRUE );
164
+ querySearchResult .setShardSearchRequest (shardSearchRequest );
159
165
querySearchResults .add (querySearchResult );
160
166
SearchHit [] searchHitArray = new SearchHit [] {
161
167
new SearchHit (0 , "10" , Map .of (), Map .of ()),
@@ -213,6 +219,9 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom
213
219
);
214
220
querySearchResult .setSearchShardTarget (searchShardTarget );
215
221
querySearchResult .setShardIndex (shardId );
222
+ ShardSearchRequest shardSearchRequest = mock (ShardSearchRequest .class );
223
+ when (shardSearchRequest .requestCache ()).thenReturn (Boolean .TRUE );
224
+ querySearchResult .setShardSearchRequest (shardSearchRequest );
216
225
querySearchResults .add (querySearchResult );
217
226
SearchHit [] searchHitArray = new SearchHit [] {
218
227
new SearchHit (-1 , "10" , Map .of (), Map .of ()),
@@ -236,7 +245,7 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom
236
245
TestUtils .assertFetchResultScores (fetchSearchResult , 4 );
237
246
}
238
247
239
- public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_thenFail () {
248
+ public void testFetchResultsAndNoCache_whenOneShardAndMultipleNodesAndMismatchResults_thenFail () {
240
249
NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy (
241
250
new NormalizationProcessorWorkflow (new ScoreNormalizer (), new ScoreCombiner ())
242
251
);
@@ -270,15 +279,11 @@ public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_then
270
279
);
271
280
querySearchResult .setSearchShardTarget (searchShardTarget );
272
281
querySearchResult .setShardIndex (shardId );
282
+ ShardSearchRequest shardSearchRequest = mock (ShardSearchRequest .class );
283
+ when (shardSearchRequest .requestCache ()).thenReturn (Boolean .FALSE );
284
+ querySearchResult .setShardSearchRequest (shardSearchRequest );
273
285
querySearchResults .add (querySearchResult );
274
- SearchHit [] searchHitArray = new SearchHit [] {
275
- new SearchHit (-1 , "10" , Map .of (), Map .of ()),
276
- new SearchHit (-1 , "10" , Map .of (), Map .of ()),
277
- new SearchHit (-1 , "10" , Map .of (), Map .of ()),
278
- new SearchHit (-1 , "1" , Map .of (), Map .of ()),
279
- new SearchHit (-1 , "2" , Map .of (), Map .of ()),
280
- new SearchHit (-1 , "3" , Map .of (), Map .of ()) };
281
- SearchHits searchHits = new SearchHits (searchHitArray , new TotalHits (7 , TotalHits .Relation .EQUAL_TO ), 10 );
286
+ SearchHits searchHits = getSearchHits ();
282
287
fetchSearchResult .hits (searchHits );
283
288
284
289
expectThrows (
@@ -291,4 +296,68 @@ public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_then
291
296
)
292
297
);
293
298
}
299
+
300
+ public void testFetchResultsAndCache_whenOneShardAndMultipleNodesAndMismatchResults_thenSuccessful () {
301
+ NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy (
302
+ new NormalizationProcessorWorkflow (new ScoreNormalizer (), new ScoreCombiner ())
303
+ );
304
+
305
+ List <QuerySearchResult > querySearchResults = new ArrayList <>();
306
+ FetchSearchResult fetchSearchResult = new FetchSearchResult ();
307
+ int shardId = 0 ;
308
+ SearchShardTarget searchShardTarget = new SearchShardTarget (
309
+ "node" ,
310
+ new ShardId ("index" , "uuid" , shardId ),
311
+ null ,
312
+ OriginalIndices .NONE
313
+ );
314
+ QuerySearchResult querySearchResult = new QuerySearchResult ();
315
+ querySearchResult .topDocs (
316
+ new TopDocsAndMaxScore (
317
+ new TopDocs (
318
+ new TotalHits (4 , TotalHits .Relation .EQUAL_TO ),
319
+ new ScoreDoc [] {
320
+ createStartStopElementForHybridSearchResults (0 ),
321
+ createDelimiterElementForHybridSearchResults (0 ),
322
+ new ScoreDoc (0 , 0.5f ),
323
+ new ScoreDoc (2 , 0.3f ),
324
+ new ScoreDoc (4 , 0.25f ),
325
+ new ScoreDoc (10 , 0.2f ),
326
+ createStartStopElementForHybridSearchResults (0 ) }
327
+ ),
328
+ 0.5f
329
+ ),
330
+ new DocValueFormat [0 ]
331
+ );
332
+ querySearchResult .setSearchShardTarget (searchShardTarget );
333
+ querySearchResult .setShardIndex (shardId );
334
+ ShardSearchRequest shardSearchRequest = mock (ShardSearchRequest .class );
335
+ when (shardSearchRequest .requestCache ()).thenReturn (Boolean .TRUE );
336
+ querySearchResult .setShardSearchRequest (shardSearchRequest );
337
+ querySearchResults .add (querySearchResult );
338
+ SearchHits searchHits = getSearchHits ();
339
+ fetchSearchResult .hits (searchHits );
340
+
341
+ normalizationProcessorWorkflow .execute (
342
+ querySearchResults ,
343
+ Optional .of (fetchSearchResult ),
344
+ ScoreNormalizationFactory .DEFAULT_METHOD ,
345
+ ScoreCombinationFactory .DEFAULT_METHOD
346
+ );
347
+
348
+ TestUtils .assertQueryResultScores (querySearchResults );
349
+ TestUtils .assertFetchResultScores (fetchSearchResult , 4 );
350
+ }
351
+
352
+ private static SearchHits getSearchHits () {
353
+ SearchHit [] searchHitArray = new SearchHit [] {
354
+ new SearchHit (-1 , "10" , Map .of (), Map .of ()),
355
+ new SearchHit (-1 , "10" , Map .of (), Map .of ()),
356
+ new SearchHit (-1 , "10" , Map .of (), Map .of ()),
357
+ new SearchHit (-1 , "1" , Map .of (), Map .of ()),
358
+ new SearchHit (-1 , "2" , Map .of (), Map .of ()),
359
+ new SearchHit (-1 , "3" , Map .of (), Map .of ()) };
360
+ SearchHits searchHits = new SearchHits (searchHitArray , new TotalHits (7 , TotalHits .Relation .EQUAL_TO ), 10 );
361
+ return searchHits ;
362
+ }
294
363
}
0 commit comments