Skip to content

Commit 4512e0a

Browse files
authored
fix flaky test (opensearch-project#3598)
Signed-off-by: Mingshi Liu <mingshl@amazon.com>
1 parent 26fc493 commit 4512e0a

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchRequestProcessorIT.java

+8-3
Original file line numberDiff line numberDiff line change
@@ -170,21 +170,27 @@ public void setup() throws Exception {
170170
+ " \"properties\": {\n"
171171
+ " \"diary_embedding_size\": {\n"
172172
+ " \"type\": \"keyword\"\n"
173+
+ " },\n"
174+
+ " \"diary_embedding_size_int\": {\n"
175+
+ " \"type\": \"integer\"\n"
173176
+ " }\n"
174177
+ " }\n"
175178
+ " }\n"
176179
+ "}";
180+
177181
String uploadDocumentRequestBodyDoc1 = "{\n"
178182
+ " \"id\": 1,\n"
179183
+ " \"diary\": [\"happy\",\"first day at school\"],\n"
180184
+ " \"diary_embedding_size\": \"1536\",\n" // embedding size for ada model
185+
+ " \"diary_embedding_size_int\": 1536,\n"
181186
+ " \"weather\": \"rainy\"\n"
182187
+ " }";
183188

184189
String uploadDocumentRequestBodyDoc2 = "{\n"
185190
+ " \"id\": 2,\n"
186191
+ " \"diary\": [\"bored\",\"at home\"],\n"
187192
+ " \"diary_embedding_size\": \"768\",\n" // embedding size for local text embedding model
193+
+ " \"diary_embedding_size_int\": 768,\n"
188194
+ " \"weather\": \"sunny\"\n"
189195
+ " }";
190196

@@ -389,7 +395,7 @@ public void testMLInferenceProcessorRemoteModelOptionalInputs() throws Exception
389395
+ " \"model_id\": \""
390396
+ this.bedrockMultiModalEmbeddingModelId
391397
+ "\",\n"
392-
+ " \"query_template\": \"{\\\"size\\\": 2,\\\"query\\\": {\\\"range\\\": {\\\"diary_embedding_size\\\": {\\\"gte\\\": ${modelPrediction}}}}}\",\n"
398+
+ " \"query_template\": \"{\\\"size\\\": 2,\\\"query\\\": {\\\"range\\\": {\\\"diary_embedding_size_int\\\": {\\\"gte\\\": ${modelPrediction}}}}}\",\n"
393399
+ " \"optional_input_map\": [\n"
394400
+ " {\n"
395401
+ " \"inputText\": \"query.term.diary.value\",\n"
@@ -415,9 +421,8 @@ public void testMLInferenceProcessorRemoteModelOptionalInputs() throws Exception
415421
createSearchPipelineProcessor(createPipelineRequestBody, pipelineName);
416422

417423
Map response = searchWithPipeline(client(), index_name, pipelineName, query);
418-
419424
assertEquals((int) JsonPath.parse(response).read("$.hits.hits.length()"), 1);
420-
Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary_embedding_size"), "1536");
425+
assertEquals((double) JsonPath.parse(response).read("$.hits.hits[0]._source.diary_embedding_size_int"), 1536.0, 0.0001);
421426
}
422427

423428
/**

plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java

+3-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
*/
55
package org.opensearch.ml.rest;
66

7-
import static org.junit.Assert.assertEquals;
87
import static org.opensearch.ml.common.MLModel.MODEL_ID_FIELD;
98
import static org.opensearch.ml.utils.TestData.SENTENCE_TRANSFORMER_MODEL_URL;
109
import static org.opensearch.ml.utils.TestHelper.makeRequest;
@@ -439,7 +438,7 @@ public void testMLInferenceProcessorRemoteModelOptionalInputs() throws Exception
439438
return;
440439
}
441440
String createPipelineRequestBody = "{\n"
442-
+ " \"response\": [\n"
441+
+ " \"response_processors\": [\n"
443442
+ " {\n"
444443
+ " \"ml_inference\": {\n"
445444
+ " \"tag\": \"ml_inference\",\n"
@@ -449,7 +448,7 @@ public void testMLInferenceProcessorRemoteModelOptionalInputs() throws Exception
449448
+ "\",\n"
450449
+ " \"optional_input_map\": [\n"
451450
+ " {\n"
452-
+ " \"inputText\": \"diary\",\n"
451+
+ " \"inputText\": \"diary[0]\",\n"
453452
+ " \"inputImage\": \"diary_image\"\n"
454453
+ " }\n"
455454
+ " ],\n"
@@ -474,7 +473,7 @@ public void testMLInferenceProcessorRemoteModelOptionalInputs() throws Exception
474473
Map response = searchWithPipeline(client(), index_name, pipelineName, query);
475474

476475
assertEquals((int) JsonPath.parse(response).read("$.hits.hits.length()"), 1);
477-
Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.multi_modal_embedding.length()"), "1024");
476+
assertEquals((int) JsonPath.parse(response).read("$.hits.hits[0]._source.multi_modal_embedding.length()"), 1024);
478477
}
479478

480479
/**

0 commit comments

Comments
 (0)