11
11
12
12
package org .opensearch .knn .plugin .action ;
13
13
14
+ import lombok .SneakyThrows ;
14
15
import org .apache .http .util .EntityUtils ;
15
16
import org .opensearch .action .search .SearchResponse ;
16
17
import org .opensearch .client .Request ;
19
20
import org .opensearch .core .xcontent .XContentParser ;
20
21
import org .opensearch .common .xcontent .XContentType ;
21
22
import org .opensearch .knn .KNNRestTestCase ;
22
- import org .opensearch .knn .index .SpaceType ;
23
- import org .opensearch .knn .index .util .KNNEngine ;
24
23
import org .opensearch .knn .indices .Model ;
25
- import org .opensearch .knn .indices .ModelMetadata ;
26
- import org .opensearch .knn .indices .ModelState ;
27
24
import org .opensearch .knn .plugin .KNNPlugin ;
28
25
import org .opensearch .core .rest .RestStatus ;
29
26
import org .opensearch .search .SearchHit ;
30
27
31
- import java .io .IOException ;
32
28
import java .util .Arrays ;
33
29
import java .util .HashMap ;
34
30
import java .util .List ;
35
31
import java .util .Map ;
36
32
37
33
import static org .opensearch .knn .common .KNNConstants .MODELS ;
34
+ import static org .opensearch .knn .common .KNNConstants .MODEL_INDEX_NAME ;
38
35
import static org .opensearch .knn .common .KNNConstants .PARAM_SIZE ;
39
36
import static org .opensearch .knn .common .KNNConstants .SEARCH_MODEL_MAX_SIZE ;
40
37
import static org .opensearch .knn .common .KNNConstants .SEARCH_MODEL_MIN_SIZE ;
47
44
48
45
public class RestSearchModelHandlerIT extends KNNRestTestCase {
49
46
50
- private ModelMetadata getModelMetadata () {
51
- return new ModelMetadata (KNNEngine .DEFAULT , SpaceType .DEFAULT , 4 , ModelState .CREATED , "2021-03-27" , "test model" , "" , "" );
52
- }
53
-
54
- public void testNotSupportedParams () throws IOException {
55
- createModelSystemIndex ();
47
+ public void testSearch_whenUnSupportedParamsPassed_thenFail () {
56
48
String restURI = String .join ("/" , KNNPlugin .KNN_BASE_URI , MODELS , "_search" );
57
49
Map <String , String > invalidParams = new HashMap <>();
58
50
invalidParams .put ("index" , "index-name" );
@@ -61,27 +53,31 @@ public void testNotSupportedParams() throws IOException {
61
53
expectThrows (ResponseException .class , () -> client ().performRequest (request ));
62
54
}
63
55
64
- public void testNoModelExists () throws IOException {
65
- createModelSystemIndex ();
56
+ @ SneakyThrows
57
+ public void testSearch_whenNoModelExists_thenReturnEmptyResults () {
58
+ // Currently, if the model index exists, we will return empty hits. If it does not exist, we will
59
+ // throw an exception. This is somewhat of a bug considering that the model index is supposed to be
60
+ // an implementation detail abstracted away from the user. However, in order to test, we need to handle
61
+ // the 2 different scenarios
66
62
String restURI = String .join ("/" , KNNPlugin .KNN_BASE_URI , MODELS , "_search" );
67
63
Request request = new Request ("GET" , restURI );
68
64
request .setJsonEntity ("{\n " + " \" query\" : {\n " + " \" match_all\" : {}\n " + " }\n " + "}" );
69
-
70
- Response response = client ().performRequest (request );
71
- assertEquals (RestStatus .OK , RestStatus .fromCode (response .getStatusLine ().getStatusCode ()));
72
-
73
- String responseBody = EntityUtils .toString (response .getEntity ());
74
- assertNotNull (responseBody );
75
-
76
- XContentParser parser = createParser (XContentType .JSON .xContent (), responseBody );
77
- SearchResponse searchResponse = SearchResponse .fromXContent (parser );
78
- assertNotNull (searchResponse );
79
- assertEquals (searchResponse .getHits ().getHits ().length , 0 );
80
-
65
+ if (!systemIndexExists (MODEL_INDEX_NAME )) {
66
+ ResponseException ex = expectThrows (ResponseException .class , () -> client ().performRequest (request ));
67
+ assertEquals (RestStatus .NOT_FOUND .getStatus (), ex .getResponse ().getStatusLine ().getStatusCode ());
68
+ } else {
69
+ Response response = client ().performRequest (request );
70
+ assertEquals (RestStatus .OK , RestStatus .fromCode (response .getStatusLine ().getStatusCode ()));
71
+ String responseBody = EntityUtils .toString (response .getEntity ());
72
+ assertNotNull (responseBody );
73
+ XContentParser parser = createParser (XContentType .JSON .xContent (), responseBody );
74
+ SearchResponse searchResponse = SearchResponse .fromXContent (parser );
75
+ assertNotNull (searchResponse );
76
+ assertEquals (searchResponse .getHits ().getHits ().length , 0 );
77
+ }
81
78
}
82
79
83
- public void testSizeValidationFailsInvalidSize () throws IOException {
84
- createModelSystemIndex ();
80
+ public void testSearch_whenInvalidSizePassed_thenFail () {
85
81
for (Integer invalidSize : Arrays .asList (SEARCH_MODEL_MIN_SIZE - 1 , SEARCH_MODEL_MAX_SIZE + 1 )) {
86
82
String restURI = String .join ("/" , KNNPlugin .KNN_BASE_URI , MODELS , "_search?" + PARAM_SIZE + "=" + invalidSize );
87
83
Request request = new Request ("GET" , restURI );
@@ -101,8 +97,8 @@ public void testSizeValidationFailsInvalidSize() throws IOException {
101
97
102
98
}
103
99
104
- public void testSearchModelExists () throws Exception {
105
- createModelSystemIndex ();
100
+ @ SneakyThrows
101
+ public void testSearch_whenModelExists_thenSuccess () {
106
102
String trainingIndex = "irrelevant-index" ;
107
103
String trainingFieldName = "train-field" ;
108
104
int dimension = 8 ;
@@ -151,7 +147,6 @@ public void testSearchModelExists() throws Exception {
151
147
}
152
148
153
149
public void testSearchModelWithoutSource () throws Exception {
154
- createModelSystemIndex ();
155
150
String trainingIndex = "irrelevant-index" ;
156
151
String trainingFieldName = "train-field" ;
157
152
int dimension = 8 ;
@@ -192,7 +187,6 @@ public void testSearchModelWithoutSource() throws Exception {
192
187
}
193
188
194
189
public void testSearchModelWithSourceFilteringIncludes () throws Exception {
195
- createModelSystemIndex ();
196
190
String trainingIndex = "irrelevant-index" ;
197
191
String trainingFieldName = "train-field" ;
198
192
int dimension = 8 ;
@@ -244,7 +238,6 @@ public void testSearchModelWithSourceFilteringIncludes() throws Exception {
244
238
}
245
239
246
240
public void testSearchModelWithSourceFilteringExcludes () throws Exception {
247
- createModelSystemIndex ();
248
241
String trainingIndex = "irrelevant-index" ;
249
242
String trainingFieldName = "train-field" ;
250
243
int dimension = 8 ;
0 commit comments