Skip to content

Commit c5225de

Browse files
authored
fine tune connector process function (opensearch-project#1954)
* fine tune connector process function Signed-off-by: Yaliang Wu <ylwu@amazon.com> * add unit test for process function Signed-off-by: Yaliang Wu <ylwu@amazon.com> * add license header Signed-off-by: Yaliang Wu <ylwu@amazon.com> --------- Signed-off-by: Yaliang Wu <ylwu@amazon.com>
1 parent a5c500c commit c5225de

35 files changed

+1311
-149
lines changed

common/build.gradle

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ dependencies {
2020
compileOnly "org.opensearch.client:opensearch-rest-client:${opensearch_version}"
2121
compileOnly "org.opensearch:common-utils:${common_utils_version}"
2222
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.7.0'
23+
testImplementation "org.opensearch.test:framework:${opensearch_version}"
2324

2425
compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
2526
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ public <T> T createPredictPayload(Map<String, String> parameters) {
291291
payload = substitutor.replace(payload);
292292

293293
if (!isJson(payload)) {
294-
throw new IllegalArgumentException("Invalid JSON in payload");
294+
throw new IllegalArgumentException("Invalid payload: " + payload);
295295
}
296296
return (T) payload;
297297
}

common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java

+19-36
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66
package org.opensearch.ml.common.connector;
77

8-
import com.google.common.collect.ImmutableList;
9-
import org.opensearch.ml.common.output.model.MLResultDataType;
8+
import org.opensearch.ml.common.connector.functions.postprocess.BedrockEmbeddingPostProcessFunction;
9+
import org.opensearch.ml.common.connector.functions.postprocess.CohereRerankPostProcessFunction;
10+
import org.opensearch.ml.common.connector.functions.postprocess.EmbeddingPostProcessFunction;
1011
import org.opensearch.ml.common.output.model.ModelTensor;
1112

12-
import java.util.ArrayList;
1313
import java.util.HashMap;
1414
import java.util.List;
1515
import java.util.Map;
@@ -20,58 +20,41 @@ public class MLPostProcessFunction {
2020
public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding";
2121
public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding";
2222
public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding";
23+
public static final String COHERE_RERANK = "connector.post_process.cohere.rerank";
2324
public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding";
25+
public static final String DEFAULT_RERANK = "connector.post_process.default.rerank";
2426

2527
private static final Map<String, String> JSON_PATH_EXPRESSION = new HashMap<>();
2628

27-
private static final Map<String, Function<List<?>, List<ModelTensor>>> POST_PROCESS_FUNCTIONS = new HashMap<>();
28-
29+
private static final Map<String, Function<Object, List<ModelTensor>>> POST_PROCESS_FUNCTIONS = new HashMap<>();
2930

3031
static {
32+
EmbeddingPostProcessFunction embeddingPostProcessFunction = new EmbeddingPostProcessFunction();
33+
BedrockEmbeddingPostProcessFunction bedrockEmbeddingPostProcessFunction = new BedrockEmbeddingPostProcessFunction();
34+
CohereRerankPostProcessFunction cohereRerankPostProcessFunction = new CohereRerankPostProcessFunction();
3135
JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding");
3236
JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings");
3337
JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]");
3438
JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding");
35-
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, buildModelTensorList());
36-
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, buildModelTensorList());
37-
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, buildModelTensorList());
38-
POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, buildModelTensorList());
39-
}
40-
41-
public static Function<List<?>, List<ModelTensor>> buildModelTensorList() {
42-
return embeddings -> {
43-
List<ModelTensor> modelTensors = new ArrayList<>();
44-
if (embeddings == null) {
45-
throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function.");
46-
}
47-
if (embeddings.get(0) instanceof Number) {
48-
embeddings = ImmutableList.of(embeddings);
49-
}
50-
embeddings.forEach(embedding -> {
51-
List<Number> eachEmbedding = (List<Number>) embedding;
52-
modelTensors.add(
53-
ModelTensor
54-
.builder()
55-
.name("sentence_embedding")
56-
.dataType(MLResultDataType.FLOAT32)
57-
.shape(new long[]{eachEmbedding.size()})
58-
.data(eachEmbedding.toArray(new Number[0]))
59-
.build()
60-
);
61-
});
62-
return modelTensors;
63-
};
39+
JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results");
40+
JSON_PATH_EXPRESSION.put(DEFAULT_RERANK, "$[*]");
41+
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, embeddingPostProcessFunction);
42+
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction);
43+
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, embeddingPostProcessFunction);
44+
POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, bedrockEmbeddingPostProcessFunction);
45+
POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction);
46+
POST_PROCESS_FUNCTIONS.put(DEFAULT_RERANK, cohereRerankPostProcessFunction);
6447
}
6548

