Skip to content

Commit 9a52b2b

Browse files
authored
Add stats for radial search (#1684)
Signed-off-by: Junqiu Lei <junqiu@amazon.com>
1 parent c315862 commit 9a52b2b

File tree

8 files changed

+230
-24
lines changed

8 files changed

+230
-24
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1616
### Features
1717
### Enhancements
1818
* Add KnnCircuitBreakerException and modify exception message [#1688](https://github.com/opensearch-project/k-NN/pull/1688)
19+
* Add stats for radial search [#1684](https://github.com/opensearch-project/k-NN/pull/1684)
1920
### Bug Fixes
2021
* Block commas in model description [#1692](https://github.com/opensearch-project/k-NN/pull/1692)
2122
### Infrastructure
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.knn.index;
7+
8+
import lombok.Getter;
9+
import org.opensearch.knn.common.KNNConstants;
10+
import org.opensearch.knn.plugin.stats.KNNCounter;
11+
12+
@Getter
13+
public enum VectorQueryType {
14+
K(KNNConstants.K) {
15+
@Override
16+
public KNNCounter getQueryStatCounter() {
17+
return KNNCounter.KNN_QUERY_REQUESTS;
18+
}
19+
20+
@Override
21+
public KNNCounter getQueryWithFilterStatCounter() {
22+
return KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS;
23+
}
24+
},
25+
MIN_SCORE(KNNConstants.MIN_SCORE) {
26+
@Override
27+
public KNNCounter getQueryStatCounter() {
28+
return KNNCounter.MIN_SCORE_QUERY_REQUESTS;
29+
}
30+
31+
@Override
32+
public KNNCounter getQueryWithFilterStatCounter() {
33+
return KNNCounter.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS;
34+
}
35+
},
36+
MAX_DISTANCE(KNNConstants.MAX_DISTANCE) {
37+
@Override
38+
public KNNCounter getQueryStatCounter() {
39+
return KNNCounter.MAX_DISTANCE_QUERY_REQUESTS;
40+
}
41+
42+
@Override
43+
public KNNCounter getQueryWithFilterStatCounter() {
44+
return KNNCounter.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS;
45+
}
46+
};
47+
48+
private final String queryTypeName;
49+
50+
VectorQueryType(String queryTypeName) {
51+
this.queryTypeName = queryTypeName;
52+
}
53+
54+
public abstract KNNCounter getQueryStatCounter();
55+
56+
public abstract KNNCounter getQueryWithFilterStatCounter();
57+
}

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

+13-5
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020
import org.opensearch.knn.index.KNNMethodContext;
2121
import org.opensearch.knn.index.SpaceType;
2222
import org.opensearch.knn.index.VectorDataType;
23+
import org.opensearch.knn.index.VectorQueryType;
2324
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
2425
import org.opensearch.knn.index.util.KNNEngine;
2526
import org.opensearch.knn.indices.ModelDao;
2627
import org.opensearch.knn.indices.ModelMetadata;
2728
import org.opensearch.knn.indices.ModelUtil;
28-
import org.opensearch.knn.plugin.stats.KNNCounter;
2929
import org.apache.lucene.search.Query;
3030
import org.opensearch.core.ParseField;
3131
import org.opensearch.core.common.ParsingException;
@@ -242,7 +242,6 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
242242
String currentFieldName = null;
243243
boolean ignoreUnmapped = false;
244244
XContentParser.Token token;
245-
KNNCounter.KNN_QUERY_REQUESTS.increment();
246245
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
247246
if (token == XContentParser.Token.FIELD_NAME) {
248247
currentFieldName = parser.currentName();
@@ -279,7 +278,6 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
279278
String tokenName = parser.currentName();
280279
if (FILTER_FIELD.getPreferredName().equals(tokenName)) {
281280
log.debug(String.format("Start parsing filter for field [%s]", fieldName));
282-
KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.increment();
283281
filter = parseInnerQueryBuilder(parser);
284282
} else {
285283
throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "]");
@@ -298,7 +296,11 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
298296
}
299297
}
300298

301-
validateSingleQueryType(k, maxDistance, minScore);
299+
VectorQueryType vectorQueryType = validateSingleQueryType(k, maxDistance, minScore);
300+
vectorQueryType.getQueryStatCounter().increment();
301+
if (filter != null) {
302+
vectorQueryType.getQueryWithFilterStatCounter().increment();
303+
}
302304

303305
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector)).filter(filter)
304306
.ignoreUnmapped(ignoreUnmapped)
@@ -549,21 +551,27 @@ public String getWriteableName() {
549551
return NAME;
550552
}
551553

552-
private static void validateSingleQueryType(Integer k, Float distance, Float score) {
554+
private static VectorQueryType validateSingleQueryType(Integer k, Float distance, Float score) {
553555
int countSetFields = 0;
556+
VectorQueryType vectorQueryType = null;
554557

555558
if (k != null && k != 0) {
556559
countSetFields++;
560+
vectorQueryType = VectorQueryType.K;
557561
}
558562
if (distance != null) {
559563
countSetFields++;
564+
vectorQueryType = VectorQueryType.MAX_DISTANCE;
560565
}
561566
if (score != null) {
562567
countSetFields++;
568+
vectorQueryType = VectorQueryType.MIN_SCORE;
563569
}
564570

565571
if (countSetFields != 1) {
566572
throw new IllegalArgumentException(String.format("[%s] requires exactly one of k, distance or score to be set", NAME));
567573
}
574+
575+
return vectorQueryType;
568576
}
569577
}

src/main/java/org/opensearch/knn/plugin/stats/KNNCounter.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@ public enum KNNCounter {
2222
SCRIPT_QUERY_ERRORS("script_query_errors"),
2323
TRAINING_REQUESTS("training_requests"),
2424
TRAINING_ERRORS("training_errors"),
25-
KNN_QUERY_WITH_FILTER_REQUESTS("knn_query_with_filter_requests");
25+
KNN_QUERY_WITH_FILTER_REQUESTS("knn_query_with_filter_requests"),
26+
MIN_SCORE_QUERY_REQUESTS("min_score_query_requests"),
27+
MIN_SCORE_QUERY_WITH_FILTER_REQUESTS("min_score_query_with_filter_requests"),
28+
MAX_DISTANCE_QUERY_REQUESTS("max_distance_query_requests"),
29+
MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS("max_distance_query_with_filter_requests");
2630

2731
private String name;
2832
private AtomicLong count;

src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java

+20
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,32 @@ private Map<String, KNNStat<?>> buildStatsMap() {
9090
}
9191

9292
private void addQueryStats(ImmutableMap.Builder<String, KNNStat<?>> builder) {
93+
// KNN Query Stats
9394
builder.put(StatNames.KNN_QUERY_REQUESTS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.KNN_QUERY_REQUESTS)))
9495
.put(
9596
StatNames.KNN_QUERY_WITH_FILTER_REQUESTS.getName(),
9697
new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS))
9798
);
9899

