Skip to content

Commit 7c65643

Browse files
authored
Throw proper exception to invalid k-NN query (#1380) (#1381)
Signed-off-by: Junqiu Lei <junqiu@amazon.com>
1 parent 722bc63 commit 7c65643

File tree

4 files changed

+123
-0
lines changed

4 files changed

+123
-0
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1919
* Increase Lucene max dimension limit to 16,000 [#1346](https://github.com/opensearch-project/k-NN/pull/1346)
2020
* Tuned default values for ef_search and ef_construction for better indexing and search performance for vector search [#1353](https://github.com/opensearch-project/k-NN/pull/1353)
2121
* Enabled Filtering on Nested Vector fields with top level filters [#1372](https://github.com/opensearch-project/k-NN/pull/1372)
22+
* Throw proper exception to invalid k-NN query [#1380](https://github.com/opensearch-project/k-NN/pull/1380)
2223
### Bug Fixes
2324
* Fix use-after-free case on nmslib search path [#1305](https://github.com/opensearch-project/k-NN/pull/1305)
2425
* Allow nested knn field mapping when train model [#1318](https://github.com/opensearch-project/k-NN/pull/1318)

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

+6
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,14 @@ public static void initialize(ModelDao modelDao) {
100100
}
101101

102102
private static float[] ObjectsToFloats(List<Object> objs) {
103+
if (Objects.isNull(objs) || objs.isEmpty()) {
104+
throw new IllegalArgumentException(String.format("[%s] field 'vector' requires to be non-null and non-empty", NAME));
105+
}
103106
float[] vec = new float[objs.size()];
104107
for (int i = 0; i < objs.size(); i++) {
108+
if ((objs.get(i) instanceof Number) == false) {
109+
throw new IllegalArgumentException(String.format("[%s] field 'vector' requires to be an array of numbers", NAME));
110+
}
105111
vec[i] = ((Number) objs.get(i)).floatValue();
106112
}
107113
return vec;

src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java

+51
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.opensearch.core.rest.RestStatus;
2525
import org.opensearch.script.Script;
2626

27+
import java.util.ArrayList;
2728
import java.util.Collections;
2829
import java.util.HashMap;
2930
import java.util.List;
@@ -425,6 +426,56 @@ public void testKNNScriptScoreWithInvalidByteQueryVector() throws Exception {
425426
);
426427
}
427428

429+
@SneakyThrows
430+
public void testSearchWithInvalidSearchVectorType() {
431+
createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.FLOAT.getValue());
432+
ingestL2FloatTestData();
433+
Request request = new Request("POST", String.format("/%s/_search", INDEX_NAME));
434+
List<Object> invalidTypeQueryVector = new ArrayList<>();
435+
invalidTypeQueryVector.add(1.5);
436+
invalidTypeQueryVector.add(2.5);
437+
invalidTypeQueryVector.add("a");
438+
invalidTypeQueryVector.add(null);
439+
XContentBuilder builder = XContentFactory.jsonBuilder()
440+
.startObject()
441+
.startObject("query")
442+
.startObject("knn")
443+
.startObject(FIELD_NAME)
444+
.field("vector", invalidTypeQueryVector)
445+
.field("k", 4)
446+
.endObject()
447+
.endObject()
448+
.endObject()
449+
.endObject();
450+
request.setJsonEntity(builder.toString());
451+
452+
ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request));
453+
assertEquals(400, ex.getResponse().getStatusLine().getStatusCode());
454+
assertTrue(ex.getMessage().contains("[knn] field 'vector' requires to be an array of numbers"));
455+
}
456+
457+
@SneakyThrows
458+
public void testSearchWithMissingQueryVector() {
459+
createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.FLOAT.getValue());
460+
ingestL2FloatTestData();
461+
Request request = new Request("POST", String.format("/%s/_search", INDEX_NAME));
462+
XContentBuilder builder = XContentFactory.jsonBuilder()
463+
.startObject()
464+
.startObject("query")
465+
.startObject("knn")
466+
.startObject(FIELD_NAME)
467+
.field("k", 4)
468+
.endObject()
469+
.endObject()
470+
.endObject()
471+
.endObject();
472+
request.setJsonEntity(builder.toString());
473+
474+
ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request));
475+
assertEquals(400, ex.getResponse().getStatusLine().getStatusCode());
476+
assertTrue(ex.getMessage().contains("[knn] field 'vector' requires to be non-null and non-empty"));
477+
}
478+
428479
@SneakyThrows
429480
private void ingestL2ByteTestData() {
430481
Byte[] b1 = { 6, 6 };

src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java

+65
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.opensearch.plugins.SearchPlugin;
4040

4141
import java.io.IOException;
42+
import java.util.ArrayList;
4243
import java.util.List;
4344
import java.util.Optional;
4445

@@ -149,6 +150,70 @@ public void testFromXcontent_WithFilter_UnsupportedClusterVersion() throws Excep
149150
expectThrows(IllegalArgumentException.class, () -> KNNQueryBuilder.fromXContent(contentParser));
150151
}
151152

153+
public void testFromXContent_invalidQueryVectorType() throws Exception {
154+
final ClusterService clusterService = mockClusterService(Version.CURRENT);
155+
156+
final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance();
157+
knnClusterUtil.initialize(clusterService);
158+
159+
List<Object> invalidTypeQueryVector = new ArrayList<>();
160+
invalidTypeQueryVector.add(1.5);
161+
invalidTypeQueryVector.add(2.5);
162+
invalidTypeQueryVector.add("a");
163+
invalidTypeQueryVector.add(null);
164+
165+
XContentBuilder builder = XContentFactory.jsonBuilder();
166+
builder.startObject();
167+
builder.startObject(FIELD_NAME);
168+
builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), invalidTypeQueryVector);
169+
builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K);
170+
builder.endObject();
171+
builder.endObject();
172+
XContentParser contentParser = createParser(builder);
173+
contentParser.nextToken();
174+
IllegalArgumentException exception = expectThrows(
175+
IllegalArgumentException.class,
176+
() -> KNNQueryBuilder.fromXContent(contentParser)
177+
);
178+
assertTrue(exception.getMessage().contains("[knn] field 'vector' requires to be an array of numbers"));
179+
}
180+
181+
public void testFromXContent_missingQueryVector() throws Exception {
182+
final ClusterService clusterService = mockClusterService(Version.CURRENT);
183+
184+
final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance();
185+
knnClusterUtil.initialize(clusterService);
186+
187+
// Test without vector field
188+
XContentBuilder builderWithoutVectorField = XContentFactory.jsonBuilder();
189+
builderWithoutVectorField.startObject();
190+
builderWithoutVectorField.startObject(FIELD_NAME);
191+
builderWithoutVectorField.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K);
192+
builderWithoutVectorField.endObject();
193+
builderWithoutVectorField.endObject();
194+
XContentParser contentParserWithoutVectorField = createParser(builderWithoutVectorField);
195+
contentParserWithoutVectorField.nextToken();
196+
IllegalArgumentException exception = expectThrows(
197+
IllegalArgumentException.class,
198+
() -> KNNQueryBuilder.fromXContent(contentParserWithoutVectorField)
199+
);
200+
assertTrue(exception.getMessage().contains("[knn] field 'vector' requires to be non-null and non-empty"));
201+
202+
// Test empty vector field
203+
List<Object> emptyQueryVector = new ArrayList<>();
204+
XContentBuilder builderWithEmptyVector = XContentFactory.jsonBuilder();
205+
builderWithEmptyVector.startObject();
206+
builderWithEmptyVector.startObject(FIELD_NAME);
207+
builderWithEmptyVector.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), emptyQueryVector);
208+
builderWithEmptyVector.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K);
209+
builderWithEmptyVector.endObject();
210+
builderWithEmptyVector.endObject();
211+
XContentParser contentParserWithEmptyVector = createParser(builderWithEmptyVector);
212+
contentParserWithEmptyVector.nextToken();
213+
exception = expectThrows(IllegalArgumentException.class, () -> KNNQueryBuilder.fromXContent(contentParserWithEmptyVector));
214+
assertTrue(exception.getMessage().contains("[knn] field 'vector' requires to be non-null and non-empty"));
215+
}
216+
152217
@Override
153218
protected NamedXContentRegistry xContentRegistry() {
154219
List<NamedXContentRegistry.Entry> list = ClusterModule.getNamedXWriteables();

0 commit comments

Comments
 (0)