Skip to content

Commit 4b5c3ec

Browse files
Minor performance improvments in KNNQueryBuilder (#2528) (#2531)
Signed-off-by: Tejas Shah <shatejas@amazon.com> (cherry picked from commit 45ecb5b) Co-authored-by: Tejas Shah <shatejas@amazon.com>
1 parent 3977536 commit 4b5c3ec

File tree

2 files changed

+40
-40
lines changed

2 files changed

+40
-40
lines changed

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

+39-39
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import lombok.AllArgsConstructor;
1010
import lombok.Getter;
1111
import lombok.extern.log4j.Log4j2;
12-
import org.apache.commons.lang.StringUtils;
1312
import org.apache.lucene.search.MatchNoDocsQuery;
1413
import org.apache.lucene.search.Query;
1514
import org.opensearch.common.ValidationException;
@@ -24,6 +23,7 @@
2423
import org.opensearch.index.query.QueryRewriteContext;
2524
import org.opensearch.index.query.QueryShardContext;
2625
import org.opensearch.knn.index.engine.KNNMethodConfigContext;
26+
import org.opensearch.knn.index.engine.KNNMethodContext;
2727
import org.opensearch.knn.index.engine.model.QueryContext;
2828
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
2929
import org.opensearch.knn.index.mapper.KNNMappingConfig;
@@ -47,7 +47,6 @@
4747
import java.util.Locale;
4848
import java.util.Map;
4949
import java.util.Objects;
50-
import java.util.concurrent.atomic.AtomicReference;
5150

5251
import static org.opensearch.knn.common.KNNConstants.EXPAND_NESTED;
5352
import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE;
@@ -393,40 +392,12 @@ protected Query doToQuery(QueryShardContext context) {
393392
}
394393
KNNVectorFieldType knnVectorFieldType = (KNNVectorFieldType) mappedFieldType;
395394
KNNMappingConfig knnMappingConfig = knnVectorFieldType.getKnnMappingConfig();
396-
final AtomicReference<QueryConfigFromMapping> queryConfigFromMapping = new AtomicReference<>();
397-
int fieldDimension = knnMappingConfig.getDimension();
398-
knnMappingConfig.getKnnMethodContext()
399-
.ifPresentOrElse(
400-
knnMethodContext -> queryConfigFromMapping.set(
401-
new QueryConfigFromMapping(
402-
knnMethodContext.getKnnEngine(),
403-
knnMethodContext.getMethodComponentContext(),
404-
knnMethodContext.getSpaceType(),
405-
knnVectorFieldType.getVectorDataType()
406-
)
407-
),
408-
() -> knnMappingConfig.getModelId().ifPresentOrElse(modelId -> {
409-
ModelMetadata modelMetadata = getModelMetadataForField(modelId);
410-
queryConfigFromMapping.set(
411-
new QueryConfigFromMapping(
412-
modelMetadata.getKnnEngine(),
413-
modelMetadata.getMethodComponentContext(),
414-
modelMetadata.getSpaceType(),
415-
modelMetadata.getVectorDataType()
416-
)
417-
);
418-
},
419-
() -> {
420-
throw new IllegalArgumentException(
421-
String.format(Locale.ROOT, "Field '%s' is not built for ANN search.", this.fieldName)
422-
);
423-
}
424-
)
425-
);
426-
KNNEngine knnEngine = queryConfigFromMapping.get().getKnnEngine();
427-
MethodComponentContext methodComponentContext = queryConfigFromMapping.get().getMethodComponentContext();
428-
SpaceType spaceType = queryConfigFromMapping.get().getSpaceType();
429-
VectorDataType vectorDataType = queryConfigFromMapping.get().getVectorDataType();
395+
QueryConfigFromMapping queryConfigFromMapping = getQueryConfig(knnMappingConfig, knnVectorFieldType);
396+
397+
KNNEngine knnEngine = queryConfigFromMapping.getKnnEngine();
398+
MethodComponentContext methodComponentContext = queryConfigFromMapping.getMethodComponentContext();
399+
SpaceType spaceType = queryConfigFromMapping.getSpaceType();
400+
VectorDataType vectorDataType = queryConfigFromMapping.getVectorDataType();
430401
RescoreContext processedRescoreContext = knnVectorFieldType.resolveRescoreContext(rescoreContext);
431402
knnVectorFieldType.transformQueryVector(vector);
432403

@@ -435,7 +406,7 @@ protected Query doToQuery(QueryShardContext context) {
435406

436407
// This could be null in the case of when a model did not have serialized methodComponent information
437408
final String method = methodComponentContext != null ? methodComponentContext.getName() : null;
438-
if (StringUtils.isNotBlank(method)) {
409+
if (method != null && !method.isBlank()) {
439410
final KNNLibrarySearchContext engineSpecificMethodContext = knnEngine.getKNNLibrarySearchContext(method);
440411
QueryContext queryContext = new QueryContext(vectorQueryType);
441412
ValidationException validationException = validateParameters(
@@ -494,9 +465,13 @@ protected Query doToQuery(QueryShardContext context) {
494465
}
495466

496467
int vectorLength = VectorDataType.BINARY == vectorDataType ? vector.length * Byte.SIZE : vector.length;
497-
if (fieldDimension != vectorLength) {
468+
if (knnMappingConfig.getDimension() != vectorLength) {
498469
throw new IllegalArgumentException(
499-
String.format("Query vector has invalid dimension: %d. Dimension should be: %d", vectorLength, fieldDimension)
470+
String.format(
471+
"Query vector has invalid dimension: %d. Dimension should be: %d",
472+
vectorLength,
473+
knnMappingConfig.getDimension()
474+
)
500475
);
501476
}
502477

@@ -572,6 +547,31 @@ protected Query doToQuery(QueryShardContext context) {
572547
throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires k or distance or score to be set", NAME));
573548
}
574549

550+
private QueryConfigFromMapping getQueryConfig(final KNNMappingConfig knnMappingConfig, final KNNVectorFieldType knnVectorFieldType) {
551+
552+
if (knnMappingConfig.getKnnMethodContext().isPresent()) {
553+
KNNMethodContext knnMethodContext = knnMappingConfig.getKnnMethodContext().get();
554+
return new QueryConfigFromMapping(
555+
knnMethodContext.getKnnEngine(),
556+
knnMethodContext.getMethodComponentContext(),
557+
knnMethodContext.getSpaceType(),
558+
knnVectorFieldType.getVectorDataType()
559+
);
560+
}
561+
562+
if (knnMappingConfig.getModelId().isPresent()) {
563+
ModelMetadata modelMetadata = getModelMetadataForField(knnMappingConfig.getModelId().get());
564+
return new QueryConfigFromMapping(
565+
modelMetadata.getKnnEngine(),
566+
modelMetadata.getMethodComponentContext(),
567+
modelMetadata.getSpaceType(),
568+
modelMetadata.getVectorDataType()
569+
);
570+
}
571+
572+
throw new IllegalArgumentException(String.format(Locale.ROOT, "Field '%s' is not built for ANN search.", this.fieldName));
573+
}
574+
575575
private ModelMetadata getModelMetadataForField(String modelId) {
576576
ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
577577
if (!ModelUtil.isModelCreated(modelMetadata)) {

src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ public static Query create(CreateQueryRequest createQueryRequest) {
121121
requestEfSearch = (Integer) methodParameters.get(METHOD_PARAMETER_EF_SEARCH);
122122
}
123123
int luceneK = requestEfSearch == null ? k : Math.max(k, requestEfSearch);
124-
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
124+
log.debug("Creating Lucene k-NN query for index: {}, field:{}, k: {}", indexName, fieldName, k);
125125
switch (vectorDataType) {
126126
case BYTE:
127127
case BINARY:

0 commit comments

Comments
 (0)