Skip to content

Commit ae7ef1b

Browse files
committed
fix flaky test
Signed-off-by: Mingshi Liu <mingshl@amazon.com>
1 parent 26fc493 commit ae7ef1b

File tree

2 files changed

+68
-64
lines changed

2 files changed

+68
-64
lines changed

plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchRequestProcessorIT.java

+37-32
Original file line numberDiff line numberDiff line change
@@ -170,21 +170,27 @@ public void setup() throws Exception {
170170
+ " \"properties\": {\n"
171171
+ " \"diary_embedding_size\": {\n"
172172
+ " \"type\": \"keyword\"\n"
173+
+ " },\n"
174+
+ " \"diary_embedding_size_int\": {\n"
175+
+ " \"type\": \"integer\"\n"
173176
+ " }\n"
174177
+ " }\n"
175178
+ " }\n"
176179
+ "}";
180+
177181
String uploadDocumentRequestBodyDoc1 = "{\n"
178182
+ " \"id\": 1,\n"
179183
+ " \"diary\": [\"happy\",\"first day at school\"],\n"
180184
+ " \"diary_embedding_size\": \"1536\",\n" // embedding size for ada model
185+
+ " \"diary_embedding_size_int\": 1536,\n"
181186
+ " \"weather\": \"rainy\"\n"
182187
+ " }";
183188

184189
String uploadDocumentRequestBodyDoc2 = "{\n"
185190
+ " \"id\": 2,\n"
186191
+ " \"diary\": [\"bored\",\"at home\"],\n"
187192
+ " \"diary_embedding_size\": \"768\",\n" // embedding size for local text embedding model
193+
+ " \"diary_embedding_size_int\": 768,\n"
188194
+ " \"weather\": \"sunny\"\n"
189195
+ " }";
190196

@@ -389,7 +395,7 @@ public void testMLInferenceProcessorRemoteModelOptionalInputs() throws Exception
389395
+ " \"model_id\": \""
390396
+ this.bedrockMultiModalEmbeddingModelId
391397
+ "\",\n"
392-
+ " \"query_template\": \"{\\\"size\\\": 2,\\\"query\\\": {\\\"range\\\": {\\\"diary_embedding_size\\\": {\\\"gte\\\": ${modelPrediction}}}}}\",\n"
398+
+ " \"query_template\": \"{\\\"size\\\": 2,\\\"query\\\": {\\\"range\\\": {\\\"diary_embedding_size_int\\\": {\\\"gte\\\": ${modelPrediction}}}}}\",\n"
393399
+ " \"optional_input_map\": [\n"
394400
+ " {\n"
395401
+ " \"inputText\": \"query.term.diary.value\",\n"
@@ -415,9 +421,8 @@ public void testMLInferenceProcessorRemoteModelOptionalInputs() throws Exception
415421
createSearchPipelineProcessor(createPipelineRequestBody, pipelineName);
416422

417423
Map response = searchWithPipeline(client(), index_name, pipelineName, query);
418-
419424
assertEquals((int) JsonPath.parse(response).read("$.hits.hits.length()"), 1);
420-
Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary_embedding_size"), "1536");
425+
assertEquals((double) JsonPath.parse(response).read("$.hits.hits[0]._source.diary_embedding_size_int"), 1536.0, 0.0001);
421426
}
422427