6649
public static String getResponseFilter(String postProcessFunction) {
6750
return JSON_PATH_EXPRESSION.get(postProcessFunction);
6851
}
6952

70-
public static Function<List<?>, List<ModelTensor>> get(String postProcessFunction) {
53+
public static Function<Object, List<ModelTensor>> get(String postProcessFunction) {
7154
return POST_PROCESS_FUNCTIONS.get(postProcessFunction);
7255
}
7356

7457
public static boolean contains(String postProcessFunction) {
7558
return POST_PROCESS_FUNCTIONS.containsKey(postProcessFunction);
7659
}
77-
}
60+
}

common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java

+24-19
Original file line numberDiff line numberDiff line change
@@ -5,43 +5,48 @@
55

66
package org.opensearch.ml.common.connector;
77

8+
import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction;
9+
import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction;
10+
import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction;
11+
import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction;
12+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
13+
import org.opensearch.ml.common.input.MLInput;
14+
815
import java.util.HashMap;
9-
import java.util.List;
1016
import java.util.Map;
1117
import java.util.function.Function;
1218

1319
public class MLPreProcessFunction {
1420

15-
private static final Map<String, Function<List<String>, Map<String, Object>>> PRE_PROCESS_FUNCTIONS = new HashMap<>();
21+
private static final Map<String, Function<MLInput, RemoteInferenceInputDataSet>> PRE_PROCESS_FUNCTIONS = new HashMap<>();
1622
public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding";
1723
public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding";
1824
public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding";
1925
public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding";
26+
public static final String TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT = "connector.pre_process.cohere.rerank";
27+
public static final String TEXT_SIMILARITY_TO_DEFAULT_INPUT = "connector.pre_process.default.rerank";
2028

21-
private static Function<List<String>, Map<String, Object>> cohereTextEmbeddingPreProcess() {
22-
return inputs -> Map.of("parameters", Map.of("texts", inputs));
23-
}
24-
25-
private static Function<List<String>, Map<String, Object>> openAiTextEmbeddingPreProcess() {
26-
return inputs -> Map.of("parameters", Map.of("input", inputs));
27-
}
28-
29-
private static Function<List<String>, Map<String, Object>> bedrockTextEmbeddingPreProcess() {
30-
return inputs -> Map.of("parameters", Map.of("inputText", inputs.get(0)));
31-
}
29+
public static final String PROCESS_REMOTE_INFERENCE_INPUT = "pre_process_function.process_remote_inference_input";
30+
public static final String CONVERT_INPUT_TO_JSON_STRING = "pre_process_function.convert_input_to_json_string";
3231

3332
static {
34-
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereTextEmbeddingPreProcess());
35-
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
36-
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
37-
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockTextEmbeddingPreProcess());
33+
CohereEmbeddingPreProcessFunction cohereEmbeddingPreProcessFunction = new CohereEmbeddingPreProcessFunction();
34+
OpenAIEmbeddingPreProcessFunction openAIEmbeddingPreProcessFunction = new OpenAIEmbeddingPreProcessFunction();
35+
BedrockEmbeddingPreProcessFunction bedrockEmbeddingPreProcessFunction = new BedrockEmbeddingPreProcessFunction();
36+
CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction();
37+
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction);
38+
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction);
39+
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction);
40+
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction);
41+
PRE_PROCESS_FUNCTIONS.put(TEXT_SIMILARITY_TO_DEFAULT_INPUT, cohereRerankPreProcessFunction);
42+
PRE_PROCESS_FUNCTIONS.put(TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT, cohereRerankPreProcessFunction);
3843
}
3944