100+
// Min Score Query Stats
101+
builder.put(
102+
StatNames.MIN_SCORE_QUERY_REQUESTS.getName(),
103+
new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.MIN_SCORE_QUERY_REQUESTS))
104+
)
105+
.put(
106+
StatNames.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.getName(),
107+
new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS))
108+
);
109+
110+
// Max Distance Query Stats
111+
builder.put(
112+
StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName(),
113+
new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.MAX_DISTANCE_QUERY_REQUESTS))
114+
)
115+
.put(
116+
StatNames.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.getName(),
117+
new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS))
118+
);
99119
}
100120

101121
private void addNativeMemoryStats(ImmutableMap.Builder<String, KNNStat<?>> builder) {

src/main/java/org/opensearch/knn/plugin/stats/StatNames.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,11 @@ public enum StatNames {
4444
KNN_QUERY_WITH_FILTER_REQUESTS(KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.getName()),
4545
GRAPH_STATS("graph_stats"),
4646
REFRESH("refresh"),
47-
MERGE("merge");
47+
MERGE("merge"),
48+
MIN_SCORE_QUERY_REQUESTS(KNNCounter.MIN_SCORE_QUERY_REQUESTS.getName()),
49+
MIN_SCORE_QUERY_WITH_FILTER_REQUESTS(KNNCounter.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.getName()),
50+
MAX_DISTANCE_QUERY_REQUESTS(KNNCounter.MAX_DISTANCE_QUERY_REQUESTS.getName()),
51+
MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS(KNNCounter.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.getName());
4852

4953
private String name;
5054

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*
8+
* Modifications Copyright OpenSearch Contributors. See
9+
* GitHub history for details.
10+
*/
11+
12+
package org.opensearch.knn.index;
13+
14+
import org.opensearch.knn.KNNTestCase;
15+
import org.opensearch.knn.plugin.stats.KNNCounter;
16+
17+
public class VectorQueryTypeTests extends KNNTestCase {
18+
19+
public void testGetQueryStatCounter() {
20+
assertEquals(KNNCounter.KNN_QUERY_REQUESTS, VectorQueryType.K.getQueryStatCounter());
21+
assertEquals(KNNCounter.MIN_SCORE_QUERY_REQUESTS, VectorQueryType.MIN_SCORE.getQueryStatCounter());
22+
assertEquals(KNNCounter.MAX_DISTANCE_QUERY_REQUESTS, VectorQueryType.MAX_DISTANCE.getQueryStatCounter());
23+
}
24+
25+
public void testGetQueryWithFilterStatCounter() {
26+
assertEquals(KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS, VectorQueryType.K.getQueryWithFilterStatCounter());
27+
assertEquals(KNNCounter.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS, VectorQueryType.MIN_SCORE.getQueryWithFilterStatCounter());
28+
assertEquals(KNNCounter.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS, VectorQueryType.MAX_DISTANCE.getQueryWithFilterStatCounter());
29+
}
30+
}

src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java

+99-17
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.knn.plugin.action;
77

8+
import lombok.SneakyThrows;
89
import org.apache.hc.core5.http.io.entity.EntityUtils;
910
import org.apache.logging.log4j.LogManager;
1011
import org.apache.logging.log4j.Logger;
@@ -30,27 +31,12 @@
3031
import org.opensearch.core.rest.RestStatus;
3132

3233
import java.io.IOException;
33-
import java.util.Arrays;
34-
import java.util.Collections;
35-
import java.util.HashMap;
36-
import java.util.List;
37-
import java.util.Map;
34+
import java.util.*;
3835

3936
import static org.opensearch.knn.TestUtils.KNN_VECTOR;
4037
import static org.opensearch.knn.TestUtils.PROPERTIES;
4138
import static org.opensearch.knn.TestUtils.VECTOR_TYPE;
42-
import static org.opensearch.knn.common.KNNConstants.FAISS_NAME;
43-
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
44-
import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME;
45-
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
46-
import static org.opensearch.knn.common.KNNConstants.METHOD_IVF;
47-
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
48-
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
49-
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
50-
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME;
51-
import static org.opensearch.knn.common.KNNConstants.NAME;
52-
import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME;
53-
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
39+
import static org.opensearch.knn.common.KNNConstants.*;
5440

5541
/**
5642
* Integration tests to check the correctness of RestKNNStatsHandler
@@ -432,6 +418,95 @@ public void testFieldsByEngineModelTraining() throws Exception {
432418
assertTrue(faissField);
433419
}
434420

421+
public void testRadialSearchStats_thenSucceed() throws Exception {
422+
createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2, METHOD_HNSW, LUCENE_NAME));
423+
Float[] vector = { 6.0f, 6.0f };
424+
addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector);
425+
426+
// First search: radial search by min score
427+
XContentBuilder queryBuilderMinScore = XContentFactory.jsonBuilder().startObject().startObject("query");
428+
queryBuilderMinScore.startObject("knn");
429+
queryBuilderMinScore.startObject(FIELD_NAME);
430+
queryBuilderMinScore.field("vector", vector);
431+
queryBuilderMinScore.field(MIN_SCORE, 0.95f);
432+
queryBuilderMinScore.endObject();
433+
queryBuilderMinScore.endObject();
434+
queryBuilderMinScore.endObject().endObject();
435+
436+
Integer minScoreStatBeforeMinScoreSearch = getStatCount(StatNames.MIN_SCORE_QUERY_REQUESTS.getName());
437+
searchKNNIndex(INDEX_NAME, queryBuilderMinScore, 1);
438+
Integer minScoreStatAfterMinScoreSearch = getStatCount(StatNames.MIN_SCORE_QUERY_REQUESTS.getName());
439+
440+
assertEquals(1, minScoreStatAfterMinScoreSearch - minScoreStatBeforeMinScoreSearch);
441+
442+
// Second search: radial search by min score with filter
443+
XContentBuilder queryBuilderMinScoreWithFilter = XContentFactory.jsonBuilder().startObject().startObject("query");
444+
queryBuilderMinScoreWithFilter.startObject("knn");
445+
queryBuilderMinScoreWithFilter.startObject(FIELD_NAME);
446+
queryBuilderMinScoreWithFilter.field("vector", vector);
447+
queryBuilderMinScoreWithFilter.field(MIN_SCORE, 0.95f);
448+
queryBuilderMinScoreWithFilter.field("filter", QueryBuilders.termQuery("_id", "1"));
449+
queryBuilderMinScoreWithFilter.endObject();
450+
queryBuilderMinScoreWithFilter.endObject();
451+
queryBuilderMinScoreWithFilter.endObject().endObject();
452+
453+
Integer minScoreWithFilterStatBeforeMinScoreWithFilterSearch = getStatCount(
454+
StatNames.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.getName()
455+
);
456+
Integer minScoreStatBeforeMinScoreWithFilterSearch = getStatCount(StatNames.MIN_SCORE_QUERY_REQUESTS.getName());
457+
searchKNNIndex(INDEX_NAME, queryBuilderMinScoreWithFilter, 1);
458+
Integer minScoreWithFilterStatAfterMinScoreWithFilterSearch = getStatCount(
459+
StatNames.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.getName()
460+
);
461+
Integer minScoreStatAfterMinScoreWithFilterSearch = getStatCount(StatNames.MIN_SCORE_QUERY_REQUESTS.getName());
462+
463+
assertEquals(1, minScoreWithFilterStatAfterMinScoreWithFilterSearch - minScoreWithFilterStatBeforeMinScoreWithFilterSearch);
464+
assertEquals(1, minScoreStatAfterMinScoreWithFilterSearch - minScoreStatBeforeMinScoreWithFilterSearch);
465+
466+
// Third search: radial search by max distance
467+
XContentBuilder queryBuilderMaxDistance = XContentFactory.jsonBuilder().startObject().startObject("query");
468+
queryBuilderMaxDistance.startObject("knn");
469+
queryBuilderMaxDistance.startObject(FIELD_NAME);
470+
queryBuilderMaxDistance.field("vector", vector);
471+
queryBuilderMaxDistance.field(MAX_DISTANCE, 100f);
472+
queryBuilderMaxDistance.endObject();
473+
queryBuilderMaxDistance.endObject();
474+
queryBuilderMaxDistance.endObject().endObject();
475+
476+
Integer maxDistanceStatBeforeMaxDistanceSearch = getStatCount(StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName());
477+
searchKNNIndex(INDEX_NAME, queryBuilderMaxDistance, 0);
478+
Integer maxDistanceStatAfterMaxDistanceSearch = getStatCount(StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName());
479+
480+
assertEquals(1, maxDistanceStatAfterMaxDistanceSearch - maxDistanceStatBeforeMaxDistanceSearch);
481+
482+
// Fourth search: radial search by max distance with filter
483+
XContentBuilder queryBuilderMaxDistanceWithFilter = XContentFactory.jsonBuilder().startObject().startObject("query");
484+
queryBuilderMaxDistanceWithFilter.startObject("knn");
485+
queryBuilderMaxDistanceWithFilter.startObject(FIELD_NAME);
486+
queryBuilderMaxDistanceWithFilter.field("vector", vector);
487+
queryBuilderMaxDistanceWithFilter.field(MAX_DISTANCE, 100f);
488+
queryBuilderMaxDistanceWithFilter.field("filter", QueryBuilders.termQuery("_id", "1"));
489+
queryBuilderMaxDistanceWithFilter.endObject();
490+
queryBuilderMaxDistanceWithFilter.endObject();
491+
queryBuilderMaxDistanceWithFilter.endObject().endObject();
492+
493+
Integer maxDistanceWithFilterStatBeforeMaxDistanceWithFilterSearch = getStatCount(
494+
StatNames.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.getName()
495+
);
496+
Integer maxDistanceStatBeforeMaxDistanceWithFilterSearch = getStatCount(StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName());
497+
searchKNNIndex(INDEX_NAME, queryBuilderMaxDistanceWithFilter, 0);
498+
Integer maxDistanceWithFilterStatAfterMaxDistanceWithFilterSearch = getStatCount(
499+
StatNames.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.getName()
500+
);
501+
Integer maxDistanceStatAfterMaxDistanceWithFilterSearch = getStatCount(StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName());
502+
503+
assertEquals(
504+
1,
505+
maxDistanceWithFilterStatAfterMaxDistanceWithFilterSearch - maxDistanceWithFilterStatBeforeMaxDistanceWithFilterSearch
506+
);
507+
assertEquals(1, maxDistanceStatAfterMaxDistanceWithFilterSearch - maxDistanceStatBeforeMaxDistanceWithFilterSearch);
508+
}
509+
435510
public void trainKnnModel(String modelId, String trainingIndexName, String trainingFieldName, int dimension, String description)
436511
throws IOException {
437512
XContentBuilder builder = XContentFactory.jsonBuilder()
@@ -487,4 +562,11 @@ protected Settings restClientSettings() {
487562
return super.restClientSettings();
488563
}
489564
}
565+
566+
@SneakyThrows
567+
private Integer getStatCount(String statName) {
568+
Response response = getKnnStats(Collections.emptyList(), Collections.emptyList());
569+
String responseBody = EntityUtils.toString(response.getEntity());
570+
return (Integer) parseNodeStatsResponse(responseBody).get(0).get(statName);
571+
}
490572
}

0 commit comments

Comments
 (0)