423428
/**
@@ -447,35 +452,35 @@ public void testMLInferenceProcessorLocalModel() throws Exception {
447452
});
448453

449454
String createPipelineRequestBody = "{\n"
450-
+ " \"request_processors\": [\n"
451-
+ " {\n"
452-
+ " \"ml_inference\": {\n"
453-
+ " \"tag\": \"ml_inference\",\n"
454-
+ " \"description\": \"This processor is going to run ml inference during search request\",\n"
455-
+ " \"model_id\": \""
456-
+ this.localModelId
457-
+ "\",\n"
458-
+ " \"model_input\": \"{ \\\"text_docs\\\": [\\\"${ml_inference.text_docs}\\\"] ,\\\"return_number\\\": true,\\\"target_response\\\": [\\\"sentence_embedding\\\"]}\",\n"
459-
+ " \"function_name\": \"text_embedding\",\n"
460-
+ " \"full_response_path\": true,\n"
461-
+ " \"input_map\": [\n"
462-
+ " {\n"
463-
+ " \"text_docs\": \"query.term.diary_embedding_size.value\"\n"
464-
+ " }\n"
465-
+ " ],\n"
466-
+ " \"output_map\": [\n"
467-
+ " {\n"
468-
+ " \"query.term.diary_embedding_size.value\": \"$.inference_results[0].output[0].data.length()\"\n"
469-
+ " }\n"
470-
+ " ],\n"
471-
+ " \n"
472-
+ " \"ignore_missing\":false,\n"
473-
+ " \"ignore_failure\": false\n"
474-
+ " \n"
475-
+ " }\n"
476-
+ " }\n"
477-
+ " ]\n"
478-
+ "}";
455+
+ " \"request_processors\": [\n"
456+
+ " {\n"
457+
+ " \"ml_inference\": {\n"
458+
+ " \"tag\": \"ml_inference\",\n"
459+
+ " \"description\": \"This processor is going to run ml inference during search request\",\n"
460+
+ " \"model_id\": \""
461+
+ this.localModelId
462+
+ "\",\n"
463+
+ " \"model_input\": \"{ \\\"text_docs\\\": [\\\"${ml_inference.text_docs}\\\"] ,\\\"return_number\\\": true,\\\"target_response\\\": [\\\"sentence_embedding\\\"]}\",\n"
464+
+ " \"function_name\": \"text_embedding\",\n"
465+
+ " \"full_response_path\": true,\n"
466+
+ " \"input_map\": [\n"
467+
+ " {\n"
468+
+ " \"text_docs\": \"query.term.diary_embedding_size.value\"\n"
469+
+ " }\n"
470+
+ " ],\n"
471+
+ " \"output_map\": [\n"
472+
+ " {\n"
473+
+ " \"query.term.diary_embedding_size.value\": \"$.inference_results[0].output[0].data.length()\"\n"
474+
+ " }\n"
475+
+ " ],\n"
476+
+ " \n"
477+
+ " \"ignore_missing\":false,\n"
478+
+ " \"ignore_failure\": false\n"
479+
+ " \n"
480+
+ " }\n"
481+
+ " }\n"
482+
+ " ]\n"
483+
+ "}";
479484

480485
String index_name = "daily_index";
481486
String pipelineName = "diary_embedding_pipeline_local";

plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java

+31-32
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
*/
55
package org.opensearch.ml.rest;
66

7-
import static org.junit.Assert.assertEquals;
87
import static org.opensearch.ml.common.MLModel.MODEL_ID_FIELD;
98
import static org.opensearch.ml.utils.TestData.SENTENCE_TRANSFORMER_MODEL_URL;
109
import static org.opensearch.ml.utils.TestHelper.makeRequest;
@@ -439,7 +438,7 @@ public void testMLInferenceProcessorRemoteModelOptionalInputs() throws Exception
439438
return;
440439
}
441440
String createPipelineRequestBody = "{\n"
442-
+ " \"response\": [\n"
441+
+ " \"response_processors\": [\n"
443442
+ " {\n"
444443
+ " \"ml_inference\": {\n"
445444
+ " \"tag\": \"ml_inference\",\n"
@@ -449,7 +448,7 @@ public void testMLInferenceProcessorRemoteModelOptionalInputs() throws Exception
449448
+ "\",\n"
450449
+ " \"optional_input_map\": [\n"
451450
+ " {\n"
452-
+ " \"inputText\": \"diary\",\n"
451+
+ " \"inputText\": \"diary[0]\",\n"
453452
+ " \"inputImage\": \"diary_image\"\n"
454453
+ " }\n"
455454
+ " ],\n"
@@ -474,7 +473,7 @@ public void testMLInferenceProcessorRemoteModelOptionalInputs() throws Exception
474473
Map response = searchWithPipeline(client(), index_name, pipelineName, query);
475474

