Skip to content

Commit 2c11e7f

Browse files
authored
fix error of ML inference processor in foreach processor (opensearch-project#2474)
* fix error of ML inference processor in foreach processor Signed-off-by: Yaliang Wu <ylwu@amazon.com> * add IT Signed-off-by: Yaliang Wu <ylwu@amazon.com> --------- Signed-off-by: Yaliang Wu <ylwu@amazon.com>
1 parent 0722df1 commit 2c11e7f

File tree

5 files changed

+214
-23
lines changed

5 files changed

+214
-23
lines changed

plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java

+5-2
Original file line numberDiff line numberDiff line change
@@ -322,12 +322,15 @@ private void appendFieldValue(
322322

323323
modelOutputValue = getModelOutputValue(modelTensorOutput, modelOutputFieldName, ignoreMissing);
324324

325-
List<String> dotPathsInArray = writeNewDotPathForNestedObject(ingestDocument.getSourceAndMetadata(), newDocumentFieldName);
325+
Map<String, Object> ingestDocumentSourceAndMetaData = new HashMap<>();
326+
ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata());
327+
ingestDocumentSourceAndMetaData.put(IngestDocument.INGEST_KEY, ingestDocument.getIngestMetadata());
328+
List<String> dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName);
326329

327330
if (dotPathsInArray.size() == 1) {
328331
ValueSource ingestValue = ValueSource.wrap(modelOutputValue, scriptService);
329332
TemplateScript.Factory ingestField = ConfigurationUtils
330-
.compileTemplate(TYPE, tag, newDocumentFieldName, newDocumentFieldName, scriptService);
333+
.compileTemplate(TYPE, tag, dotPathsInArray.get(0), dotPathsInArray.get(0), scriptService);
331334
ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing);
332335
} else {
333336
if (!(modelOutputValue instanceof List)) {

plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java

+39
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
import static org.opensearch.ml.processor.MLInferenceIngestProcessor.DEFAULT_OUTPUT_FIELD_NAME;
1111

1212
import java.nio.ByteBuffer;
13+
import java.time.ZonedDateTime;
1314
import java.util.ArrayList;
1415
import java.util.Arrays;
1516
import java.util.HashMap;
1617
import java.util.List;
1718
import java.util.Map;
1819
import java.util.function.BiConsumer;
1920

21+
import org.junit.Assert;
2022
import org.junit.Before;
2123
import org.mockito.ArgumentCaptor;
2224
import org.mockito.Mock;
@@ -1043,6 +1045,43 @@ public void testParseGetDataInTensor_BooleanDataType() {
10431045
assertEquals(List.of(true, false, true), result);
10441046
}
10451047

1048+
public void testWriteNewDotPathForNestedObject() {
1049+
Map<String, Object> docSourceAndMetaData = new HashMap<>();
1050+
docSourceAndMetaData.put("_id", randomAlphaOfLength(5));
1051+
docSourceAndMetaData.put("_index", "my_books");
1052+
1053+
List<Map<String, String>> books = new ArrayList<>();
1054+
Map<String, String> book1 = new HashMap<>();
1055+
book1.put("title", "first book");
1056+
book1.put("description", "this is first book");
1057+
Map<String, String> book2 = new HashMap<>();
1058+
book2.put("title", "second book");
1059+
book2.put("description", "this is second book");
1060+
books.add(book1);
1061+
books.add(book2);
1062+
docSourceAndMetaData.put("books", books);
1063+
1064+
Map<String, Object> ingestMetadata = new HashMap<>();
1065+
ingestMetadata.put("pipeline", "test_pipeline");
1066+
ingestMetadata.put("timeestamp", ZonedDateTime.now());
1067+
Map<String, String> ingestValue = new HashMap<>();
1068+
ingestValue.put("title", "first book");
1069+
ingestValue.put("description", "this is first book");
1070+
ingestMetadata.put("_value", ingestValue);
1071+
docSourceAndMetaData.put("_ingest", ingestMetadata);
1072+
1073+
String path = "_ingest._value.title";
1074+
List<String> newPath = modelExecutor.writeNewDotPathForNestedObject(docSourceAndMetaData, path);
1075+
Assert.assertEquals(1, newPath.size());
1076+
Assert.assertEquals(path, newPath.get(0));
1077+
1078+
String path2 = "books.*.title";
1079+
List<String> newPath2 = modelExecutor.writeNewDotPathForNestedObject(docSourceAndMetaData, path2);
1080+
Assert.assertEquals(2, newPath2.size());
1081+
Assert.assertEquals("books.0.title", newPath2.get(0));
1082+
Assert.assertEquals("books.1.title", newPath2.get(1));
1083+
}
1084+
10461085
private static Map<String, Object> getNestedObjectWithAnotherNestedObjectSource() {
10471086
ArrayList<Object> childDocuments = new ArrayList<>();
10481087

plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java

+21
Original file line numberDiff line numberDiff line change
@@ -960,4 +960,25 @@ public void waitForTask(String taskId, MLTaskState targetState) throws Interrupt
960960
}, CUSTOM_MODEL_TIMEOUT, TimeUnit.SECONDS);
961961
assertTrue(taskDone.get());
962962
}
963+
964+
public String registerRemoteModel(String createConnectorInput, String modelName, boolean deploy) throws IOException,
965+
InterruptedException {
966+
Response response = RestMLRemoteInferenceIT.createConnector(createConnectorInput);
967+
Map responseMap = parseResponseToMap(response);
968+
String connectorId = (String) responseMap.get("connector_id");
969+
response = RestMLRemoteInferenceIT.registerRemoteModel(modelName, modelName, connectorId);
970+
responseMap = parseResponseToMap(response);
971+
String taskId = (String) responseMap.get("task_id");
972+
waitForTask(taskId, MLTaskState.COMPLETED);
973+
response = RestMLRemoteInferenceIT.getTask(taskId);
974+
responseMap = parseResponseToMap(response);
975+
String modelId = (String) responseMap.get("model_id");
976+
if (deploy) {
977+
response = RestMLRemoteInferenceIT.deployRemoteModel(modelId);
978+
responseMap = parseResponseToMap(response);
979+
taskId = (String) responseMap.get("task_id");
980+
waitForTask(taskId, MLTaskState.COMPLETED);
981+
}
982+
return modelId;
983+
}
963984
}

plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java

+142-20
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717
import org.junit.Before;
1818
import org.opensearch.client.Request;
1919
import org.opensearch.client.Response;
20-
import org.opensearch.ml.common.MLTaskState;
2120
import org.opensearch.ml.utils.TestHelper;
2221

2322
import com.google.common.collect.ImmutableList;
2423
import com.jayway.jsonpath.JsonPath;
2524

2625
public class RestMLInferenceIngestProcessorIT extends MLCommonsRestTestCase {
2726
private final String OPENAI_KEY = System.getenv("OPENAI_KEY");
28-
private String modelId;
27+
private String openAIChatModelId;
28+
private String bedrockEmbeddingModelId;
2929
private final String completionModelConnectorEntity = "{\n"
3030
+ " \"name\": \"OpenAI text embedding model Connector\",\n"
3131
+ " \"description\": \"The connector to public OpenAI text embedding model service\",\n"
@@ -52,26 +52,58 @@ public class RestMLInferenceIngestProcessorIT extends MLCommonsRestTestCase {
5252
+ " ]\n"
5353
+ "}";
5454

55+
private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID");
56+
private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY");
57+
private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN");
58+
private static final String GITHUB_CI_AWS_REGION = "us-west-2";
59+
60+
private final String bedrockEmbeddingModelConnectorEntity = "{\n"
61+
+ " \"name\": \"Amazon Bedrock Connector: embedding\",\n"
62+
+ " \"description\": \"The connector to bedrock Titan embedding model\",\n"
63+
+ " \"version\": 1,\n"
64+
+ " \"protocol\": \"aws_sigv4\",\n"
65+
+ " \"parameters\": {\n"
66+
+ " \"region\": \""
67+
+ GITHUB_CI_AWS_REGION
68+
+ "\",\n"
69+
+ " \"service_name\": \"bedrock\",\n"
70+
+ " \"model_name\": \"amazon.titan-embed-text-v1\"\n"
71+
+ " },\n"
72+
+ " \"credential\": {\n"
73+
+ " \"access_key\": \""
74+
+ AWS_ACCESS_KEY_ID
75+
+ "\",\n"
76+
+ " \"secret_key\": \""
77+
+ AWS_SECRET_ACCESS_KEY
78+
+ "\",\n"
79+
+ " \"session_token\": \""
80+
+ AWS_SESSION_TOKEN
81+
+ "\"\n"
82+
+ " },\n"
83+
+ " \"actions\": [\n"
84+
+ " {\n"
85+
+ " \"action_type\": \"predict\",\n"
86+
+ " \"method\": \"POST\",\n"
87+
+ " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke\",\n"
88+
+ " \"headers\": {\n"
89+
+ " \"content-type\": \"application/json\",\n"
90+
+ " \"x-amz-content-sha256\": \"required\"\n"
91+
+ " },\n"
92+
+ " \"request_body\": \"{ \\\"inputText\\\": \\\"${parameters.input}\\\" }\",\n"
93+
+ " \"pre_process_function\": \"connector.pre_process.bedrock.embedding\",\n"
94+
+ " \"post_process_function\": \"connector.post_process.bedrock.embedding\"\n"
95+
+ " }\n"
96+
+ " ]\n"
97+
+ "}";
98+
5599
@Before
56100
public void setup() throws IOException, InterruptedException {
57101
RestMLRemoteInferenceIT.disableClusterConnectorAccessControl();
58102
Thread.sleep(20000);
59-
60-
// create connectors for OPEN AI and register model
61-
Response response = RestMLRemoteInferenceIT.createConnector(completionModelConnectorEntity);
62-
Map responseMap = parseResponseToMap(response);
63-
String openAIConnectorId = (String) responseMap.get("connector_id");
64-
response = RestMLRemoteInferenceIT.registerRemoteModel("openAI-GPT-3.5 chat model", openAIConnectorId);
65-
responseMap = parseResponseToMap(response);
66-
String taskId = (String) responseMap.get("task_id");
67-
waitForTask(taskId, MLTaskState.COMPLETED);
68-
response = RestMLRemoteInferenceIT.getTask(taskId);
69-
responseMap = parseResponseToMap(response);
70-
this.modelId = (String) responseMap.get("model_id");
71-
response = RestMLRemoteInferenceIT.deployRemoteModel(modelId);
72-
responseMap = parseResponseToMap(response);
73-
taskId = (String) responseMap.get("task_id");
74-
waitForTask(taskId, MLTaskState.COMPLETED);
103+
String openAIChatModelName = "openAI-GPT-3.5 chat model " + randomAlphaOfLength(5);
104+
this.openAIChatModelId = registerRemoteModel(completionModelConnectorEntity, openAIChatModelName, true);
105+
String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5);
106+
this.bedrockEmbeddingModelId = registerRemoteModel(bedrockEmbeddingModelConnectorEntity, bedrockEmbeddingModelName, true);
75107
}
76108

77109
public void testMLInferenceProcessorWithObjectFieldType() throws Exception {
@@ -82,7 +114,7 @@ public void testMLInferenceProcessorWithObjectFieldType() throws Exception {
82114
+ " {\n"
83115
+ " \"ml_inference\": {\n"
84116
+ " \"model_id\": \""
85-
+ this.modelId
117+
+ this.openAIChatModelId
86118
+ "\",\n"
87119
+ " \"input_map\": [\n"
88120
+ " {\n"
@@ -141,7 +173,7 @@ public void testMLInferenceProcessorWithNestedFieldType() throws Exception {
141173
+ " {\n"
142174
+ " \"ml_inference\": {\n"
143175
+ " \"model_id\": \""
144-
+ this.modelId
176+
+ this.openAIChatModelId
145177
+ "\",\n"
146178
+ " \"input_map\": [\n"
147179
+ " {\n"
@@ -228,6 +260,96 @@ public void testMLInferenceProcessorWithNestedFieldType() throws Exception {
228260
Assert.assertEquals(0.014352738, (Double) embedding4.get(0), 0.005);
229261
}
230262

263+
public void testMLInferenceProcessorWithForEachProcessor() throws Exception {
264+
String indexName = "my_books";
265+
String pipelineName = "my_books_bedrock_embedding_pipeline";
266+
String createIndexRequestBody = "{\n"
267+
+ " \"settings\": {\n"
268+
+ " \"index\": {\n"
269+
+ " \"default_pipeline\": \""
270+
+ pipelineName
271+
+ "\"\n"
272+
+ " }\n"
273+
+ " },\n"
274+
+ " \"mappings\": {\n"
275+
+ " \"properties\": {\n"
276+
+ " \"books\": {\n"
277+
+ " \"type\": \"nested\",\n"
278+
+ " \"properties\": {\n"
279+
+ " \"title_embedding\": {\n"
280+
+ " \"type\": \"float\"\n"
281+
+ " },\n"
282+
+ " \"title\": {\n"
283+
+ " \"type\": \"text\"\n"
284+
+ " },\n"
285+
+ " \"description\": {\n"
286+
+ " \"type\": \"text\"\n"
287+
+ " }\n"
288+
+ " }\n"
289+
+ " }\n"
290+
+ " }\n"
291+
+ " }\n"
292+
+ "}";
293+
createIndex(indexName, createIndexRequestBody);
294+
295+
String createPipelineRequestBody = "{\n"
296+
+ " \"description\": \"Test bedrock embeddings\",\n"
297+
+ " \"processors\": [\n"
298+
+ " {\n"
299+
+ " \"foreach\": {\n"
300+
+ " \"field\": \"books\",\n"
301+
+ " \"processor\": {\n"
302+
+ " \"ml_inference\": {\n"
303+
+ " \"model_id\": \""
304+
+ this.bedrockEmbeddingModelId
305+
+ "\",\n"
306+
+ " \"input_map\": [\n"
307+
+ " {\n"
308+
+ " \"input\": \"_ingest._value.title\"\n"
309+
+ " }\n"
310+
+ " ],\n"
311+
+ " \"output_map\": [\n"
312+
+ " {\n"
313+
+ " \"_ingest._value.title_embedding\": \"$.embedding\"\n"
314+
+ " }\n"
315+
+ " ],\n"
316+
+ " \"ignore_missing\": false,\n"
317+
+ " \"ignore_failure\": false\n"
318+
+ " }\n"
319+
+ " }\n"
320+
+ " }\n"
321+
+ " }\n"
322+
+ " ]\n"
323+
+ "}";
324+
createPipelineProcessor(createPipelineRequestBody, pipelineName);
325+
326+
// Skip test if key is null
327+
if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) {
328+
return;
329+
}
330+
String uploadDocumentRequestBody = "{\n"
331+
+ " \"books\": [{\n"
332+
+ " \"title\": \"first book\",\n"
333+
+ " \"description\": \"This is first book\"\n"
334+
+ " },\n"
335+
+ " {\n"
336+
+ " \"title\": \"second book\",\n"
337+
+ " \"description\": \"This is second book\"\n"
338+
+ " }\n"
339+
+ " ]\n"
340+
+ "}";
341+
uploadDocument(indexName, "1", uploadDocumentRequestBody);
342+
Map document = getDocument(indexName, "1");
343+
344+
List embeddingList = JsonPath.parse(document).read("_source.books[*].title_embedding");
345+
Assert.assertEquals(2, embeddingList.size());
346+
347+
List embedding1 = JsonPath.parse(document).read("_source.books[0].title_embedding");
348+
Assert.assertEquals(1536, embedding1.size());
349+
List embedding2 = JsonPath.parse(document).read("_source.books[1].title_embedding");
350+
Assert.assertEquals(1536, embedding2.size());
351+
}
352+
231353
protected void createPipelineProcessor(String requestBody, final String pipelineName) throws Exception {
232354
Response pipelineCreateResponse = TestHelper
233355
.makeRequest(

plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java

+7-1
Original file line numberDiff line numberDiff line change
@@ -779,8 +779,14 @@ public static Response createConnector(String input) throws IOException {
779779
}
780780

781781
public static Response registerRemoteModel(String name, String connectorId) throws IOException {
782+
return registerRemoteModel("remote_model_group", name, connectorId);
783+
}
784+
785+
public static Response registerRemoteModel(String modelGroupName, String name, String connectorId) throws IOException {
782786
String registerModelGroupEntity = "{\n"
783-
+ " \"name\": \"remote_model_group\",\n"
787+
+ " \"name\": \""
788+
+ modelGroupName
789+
+ "\",\n"
784790
+ " \"description\": \"This is an example description\"\n"
785791
+ "}";
786792
Response response = TestHelper

0 commit comments

Comments
 (0)