17
17
import org .junit .Before ;
18
18
import org .opensearch .client .Request ;
19
19
import org .opensearch .client .Response ;
20
- import org .opensearch .ml .common .MLTaskState ;
21
20
import org .opensearch .ml .utils .TestHelper ;
22
21
23
22
import com .google .common .collect .ImmutableList ;
24
23
import com .jayway .jsonpath .JsonPath ;
25
24
26
25
public class RestMLInferenceIngestProcessorIT extends MLCommonsRestTestCase {
27
26
private final String OPENAI_KEY = System .getenv ("OPENAI_KEY" );
28
- private String modelId ;
27
+ private String openAIChatModelId ;
28
+ private String bedrockEmbeddingModelId ;
29
29
private final String completionModelConnectorEntity = "{\n "
30
30
+ " \" name\" : \" OpenAI text embedding model Connector\" ,\n "
31
31
+ " \" description\" : \" The connector to public OpenAI text embedding model service\" ,\n "
@@ -52,26 +52,58 @@ public class RestMLInferenceIngestProcessorIT extends MLCommonsRestTestCase {
52
52
+ " ]\n "
53
53
+ "}" ;
54
54
55
+ private static final String AWS_ACCESS_KEY_ID = System .getenv ("AWS_ACCESS_KEY_ID" );
56
+ private static final String AWS_SECRET_ACCESS_KEY = System .getenv ("AWS_SECRET_ACCESS_KEY" );
57
+ private static final String AWS_SESSION_TOKEN = System .getenv ("AWS_SESSION_TOKEN" );
58
+ private static final String GITHUB_CI_AWS_REGION = "us-west-2" ;
59
+
60
+ private final String bedrockEmbeddingModelConnectorEntity = "{\n "
61
+ + " \" name\" : \" Amazon Bedrock Connector: embedding\" ,\n "
62
+ + " \" description\" : \" The connector to bedrock Titan embedding model\" ,\n "
63
+ + " \" version\" : 1,\n "
64
+ + " \" protocol\" : \" aws_sigv4\" ,\n "
65
+ + " \" parameters\" : {\n "
66
+ + " \" region\" : \" "
67
+ + GITHUB_CI_AWS_REGION
68
+ + "\" ,\n "
69
+ + " \" service_name\" : \" bedrock\" ,\n "
70
+ + " \" model_name\" : \" amazon.titan-embed-text-v1\" \n "
71
+ + " },\n "
72
+ + " \" credential\" : {\n "
73
+ + " \" access_key\" : \" "
74
+ + AWS_ACCESS_KEY_ID
75
+ + "\" ,\n "
76
+ + " \" secret_key\" : \" "
77
+ + AWS_SECRET_ACCESS_KEY
78
+ + "\" ,\n "
79
+ + " \" session_token\" : \" "
80
+ + AWS_SESSION_TOKEN
81
+ + "\" \n "
82
+ + " },\n "
83
+ + " \" actions\" : [\n "
84
+ + " {\n "
85
+ + " \" action_type\" : \" predict\" ,\n "
86
+ + " \" method\" : \" POST\" ,\n "
87
+ + " \" url\" : \" https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke\" ,\n "
88
+ + " \" headers\" : {\n "
89
+ + " \" content-type\" : \" application/json\" ,\n "
90
+ + " \" x-amz-content-sha256\" : \" required\" \n "
91
+ + " },\n "
92
+ + " \" request_body\" : \" { \\ \" inputText\\ \" : \\ \" ${parameters.input}\\ \" }\" ,\n "
93
+ + " \" pre_process_function\" : \" connector.pre_process.bedrock.embedding\" ,\n "
94
+ + " \" post_process_function\" : \" connector.post_process.bedrock.embedding\" \n "
95
+ + " }\n "
96
+ + " ]\n "
97
+ + "}" ;
98
+
55
99
@ Before
56
100
public void setup () throws IOException , InterruptedException {
57
101
RestMLRemoteInferenceIT .disableClusterConnectorAccessControl ();
58
102
Thread .sleep (20000 );
59
-
60
- // create connectors for OPEN AI and register model
61
- Response response = RestMLRemoteInferenceIT .createConnector (completionModelConnectorEntity );
62
- Map responseMap = parseResponseToMap (response );
63
- String openAIConnectorId = (String ) responseMap .get ("connector_id" );
64
- response = RestMLRemoteInferenceIT .registerRemoteModel ("openAI-GPT-3.5 chat model" , openAIConnectorId );
65
- responseMap = parseResponseToMap (response );
66
- String taskId = (String ) responseMap .get ("task_id" );
67
- waitForTask (taskId , MLTaskState .COMPLETED );
68
- response = RestMLRemoteInferenceIT .getTask (taskId );
69
- responseMap = parseResponseToMap (response );
70
- this .modelId = (String ) responseMap .get ("model_id" );
71
- response = RestMLRemoteInferenceIT .deployRemoteModel (modelId );
72
- responseMap = parseResponseToMap (response );
73
- taskId = (String ) responseMap .get ("task_id" );
74
- waitForTask (taskId , MLTaskState .COMPLETED );
103
+ String openAIChatModelName = "openAI-GPT-3.5 chat model " + randomAlphaOfLength (5 );
104
+ this .openAIChatModelId = registerRemoteModel (completionModelConnectorEntity , openAIChatModelName , true );
105
+ String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength (5 );
106
+ this .bedrockEmbeddingModelId = registerRemoteModel (bedrockEmbeddingModelConnectorEntity , bedrockEmbeddingModelName , true );
75
107
}
76
108
77
109
public void testMLInferenceProcessorWithObjectFieldType () throws Exception {
@@ -82,7 +114,7 @@ public void testMLInferenceProcessorWithObjectFieldType() throws Exception {
82
114
+ " {\n "
83
115
+ " \" ml_inference\" : {\n "
84
116
+ " \" model_id\" : \" "
85
- + this .modelId
117
+ + this .openAIChatModelId
86
118
+ "\" ,\n "
87
119
+ " \" input_map\" : [\n "
88
120
+ " {\n "
@@ -141,7 +173,7 @@ public void testMLInferenceProcessorWithNestedFieldType() throws Exception {
141
173
+ " {\n "
142
174
+ " \" ml_inference\" : {\n "
143
175
+ " \" model_id\" : \" "
144
- + this .modelId
176
+ + this .openAIChatModelId
145
177
+ "\" ,\n "
146
178
+ " \" input_map\" : [\n "
147
179
+ " {\n "
@@ -228,6 +260,96 @@ public void testMLInferenceProcessorWithNestedFieldType() throws Exception {
228
260
Assert .assertEquals (0.014352738 , (Double ) embedding4 .get (0 ), 0.005 );
229
261
}
230
262
263
+ public void testMLInferenceProcessorWithForEachProcessor () throws Exception {
264
+ String indexName = "my_books" ;
265
+ String pipelineName = "my_books_bedrock_embedding_pipeline" ;
266
+ String createIndexRequestBody = "{\n "
267
+ + " \" settings\" : {\n "
268
+ + " \" index\" : {\n "
269
+ + " \" default_pipeline\" : \" "
270
+ + pipelineName
271
+ + "\" \n "
272
+ + " }\n "
273
+ + " },\n "
274
+ + " \" mappings\" : {\n "
275
+ + " \" properties\" : {\n "
276
+ + " \" books\" : {\n "
277
+ + " \" type\" : \" nested\" ,\n "
278
+ + " \" properties\" : {\n "
279
+ + " \" title_embedding\" : {\n "
280
+ + " \" type\" : \" float\" \n "
281
+ + " },\n "
282
+ + " \" title\" : {\n "
283
+ + " \" type\" : \" text\" \n "
284
+ + " },\n "
285
+ + " \" description\" : {\n "
286
+ + " \" type\" : \" text\" \n "
287
+ + " }\n "
288
+ + " }\n "
289
+ + " }\n "
290
+ + " }\n "
291
+ + " }\n "
292
+ + "}" ;
293
+ createIndex (indexName , createIndexRequestBody );
294
+
295
+ String createPipelineRequestBody = "{\n "
296
+ + " \" description\" : \" Test bedrock embeddings\" ,\n "
297
+ + " \" processors\" : [\n "
298
+ + " {\n "
299
+ + " \" foreach\" : {\n "
300
+ + " \" field\" : \" books\" ,\n "
301
+ + " \" processor\" : {\n "
302
+ + " \" ml_inference\" : {\n "
303
+ + " \" model_id\" : \" "
304
+ + this .bedrockEmbeddingModelId
305
+ + "\" ,\n "
306
+ + " \" input_map\" : [\n "
307
+ + " {\n "
308
+ + " \" input\" : \" _ingest._value.title\" \n "
309
+ + " }\n "
310
+ + " ],\n "
311
+ + " \" output_map\" : [\n "
312
+ + " {\n "
313
+ + " \" _ingest._value.title_embedding\" : \" $.embedding\" \n "
314
+ + " }\n "
315
+ + " ],\n "
316
+ + " \" ignore_missing\" : false,\n "
317
+ + " \" ignore_failure\" : false\n "
318
+ + " }\n "
319
+ + " }\n "
320
+ + " }\n "
321
+ + " }\n "
322
+ + " ]\n "
323
+ + "}" ;
324
+ createPipelineProcessor (createPipelineRequestBody , pipelineName );
325
+
326
+ // Skip test if key is null
327
+ if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null ) {
328
+ return ;
329
+ }
330
+ String uploadDocumentRequestBody = "{\n "
331
+ + " \" books\" : [{\n "
332
+ + " \" title\" : \" first book\" ,\n "
333
+ + " \" description\" : \" This is first book\" \n "
334
+ + " },\n "
335
+ + " {\n "
336
+ + " \" title\" : \" second book\" ,\n "
337
+ + " \" description\" : \" This is second book\" \n "
338
+ + " }\n "
339
+ + " ]\n "
340
+ + "}" ;
341
+ uploadDocument (indexName , "1" , uploadDocumentRequestBody );
342
+ Map document = getDocument (indexName , "1" );
343
+
344
+ List embeddingList = JsonPath .parse (document ).read ("_source.books[*].title_embedding" );
345
+ Assert .assertEquals (2 , embeddingList .size ());
346
+
347
+ List embedding1 = JsonPath .parse (document ).read ("_source.books[0].title_embedding" );
348
+ Assert .assertEquals (1536 , embedding1 .size ());
349
+ List embedding2 = JsonPath .parse (document ).read ("_source.books[1].title_embedding" );
350
+ Assert .assertEquals (1536 , embedding2 .size ());
351
+ }
352
+
231
353
protected void createPipelineProcessor (String requestBody , final String pipelineName ) throws Exception {
232
354
Response pipelineCreateResponse = TestHelper
233
355
.makeRequest (
0 commit comments