Skip to content

Commit 152c5e2

Browse files
Add process function for bedrock (opensearch-project#1554) (opensearch-project#1948)
* Add process function for bedrock Signed-off-by: zane-neo <zaniu@amazon.com> * Merge two functions together Signed-off-by: zane-neo <zaniu@amazon.com> * change method name back Signed-off-by: zane-neo <zaniu@amazon.com> * Fix compile issue after rebase Signed-off-by: zane-neo <zaniu@amazon.com> * format code Signed-off-by: zane-neo <zaniu@amazon.com> * Fix compile issue after merging methods Signed-off-by: zane-neo <zaniu@amazon.com> --------- Signed-off-by: zane-neo <zaniu@amazon.com> (cherry picked from commit 33977a1) Co-authored-by: zane-neo <zaniu@amazon.com>
1 parent bd8b6ab commit 152c5e2

File tree

5 files changed

+36
-23
lines changed

5 files changed

+36
-23
lines changed

common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java

+22-13
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.ml.common.connector;
77

8+
import com.google.common.collect.ImmutableList;
89
import org.opensearch.ml.common.output.model.MLResultDataType;
910
import org.opensearch.ml.common.output.model.ModelTensor;
1011

@@ -18,38 +19,46 @@ public class MLPostProcessFunction {
1819

1920
public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding";
2021
public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding";
21-
22+
public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding";
2223
public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding";
2324

2425
private static final Map<String, String> JSON_PATH_EXPRESSION = new HashMap<>();
2526

26-
private static final Map<String, Function<List<List<Float>>, List<ModelTensor>>> POST_PROCESS_FUNCTIONS = new HashMap<>();
27+
private static final Map<String, Function<List<?>, List<ModelTensor>>> POST_PROCESS_FUNCTIONS = new HashMap<>();
2728

2829

2930
static {
3031
JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding");
3132
JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings");
3233
JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]");
34+
JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding");
3335
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, buildModelTensorList());
3436
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, buildModelTensorList());
3537
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, buildModelTensorList());
38+
POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, buildModelTensorList());
3639
}
3740

