Skip to content

Commit 1c90e38

Browse files
[Backport-2.x] Backport fix bedrock preprocess func (opensearch-project#2537) (opensearch-project#2542)
* Fix bedrock connector embedding generation issue Signed-off-by: zane-neo <zaniu@amazon.com> * format code Signed-off-by: zane-neo <zaniu@amazon.com> * add IT Signed-off-by: zane-neo <zaniu@amazon.com> * add ITs Signed-off-by: zane-neo <zaniu@amazon.com> * format code Signed-off-by: zane-neo <zaniu@amazon.com> * change input to fix number format exception in local Signed-off-by: zane-neo <zaniu@amazon.com> * Add log to identify the failure IT root cause Signed-off-by: zane-neo <zaniu@amazon.com> * Update plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java Co-authored-by: Yaliang Wu <ylwu@amazon.com> Signed-off-by: zane-neo <zaniu@amazon.com> * address comments Signed-off-by: zane-neo <zaniu@amazon.com> * fix backport incompatibility Signed-off-by: zane-neo <zaniu@amazon.com> --------- Signed-off-by: zane-neo <zaniu@amazon.com> Co-authored-by: Yaliang Wu <ylwu@amazon.com> (cherry picked from commit 210903d) Co-authored-by: zane-neo <zaniu@amazon.com>
1 parent f54abcf commit 1c90e38

File tree

5 files changed

+248
-4
lines changed

5 files changed

+248
-4
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java

+8-4
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ default void executeAction(String action, MLInput mlInput, ActionListener<MLTask
9494

9595
/**
9696
* Calculate the chunk size.
97-
* @param textDocsInputDataSet
97+
* @param textDocsInputDataSet Input dataset in textDocsInputDataSet format.
9898
* @return Tuple of chunk size and step size.
9999
*/
100100
private Tuple<Integer, Integer> calculateChunkSize(String action, TextDocsInputDataSet textDocsInputDataSet) {
@@ -118,11 +118,15 @@ private Tuple<Integer, Integer> calculateChunkSize(String action, TextDocsInputD
118118
throw new IllegalArgumentException("no " + action + " action found");
119119
}
120120
String preProcessFunction = connectorAction.get().getPreProcessFunction();
121-
if (preProcessFunction != null && !MLPreProcessFunction.contains(preProcessFunction)) {
122-
// user defined preprocess script, this case, the chunk size is always equals to text docs length.
121+
if (preProcessFunction == null) {
122+
// default preprocess case, consider this a batch.
123+
return Tuple.tuple(1, textDocsLength);
124+
} else if (MLPreProcessFunction.TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT.equals(preProcessFunction)
125+
|| !MLPreProcessFunction.contains(preProcessFunction)) {
126+
// bedrock and user defined preprocess script, the chunk size is always equals to text docs length.
123127
return Tuple.tuple(textDocsLength, 1);
124128
}
125-
// consider as batch.
129+
// Other cases: non-bedrock and user defined preprocess script, consider as batch.
126130
return Tuple.tuple(1, textDocsLength);
127131
}
128132
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java

+79
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,85 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_userDefinedPre
612612
);
613613
}
614614

615+
@Test
616+
public void executePredict_TextDocsInferenceInput_withoutStepSize_bedRockEmbeddingPreProcessFunction() {
617+
ConnectorAction predictAction = ConnectorAction
618+
.builder()
619+
.actionType(PREDICT)
620+
.method("POST")
621+
.url("http://openai.com/mock")
622+
.requestBody("{\"input\": ${parameters.input}}")
623+
.preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT)
624+
.build();
625+
Map<String, String> credential = ImmutableMap
626+
.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
627+
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "bedrock");
628+
Connector connector = AwsConnector
629+
.awsConnectorBuilder()
630+
.name("test connector")
631+
.version("1")
632+
.protocol("aws_sigv4")
633+
.parameters(parameters)
634+
.credential(credential)
635+
.actions(Arrays.asList(predictAction))
636+
.build();
637+
connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c));
638+
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
639+
Settings settings = Settings.builder().build();
640+
threadContext = new ThreadContext(settings);
641+
when(executor.getClient()).thenReturn(client);
642+
when(client.threadPool()).thenReturn(threadPool);
643+
when(threadPool.getThreadContext()).thenReturn(threadContext);
644+
when(executor.getScriptService()).thenReturn(scriptService);
645+
646+
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build();
647+
executor
648+
.executeAction(
649+
PREDICT.name(),
650+
MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(),
651+
actionListener
652+
);
653+
}
654+
655+
@Test
656+
public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPreprocessFunction() {
657+
ConnectorAction predictAction = ConnectorAction
658+
.builder()
659+
.actionType(ConnectorAction.ActionType.PREDICT)
660+
.method("POST")
661+
.url("http://openai.com/mock")
662+
.requestBody("{\"input\": ${parameters.input}}")
663+
.build();
664+
Map<String, String> credential = ImmutableMap
665+
.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
666+
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "bedrock");
667+
Connector connector = AwsConnector
668+
.awsConnectorBuilder()
669+
.name("test connector")
670+
.version("1")
671+
.protocol("aws_sigv4")
672+
.parameters(parameters)
673+
.credential(credential)
674+
.actions(Arrays.asList(predictAction))
675+
.build();
676+
connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c));
677+
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
678+
Settings settings = Settings.builder().build();
679+
threadContext = new ThreadContext(settings);
680+
when(executor.getClient()).thenReturn(client);
681+
when(client.threadPool()).thenReturn(threadPool);
682+
when(threadPool.getThreadContext()).thenReturn(threadContext);
683+
when(executor.getScriptService()).thenReturn(scriptService);
684+
685+
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build();
686+
executor
687+
.executeAction(
688+
PREDICT.name(),
689+
MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(),
690+
actionListener
691+
);
692+
}
693+
615694
@Test
616695
public void executePredict_whenRetryEnabled_thenInvokeRemoteServiceWithRetry() {
617696
ConnectorAction predictAction = ConnectorAction

plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java

+7
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,13 @@ public Map predictTextEmbedding(String modelId) throws IOException {
896896
return result;
897897
}
898898

899+
public Map predictTextEmbeddingModel(String modelId, MLInput input) throws IOException {
900+
String requestBody = TestHelper.toJsonString(input);
901+
Response response = TestHelper
902+
.makeRequest(client(), "POST", "/_plugins/_ml/_predict/TEXT_EMBEDDING/" + modelId, null, requestBody, null);
903+
return parseResponseToMap(response);
904+
}
905+
899906
public Consumer<Map<String, Object>> verifyTextEmbeddingModelDeployed() {
900907
return (modelProfile) -> {
901908
if (modelProfile.containsKey("model_state")) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.rest;
7+
8+
import java.io.IOException;
9+
import java.nio.file.Files;
10+
import java.nio.file.Path;
11+
import java.util.List;
12+
import java.util.Locale;
13+
import java.util.Map;
14+
15+
import org.junit.Before;
16+
import org.opensearch.ml.common.FunctionName;
17+
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
18+
import org.opensearch.ml.common.input.MLInput;
19+
import org.opensearch.ml.common.utils.StringUtils;
20+
21+
import lombok.SneakyThrows;
22+
23+
public class RestBedRockInferenceIT extends MLCommonsRestTestCase {
24+
private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID");
25+
private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY");
26+
private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN");
27+
private static final String GITHUB_CI_AWS_REGION = "us-west-2";
28+
29+
@SneakyThrows
30+
@Before
31+
public void setup() throws IOException, InterruptedException {
32+
RestMLRemoteInferenceIT.disableClusterConnectorAccessControl();
33+
Thread.sleep(20000);
34+
}
35+
36+
public void test_bedrock_embedding_model() throws Exception {
37+
// Skip test if key is null
38+
if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) {
39+
return;
40+
}
41+
String templates = Files
42+
.readString(
43+
Path
44+
.of(
45+
RestMLPredictionAction.class
46+
.getClassLoader()
47+
.getResource("org/opensearch/ml/rest/templates/BedRockConnectorBodies.json")
48+
.toURI()
49+
)
50+
);
51+
Map<String, Object> templateMap = StringUtils.gson.fromJson(templates, Map.class);
52+
for (Map.Entry<String, Object> templateEntry : templateMap.entrySet()) {
53+
String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5);
54+
String testCaseName = templateEntry.getKey();
55+
String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s", testCaseName);
56+
String modelId = registerRemoteModel(
57+
String
58+
.format(
59+
StringUtils.gson.toJson(templateEntry.getValue()),
60+
GITHUB_CI_AWS_REGION,
61+
AWS_ACCESS_KEY_ID,
62+
AWS_SECRET_ACCESS_KEY,
63+
AWS_SESSION_TOKEN
64+
),
65+
bedrockEmbeddingModelName,
66+
true
67+
);
68+
69+
TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build();
70+
MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build();
71+
Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput);
72+
assertTrue(errorMsg, inferenceResult.containsKey("inference_results"));
73+
List output = (List) inferenceResult.get("inference_results");
74+
assertEquals(errorMsg, 2, output.size());
75+
assertTrue(errorMsg, output.get(0) instanceof Map);
76+
assertTrue(errorMsg, output.get(1) instanceof Map);
77+
validateOutput(errorMsg, (Map) output.get(0));
78+
validateOutput(errorMsg, (Map) output.get(1));
79+
}
80+
}
81+
82+
private void validateOutput(String errorMsg, Map<String, Object> output) {
83+
assertTrue(errorMsg, output.containsKey("output"));
84+
assertTrue(errorMsg, output.get("output") instanceof List);
85+
List outputList = (List) output.get("output");
86+
assertEquals(errorMsg, 1, outputList.size());
87+
assertTrue(errorMsg, outputList.get(0) instanceof Map);
88+
assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data") instanceof List);
89+
assertEquals(errorMsg, 1536, ((List) ((Map<?, ?>) outputList.get(0)).get("data")).size());
90+
}
91+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
{
2+
"without_step_size": {
3+
"name": "Amazon Bedrock Connector: embedding",
4+
"description": "The connector to bedrock Titan embedding model",
5+
"version": 1,
6+
"protocol": "aws_sigv4",
7+
"parameters": {
8+
"region": "%s",
9+
"service_name": "bedrock",
10+
"model_name": "amazon.titan-embed-text-v1"
11+
},
12+
"credential": {
13+
"access_key": "%s",
14+
"secret_key": "%s",
15+
"session_token": "%s"
16+
},
17+
"actions": [
18+
{
19+
"action_type": "predict",
20+
"method": "POST",
21+
"url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke",
22+
"headers": {
23+
"content-type": "application/json",
24+
"x-amz-content-sha256": "required"
25+
},
26+
"request_body": "{ \"inputText\": \"${parameters.inputText}\" }",
27+
"pre_process_function": "connector.pre_process.bedrock.embedding",
28+
"post_process_function": "connector.post_process.bedrock.embedding"
29+
}
30+
]
31+
},
32+
"with_step_size": {
33+
"name": "Amazon Bedrock Connector: embedding",
34+
"description": "The connector to bedrock Titan embedding model",
35+
"version": 1,
36+
"protocol": "aws_sigv4",
37+
"parameters": {
38+
"region": "%s",
39+
"service_name": "bedrock",
40+
"model_name": "amazon.titan-embed-text-v1",
41+
"input_docs_processed_step_size": "1"
42+
},
43+
"credential": {
44+
"access_key": "%s",
45+
"secret_key": "%s",
46+
"session_token": "%s"
47+
},
48+
"actions": [
49+
{
50+
"action_type": "predict",
51+
"method": "POST",
52+
"url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke",
53+
"headers": {
54+
"content-type": "application/json",
55+
"x-amz-content-sha256": "required"
56+
},
57+
"request_body": "{ \"inputText\": \"${parameters.inputText}\" }",
58+
"pre_process_function": "connector.pre_process.bedrock.embedding",
59+
"post_process_function": "connector.post_process.bedrock.embedding"
60+
}
61+
]
62+
}
63+
}

0 commit comments

Comments
 (0)