9
9
import lombok .AllArgsConstructor ;
10
10
import lombok .Getter ;
11
11
import lombok .extern .log4j .Log4j2 ;
12
- import org .apache .commons .lang .StringUtils ;
13
12
import org .apache .lucene .search .MatchNoDocsQuery ;
14
13
import org .apache .lucene .search .Query ;
15
14
import org .opensearch .common .ValidationException ;
24
23
import org .opensearch .index .query .QueryRewriteContext ;
25
24
import org .opensearch .index .query .QueryShardContext ;
26
25
import org .opensearch .knn .index .engine .KNNMethodConfigContext ;
26
+ import org .opensearch .knn .index .engine .KNNMethodContext ;
27
27
import org .opensearch .knn .index .engine .model .QueryContext ;
28
28
import org .opensearch .knn .index .engine .qframe .QuantizationConfig ;
29
29
import org .opensearch .knn .index .mapper .KNNMappingConfig ;
47
47
import java .util .Locale ;
48
48
import java .util .Map ;
49
49
import java .util .Objects ;
50
- import java .util .concurrent .atomic .AtomicReference ;
51
50
52
51
import static org .opensearch .knn .common .KNNConstants .EXPAND_NESTED ;
53
52
import static org .opensearch .knn .common .KNNConstants .MAX_DISTANCE ;
@@ -393,40 +392,12 @@ protected Query doToQuery(QueryShardContext context) {
393
392
}
394
393
KNNVectorFieldType knnVectorFieldType = (KNNVectorFieldType ) mappedFieldType ;
395
394
KNNMappingConfig knnMappingConfig = knnVectorFieldType .getKnnMappingConfig ();
396
- final AtomicReference <QueryConfigFromMapping > queryConfigFromMapping = new AtomicReference <>();
397
- int fieldDimension = knnMappingConfig .getDimension ();
398
- knnMappingConfig .getKnnMethodContext ()
399
- .ifPresentOrElse (
400
- knnMethodContext -> queryConfigFromMapping .set (
401
- new QueryConfigFromMapping (
402
- knnMethodContext .getKnnEngine (),
403
- knnMethodContext .getMethodComponentContext (),
404
- knnMethodContext .getSpaceType (),
405
- knnVectorFieldType .getVectorDataType ()
406
- )
407
- ),
408
- () -> knnMappingConfig .getModelId ().ifPresentOrElse (modelId -> {
409
- ModelMetadata modelMetadata = getModelMetadataForField (modelId );
410
- queryConfigFromMapping .set (
411
- new QueryConfigFromMapping (
412
- modelMetadata .getKnnEngine (),
413
- modelMetadata .getMethodComponentContext (),
414
- modelMetadata .getSpaceType (),
415
- modelMetadata .getVectorDataType ()
416
- )
417
- );
418
- },
419
- () -> {
420
- throw new IllegalArgumentException (
421
- String .format (Locale .ROOT , "Field '%s' is not built for ANN search." , this .fieldName )
422
- );
423
- }
424
- )
425
- );
426
- KNNEngine knnEngine = queryConfigFromMapping .get ().getKnnEngine ();
427
- MethodComponentContext methodComponentContext = queryConfigFromMapping .get ().getMethodComponentContext ();
428
- SpaceType spaceType = queryConfigFromMapping .get ().getSpaceType ();
429
- VectorDataType vectorDataType = queryConfigFromMapping .get ().getVectorDataType ();
395
+ QueryConfigFromMapping queryConfigFromMapping = getQueryConfig (knnMappingConfig , knnVectorFieldType );
396
+
397
+ KNNEngine knnEngine = queryConfigFromMapping .getKnnEngine ();
398
+ MethodComponentContext methodComponentContext = queryConfigFromMapping .getMethodComponentContext ();
399
+ SpaceType spaceType = queryConfigFromMapping .getSpaceType ();
400
+ VectorDataType vectorDataType = queryConfigFromMapping .getVectorDataType ();
430
401
RescoreContext processedRescoreContext = knnVectorFieldType .resolveRescoreContext (rescoreContext );
431
402
knnVectorFieldType .transformQueryVector (vector );
432
403
@@ -435,7 +406,7 @@ protected Query doToQuery(QueryShardContext context) {
435
406
436
407
// This could be null in the case of when a model did not have serialized methodComponent information
437
408
final String method = methodComponentContext != null ? methodComponentContext .getName () : null ;
438
- if (StringUtils . isNotBlank ( method )) {
409
+ if (method != null && ! method . isBlank ( )) {
439
410
final KNNLibrarySearchContext engineSpecificMethodContext = knnEngine .getKNNLibrarySearchContext (method );
440
411
QueryContext queryContext = new QueryContext (vectorQueryType );
441
412
ValidationException validationException = validateParameters (
@@ -494,9 +465,13 @@ protected Query doToQuery(QueryShardContext context) {
494
465
}
495
466
496
467
int vectorLength = VectorDataType .BINARY == vectorDataType ? vector .length * Byte .SIZE : vector .length ;
497
- if (fieldDimension != vectorLength ) {
468
+ if (knnMappingConfig . getDimension () != vectorLength ) {
498
469
throw new IllegalArgumentException (
499
- String .format ("Query vector has invalid dimension: %d. Dimension should be: %d" , vectorLength , fieldDimension )
470
+ String .format (
471
+ "Query vector has invalid dimension: %d. Dimension should be: %d" ,
472
+ vectorLength ,
473
+ knnMappingConfig .getDimension ()
474
+ )
500
475
);
501
476
}
502
477
@@ -572,6 +547,31 @@ protected Query doToQuery(QueryShardContext context) {
572
547
throw new IllegalArgumentException (String .format (Locale .ROOT , "[%s] requires k or distance or score to be set" , NAME ));
573
548
}
574
549
550
+ private QueryConfigFromMapping getQueryConfig (final KNNMappingConfig knnMappingConfig , final KNNVectorFieldType knnVectorFieldType ) {
551
+
552
+ if (knnMappingConfig .getKnnMethodContext ().isPresent ()) {
553
+ KNNMethodContext knnMethodContext = knnMappingConfig .getKnnMethodContext ().get ();
554
+ return new QueryConfigFromMapping (
555
+ knnMethodContext .getKnnEngine (),
556
+ knnMethodContext .getMethodComponentContext (),
557
+ knnMethodContext .getSpaceType (),
558
+ knnVectorFieldType .getVectorDataType ()
559
+ );
560
+ }
561
+
562
+ if (knnMappingConfig .getModelId ().isPresent ()) {
563
+ ModelMetadata modelMetadata = getModelMetadataForField (knnMappingConfig .getModelId ().get ());
564
+ return new QueryConfigFromMapping (
565
+ modelMetadata .getKnnEngine (),
566
+ modelMetadata .getMethodComponentContext (),
567
+ modelMetadata .getSpaceType (),
568
+ modelMetadata .getVectorDataType ()
569
+ );
570
+ }
571
+
572
+ throw new IllegalArgumentException (String .format (Locale .ROOT , "Field '%s' is not built for ANN search." , this .fieldName ));
573
+ }
574
+
575
575
private ModelMetadata getModelMetadataForField (String modelId ) {
576
576
ModelMetadata modelMetadata = modelDao .getMetadata (modelId );
577
577
if (!ModelUtil .isModelCreated (modelMetadata )) {
0 commit comments