38-
public static Function<List<List<Float>>, List<ModelTensor>> buildModelTensorList() {
41+
public static Function<List<?>, List<ModelTensor>> buildModelTensorList() {
3942
return embeddings -> {
4043
List<ModelTensor> modelTensors = new ArrayList<>();
4144
if (embeddings == null) {
4245
throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function.");
4346
}
44-
embeddings.forEach(embedding -> modelTensors.add(
45-
ModelTensor
46-
.builder()
47-
.name("sentence_embedding")
48-
.dataType(MLResultDataType.FLOAT32)
49-
.shape(new long[]{embedding.size()})
50-
.data(embedding.toArray(new Number[0]))
51-
.build()
52-
));
47+
if (embeddings.get(0) instanceof Number) {
48+
embeddings = ImmutableList.of(embeddings);
49+
}
50+
embeddings.forEach(embedding -> {
51+
List<Number> eachEmbedding = (List<Number>) embedding;
52+
modelTensors.add(
53+
ModelTensor
54+
.builder()
55+
.name("sentence_embedding")
56+
.dataType(MLResultDataType.FLOAT32)
57+
.shape(new long[]{eachEmbedding.size()})
58+
.data(eachEmbedding.toArray(new Number[0]))
59+
.build()
60+
);
61+
});
5362
return modelTensors;
5463
};
5564
}
@@ -58,7 +67,7 @@ public static String getResponseFilter(String postProcessFunction) {
5867
return JSON_PATH_EXPRESSION.get(postProcessFunction);
5968
}
6069

61-
public static Function<List<List<Float>>, List<ModelTensor>> get(String postProcessFunction) {
70+
public static Function<List<?>, List<ModelTensor>> get(String postProcessFunction) {
6271
return POST_PROCESS_FUNCTIONS.get(postProcessFunction);
6372
}
6473

common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java

+8-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public class MLPreProcessFunction {
1515
private static final Map<String, Function<List<String>, Map<String, Object>>> PRE_PROCESS_FUNCTIONS = new HashMap<>();
1616
public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding";
1717
public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding";
18-
18+
public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding";
1919
public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding";
2020

2121
private static Function<List<String>, Map<String, Object>> cohereTextEmbeddingPreProcess() {
@@ -26,17 +26,22 @@ private static Function<List<String>, Map<String, Object>> openAiTextEmbeddingPr
2626
return inputs -> Map.of("parameters", Map.of("input", inputs));
2727
}
2828

29+
private static Function<List<String>, Map<String, Object>> bedrockTextEmbeddingPreProcess() {
30+
return inputs -> Map.of("parameters", Map.of("inputText", inputs.get(0)));
31+
}
32+
2933
static {
3034
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereTextEmbeddingPreProcess());
3135
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
3236
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
37+
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockTextEmbeddingPreProcess());
3338
}
3439

3540
public static boolean contains(String functionName) {
3641
return PRE_PROCESS_FUNCTIONS.containsKey(functionName);
3742
}
3843

39-
public static Function<List<String>, Map<String, Object>> get(String postProcessFunction) {
40-
return PRE_PROCESS_FUNCTIONS.get(postProcessFunction);
44+
public static Function<List<String>, Map<String, Object>> get(String preProcessFunction) {
45+
return PRE_PROCESS_FUNCTIONS.get(preProcessFunction);
4146
}
4247
}

common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java

+4-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import java.util.Collections;
1515
import java.util.List;
1616

17+
import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_EMBEDDING;
18+
import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_EMBEDDING;
1719
import static org.opensearch.ml.common.connector.MLPostProcessFunction.OPENAI_EMBEDDING;
1820

1921
public class MLPostProcessFunctionTest {
@@ -29,13 +31,13 @@ public void contains() {
2931

3032
@Test
3133
public void get() {
32-
Assert.assertNotNull(MLPostProcessFunction.get(OPENAI_EMBEDDING));
34+
Assert.assertNotNull(MLPostProcessFunction.get(COHERE_EMBEDDING));
3335
Assert.assertNull(MLPostProcessFunction.get("wrong value"));
3436
}
3537

3638
@Test
3739
public void test_getResponseFilter() {
38-
assert null != MLPostProcessFunction.getResponseFilter(OPENAI_EMBEDDING);
40+
assert null != MLPostProcessFunction.getResponseFilter(BEDROCK_EMBEDDING);
3941
assert null == MLPostProcessFunction.getResponseFilter("wrong value");
4042
}
4143

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ public static ModelTensors processOutput(
175175
// in this case, we can use jsonpath to build a List<List<Float>> result from model response.
176176
if (StringUtils.isBlank(responseFilter))
177177
responseFilter = MLPostProcessFunction.getResponseFilter(postProcessFunction);
178-
List<List<Float>> vectors = JsonPath.read(modelResponse, responseFilter);
178+
List<?> vectors = JsonPath.read(modelResponse, responseFilter);
179179
List<ModelTensor> processedResponse = executeBuildInPostProcessFunction(
180180
vectors,
181181
MLPostProcessFunction.get(postProcessFunction)

ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java

+1-4
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,7 @@ public static Optional<String> executePreprocessFunction(
3030
return Optional.ofNullable(executeScript(scriptService, preProcessFunction, ImmutableMap.of("text_docs", inputSentences)));
3131
}
3232

33-
public static List<ModelTensor> executeBuildInPostProcessFunction(
34-
List<List<Float>> vectors,
35-
Function<List<List<Float>>, List<ModelTensor>> function
36-
) {
33+
public static List<ModelTensor> executeBuildInPostProcessFunction(List<?> vectors, Function<List<?>, List<ModelTensor>> function) {
3734
return function.apply(vectors);
3835
}
3936

0 commit comments

Comments
 (0)