476475
assertEquals((int) JsonPath.parse(response).read("$.hits.hits.length()"), 1);
477-
Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.multi_modal_embedding.length()"), "1024");
476+
assertEquals((int) JsonPath.parse(response).read("$.hits.hits[0]._source.multi_modal_embedding.length()"), 1024);
478477
}
479478

480479
/**
@@ -614,34 +613,34 @@ public void testMLInferenceProcessorLocalModel() throws Exception {
614613
});
615614

616615
String createPipelineRequestBody = "{\n"
617-
+ " \"response_processors\": [\n"
618-
+ " {\n"
619-
+ " \"ml_inference\": {\n"
620-
+ " \"tag\": \"ml_inference\",\n"
621-
+ " \"description\": \"This processor is going to run ml inference during search request\",\n"
622-
+ " \"model_id\": \""
623-
+ this.localModelId
624-
+ "\",\n"
625-
+ " \"model_input\": \"{ \\\"text_docs\\\": [\\\"${ml_inference.text_docs}\\\"] ,\\\"return_number\\\": true,\\\"target_response\\\": [\\\"sentence_embedding\\\"]}\",\n"
626-
+ " \"function_name\": \"text_embedding\",\n"
627-
+ " \"full_response_path\": true,\n"
628-
+ " \"input_map\": [\n"
629-
+ " {\n"
630-
+ " \"input\": \"weather\"\n"
631-
+ " }\n"
632-
+ " ],\n"
633-
+ " \"output_map\": [\n"
634-
+ " {\n"
635-
+ " \"weather_embedding\": \"$.inference_results[0].output[0].data\"\n"
636-
+ " }\n"
637-
+ " ],\n"
638-
+ " \"ignore_missing\": false,\n"
639-
+ " \"ignore_failure\": false\n"
640-
+ " }\n"
641-
+ " }\n"
642-
+ " ]\n"
643-
+ "}";
644-
616+
+ " \"response_processors\": [\n"
617+
+ " {\n"
618+
+ " \"ml_inference\": {\n"
619+
+ " \"tag\": \"ml_inference\",\n"
620+
+ " \"description\": \"This processor is going to run ml inference during search request\",\n"
621+
+ " \"model_id\": \""
622+
+ this.localModelId
623+
+ "\",\n"
624+
+ " \"model_input\": \"{ \\\"text_docs\\\": [\\\"${ml_inference.text_docs}\\\"] ,\\\"return_number\\\": true,\\\"target_response\\\": [\\\"sentence_embedding\\\"]}\",\n"
625+
+ " \"function_name\": \"text_embedding\",\n"
626+
+ " \"full_response_path\": true,\n"
627+
+ " \"input_map\": [\n"
628+
+ " {\n"
629+
+ " \"input\": \"weather\"\n"
630+
+ " }\n"
631+
+ " ],\n"
632+
+ " \"output_map\": [\n"
633+
+ " {\n"
634+
+ " \"weather_embedding\": \"$.inference_results[0].output[0].data\"\n"
635+
+ " }\n"
636+
+ " ],\n"
637+
+ " \"ignore_missing\": false,\n"
638+
+ " \"ignore_failure\": false\n"
639+
+ " }\n"
640+
+ " }\n"
641+
+ " ]\n"
642+
+ "}";
643+
645644
String index_name = "daily_index";
646645
String pipelineName = "weather_embedding_pipeline_local";
647646
createSearchPipelineProcessor(createPipelineRequestBody, pipelineName);

0 commit comments

Comments
 (0)