4045
public static boolean contains(String functionName) {
4146
return PRE_PROCESS_FUNCTIONS.containsKey(functionName);
4247
}
4348

44-
public static Function<List<String>, Map<String, Object>> get(String preProcessFunction) {
45-
return PRE_PROCESS_FUNCTIONS.get(preProcessFunction);
49+
public static Function<MLInput, RemoteInferenceInputDataSet> get(String postProcessFunction) {
50+
return PRE_PROCESS_FUNCTIONS.get(postProcessFunction);
4651
}
4752
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.connector.functions.postprocess;
7+
8+
import org.opensearch.ml.common.output.model.MLResultDataType;
9+
import org.opensearch.ml.common.output.model.ModelTensor;
10+
11+
import java.util.ArrayList;
12+
import java.util.List;
13+
14+
public class BedrockEmbeddingPostProcessFunction extends ConnectorPostProcessFunction<List<Float>> {
15+
16+
@Override
17+
public void validate(Object input) {
18+
if (!(input instanceof List)) {
19+
throw new IllegalArgumentException("Post process function input is not a List.");
20+
}
21+
22+
List<?> outerList = (List<?>) input;
23+
24+
if (!outerList.isEmpty() && !(((List<?>)input).get(0) instanceof Number)) {
25+
throw new IllegalArgumentException("The embedding should be a non-empty List containing Float values.");
26+
}
27+
}
28+
29+
@Override
30+
public List<ModelTensor> process(List<Float> embedding) {
31+
List<ModelTensor> modelTensors = new ArrayList<>();
32+
modelTensors.add(
33+
ModelTensor
34+
.builder()
35+
.name("sentence_embedding")
36+
.dataType(MLResultDataType.FLOAT32)
37+
.shape(new long[]{embedding.size()})
38+
.data(embedding.toArray(new Number[0]))
39+
.build());
40+
return modelTensors;
41+
}
42+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.connector.functions.postprocess;
7+
8+
import org.opensearch.ml.common.output.model.MLResultDataType;
9+
import org.opensearch.ml.common.output.model.ModelTensor;
10+
11+
import java.util.ArrayList;
12+
import java.util.List;
13+
import java.util.Map;
14+
15+
public class CohereRerankPostProcessFunction extends ConnectorPostProcessFunction<List<Map<String, Object>>> {
16+
17+
@Override
18+
public void validate(Object input) {
19+
if (!(input instanceof List)) {
20+
throw new IllegalArgumentException("Post process function input is not a List.");
21+
}
22+
List<?> outerList = (List<?>) input;
23+
if (!outerList.isEmpty()) {
24+
if (!(outerList.get(0) instanceof Map)) {
25+
throw new IllegalArgumentException("Post process function input is not a List of Map.");
26+
}
27+
Map innerMap = (Map) outerList.get(0);
28+
29+
if (innerMap.isEmpty() || !innerMap.containsKey("index") || !innerMap.containsKey("relevance_score")) {
30+
throw new IllegalArgumentException("The rerank result should contain index and relevance_score.");
31+
}
32+
}
33+
}
34+
35+
@Override
36+
public List<ModelTensor> process(List<Map<String, Object>> rerankResults) {
37+
List<ModelTensor> modelTensors = new ArrayList<>();
38+
39+
if (rerankResults.size() > 0) {
40+
Double[] scores = new Double[rerankResults.size()];
41+
for (int i = 0; i < rerankResults.size(); i++) {
42+
Integer index = (Integer) rerankResults.get(i).get("index");
43+
scores[index] = (Double) rerankResults.get(i).get("relevance_score");
44+
}
45+
46+
for (int i = 0; i < scores.length; i++) {
47+
modelTensors.add(ModelTensor.builder()
48+
.name("similarity")
49+
.shape(new long[]{1})
50+
.data(new Number[]{scores[i]})
51+
.dataType(MLResultDataType.FLOAT32)
52+
.build());
53+
}
54+
}
55+
return modelTensors;
56+
}
57+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.connector.functions.postprocess;
7+
8+
import org.opensearch.ml.common.output.model.ModelTensor;
9+
10+
import java.util.List;
11+
import java.util.function.Function;
12+
13+
public abstract class ConnectorPostProcessFunction<T> implements Function<Object, List<ModelTensor>> {
14+
15+
@Override
16+
public List<ModelTensor> apply(Object input) {
17+
if (input == null) {
18+
throw new IllegalArgumentException("Can't run post process function as model output is null");
19+
}
20+
validate(input);
21+
return process((T)input);
22+
}
23+
24+
public abstract void validate(Object input);
25+
26+
public abstract List<ModelTensor> process(T input);
27+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.connector.functions.postprocess;
7+
8+
import org.opensearch.ml.common.output.model.MLResultDataType;
9+
import org.opensearch.ml.common.output.model.ModelTensor;
10+
11+
import java.util.ArrayList;
12+
import java.util.List;
13+
14+
public class EmbeddingPostProcessFunction extends ConnectorPostProcessFunction<List<List<Float>>> {
15+
16+
@Override
17+
public void validate(Object input) {
18+
if (!(input instanceof List)) {
19+
throw new IllegalArgumentException("Post process function input is not a List.");
20+
}
21+
22+
List<?> outerList = (List<?>) input;
23+
24+
if (!outerList.isEmpty()) {
25+
if (!(outerList.get(0) instanceof List)) {
26+
throw new IllegalArgumentException("The embedding should be a non-empty List containing List of Float values.");
27+
}
28+
List<?> innerList = (List<?>) outerList.get(0);
29+
30+
if (innerList.isEmpty() || !(innerList.get(0) instanceof Number)) {
31+
throw new IllegalArgumentException("The embedding should be a non-empty List containing Float values.");
32+
}
33+
}
34+
}
35+
36+
@Override
37+
public List<ModelTensor> process(List<List<Float>> embeddings) {
38+
List<ModelTensor> modelTensors = new ArrayList<>();
39+
embeddings.forEach(embedding -> modelTensors.add(
40+
ModelTensor
41+
.builder()
42+
.name("sentence_embedding")
43+
.dataType(MLResultDataType.FLOAT32)
44+
.shape(new long[]{embedding.size()})
45+
.data(embedding.toArray(new Number[0]))
46+
.build()
47+
));
48+
return modelTensors;
49+
}
50+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.connector.functions.preprocess;
7+
8+
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
9+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
10+
import org.opensearch.ml.common.input.MLInput;
11+
12+
import java.util.Map;
13+
14+
import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;
15+
16+
17+
public class BedrockEmbeddingPreProcessFunction extends ConnectorPreProcessFunction {
18+
19+
public BedrockEmbeddingPreProcessFunction() {
20+
this.returnDirectlyForRemoteInferenceInput = true;
21+
}
22+
23+
@Override
24+
public void validate(MLInput mlInput) {
25+
validateTextDocsInput(mlInput);
26+
}
27+
28+
@Override
29+
public RemoteInferenceInputDataSet process(MLInput mlInput) {
30+
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
31+
Map<String, Object> processedResult = Map.of("parameters", Map.of("inputText", processTextDocs(inputData).get(0)));
32+
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build();
33+
}
34+
}

0 commit comments

Comments
 (0)