Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support embedding types #3560

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
public class MLPostProcessFunction {

public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding";
public static final String COHERE_V2_EMBEDDING_FLOAT32 = "connector.post_process.cohere_v2.embedding.float";
public static final String COHERE_V2_EMBEDDING_INT8 = "connector.post_process.cohere_v2.embedding.int8";
public static final String COHERE_V2_EMBEDDING_UINT8 = "connector.post_process.cohere_v2.embedding.int8";
public static final String COHERE_V2_EMBEDDING_UINT8 = "connector.post_process.cohere_v2.embedding.uint8";
public static final String COHERE_V2_EMBEDDING_BINARY = "connector.post_process.cohere_v2.embedding.binary";
public static final String COHERE_V2_EMBEDDING_UBINARY = "connector.post_process.cohere_v2.embedding.ubinary";
public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding";
public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding";
public static final String BEDROCK_V2_EMBEDDING_FLOAT = "connector.post_process.bedrock_v2.embedding.float";
public static final String BEDROCK_V2_EMBEDDING_BINARY = "connector.post_process.bedrock_v2.embedding.binary";
public static final String BEDROCK_BATCH_JOB_ARN = "connector.post_process.bedrock.batch_job_arn";
public static final String COHERE_RERANK = "connector.post_process.cohere.rerank";
Expand All @@ -46,24 +48,29 @@ public class MLPostProcessFunction {
BedrockRerankPostProcessFunction bedrockRerankPostProcessFunction = new BedrockRerankPostProcessFunction();
JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding");
JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings");
JSON_PATH_EXPRESSION.put(COHERE_V2_EMBEDDING_FLOAT32, "$.embeddings.float");
JSON_PATH_EXPRESSION.put(COHERE_V2_EMBEDDING_INT8, "$.embeddings.int8");
JSON_PATH_EXPRESSION.put(COHERE_V2_EMBEDDING_UINT8, "$.embeddings.uint8");
JSON_PATH_EXPRESSION.put(COHERE_V2_EMBEDDING_BINARY, "$.embeddings.binary");
JSON_PATH_EXPRESSION.put(COHERE_V2_EMBEDDING_UBINARY, "$.embeddings.ubinary");
JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]");
JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding");
JSON_PATH_EXPRESSION.put(BEDROCK_V2_EMBEDDING_FLOAT, "$.embeddingsByType.float");
JSON_PATH_EXPRESSION.put(BEDROCK_V2_EMBEDDING_BINARY, "$.embeddingsByType.binary");
JSON_PATH_EXPRESSION.put(BEDROCK_BATCH_JOB_ARN, "$");
JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results");
JSON_PATH_EXPRESSION.put(BEDROCK_RERANK, "$.results");
JSON_PATH_EXPRESSION.put(DEFAULT_RERANK, "$[*]");
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_V2_EMBEDDING_FLOAT32, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_V2_EMBEDDING_INT8, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_V2_EMBEDDING_UINT8, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_V2_EMBEDDING_BINARY, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_V2_EMBEDDING_UBINARY, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, bedrockEmbeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_V2_EMBEDDING_FLOAT, bedrockEmbeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_V2_EMBEDDING_BINARY, bedrockEmbeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_BATCH_JOB_ARN, batchJobArnPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,29 @@ POST /_plugins/_ml/connectors/_create
}
```

If you're using BedRock V2 API, you should supply `embeddingTypes` in request body:
```json
POST /_plugins/_ml/connectors/_create
{
...
"parameters": {
...
"model": "amazon.titan-embed-text-v2:0"
},
"actions": [
{
...
"request_body": "{ \"inputText\": \"${parameters.inputText}\", \"embeddingTypes\": [\"float\"] }",
"pre_process_function": "connector.pre_process.bedrock.embedding",
"post_process_function": "onnector.post_process.bedrock_v2.embedding.float"
}
]
}
```
For BedRock v2 embedding API, there are several build-in post_process_function that can extract the embedding result to a list of list of number format:
1. v2 float: connector.post_process.bedrock_v2.embedding.float
2. v2 binary: connector.post_process.bedrock_v2.embedding.binary

If using the AWS Opensearch Service, you can provide an IAM role arn that allows access to the bedrock service.
Refer to this [AWS doc](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/ml-amazon-connector.html)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,29 @@ POST /_plugins/_ml/connectors/_create
]
}
```
If you're using cohere V2 embedding API, you should pass `embedding_types` in the request body
```json
POST /_plugins/_ml/connectors/_create
{
...
"actions": [
{
"action_type": "predict",
"method": "POST",
"url": "https://api.cohere.ai/v2/embed",
"request_body": "{ \"texts\": ${parameters.texts}, \"truncate\": \"END\", \"model\": \"${parameters.model_name}\", \"embedding_types\": [\"float\"], \"input_type\": \"${parameters.input_type}\"}",
"pre_process_function": "connector.pre_process.cohere.embedding",
"post_process_function": "connector.post_process.cohere_v2.embedding.float"
}
]
}
```
For cohere v2 embedding API, there are several build-in post_process_function that can extract the embedding result to a list of list of number format:
1. v2 float: connector.post_process.cohere_v2.embedding.float
2. v2 int8: connector.post_process.cohere_v2.embedding.int8
3. v2 uint8: connector.post_process.cohere_v2.embedding.uint8
4. v2 binary: connector.post_process.cohere_v2.embedding.binary
5. v2 ubinary: connector.post_process.cohere_v2.embedding.ubinary

This request response will return the `connector_id`, note it down.

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.rest;

import lombok.SneakyThrows;
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang3.StringUtils;
import org.junit.Before;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;

@Log4j2
public class RestBedRockV2PostProcessFunctionInferenceIT extends MLCommonsRestTestCase {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why nor merge into the BedRock Inference IT to share the IT resources, reducing the overall IT run time length?
https://github.com/opensearch-project/ml-commons/blob/main/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The template of RestBedRockInferenceIT and RestBedRockV2PostProcessFunctionInferenceIT are different, so if we merge them together, the code will seem mess up, separating them is for better code reading experience.

Copy link
Collaborator

@Zhangxunmt Zhangxunmt Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, but this means a simple test case would need spinning up new IT cluster resources. Given the IT suites already takes 40+ mins, it's not scalable to keep adding new test classes. So organizing the same test categories into a same test suite sharing the resources is a better approach overall. Do you have any better ideas on this matter?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I can take a look again to find a way to merge them and maintain the readability in best effort.

private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID");
private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY");
private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN");
private static final String GITHUB_CI_AWS_REGION = "us-west-2";
private static final List<String> POST_PROCESS_FUNCTIONS = List.of(
"connector.post_process.bedrock_v2.embedding.float",
"connector.post_process.bedrock_v2.embedding.binary"
);
private static final Map<String, String> DATA_TYPE = Map.of(
"connector.post_process.bedrock_v2.embedding.float", "FLOAT32",
"connector.post_process.bedrock_v2.embedding.binary", "BINARY"
);

@SneakyThrows
@Before
public void setup() throws IOException, InterruptedException {
RestMLRemoteInferenceIT.disableClusterConnectorAccessControl();
Thread.sleep(20000);
}

public void test_bedrock_embedding_model() throws Exception {
// Skip test if key is null
if (tokenNotSet()) {
return;
}
String templates = Files
.readString(
Path
.of(
RestMLPredictionAction.class
.getClassLoader()
.getResource("org/opensearch/ml/rest/templates/BedRockV2ConnectorBodies.json")
.toURI()
)
);
for (String postProcessFunction : POST_PROCESS_FUNCTIONS) {
String bedrockEmbeddingModelName = "bedrock embedding model: " + postProcessFunction;
String modelId = registerRemoteModel(
String
.format(
templates,
GITHUB_CI_AWS_REGION,
AWS_ACCESS_KEY_ID,
AWS_SECRET_ACCESS_KEY,
AWS_SESSION_TOKEN,
StringUtils.substringAfterLast(postProcessFunction, "."),
postProcessFunction
),
bedrockEmbeddingModelName,
true
);
String errorMsg = String.format("failed to test: %s", postProcessFunction);
TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build();
MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build();
Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput);
assertTrue(errorMsg, inferenceResult.containsKey("inference_results"));
List output = (List) inferenceResult.get("inference_results");
assertEquals(errorMsg, 2, output.size());
assertTrue(errorMsg, output.get(0) instanceof Map);
assertTrue(errorMsg, output.get(1) instanceof Map);
validateOutput(errorMsg, (Map) output.get(0), DATA_TYPE.get(postProcessFunction));
validateOutput(errorMsg, (Map) output.get(1), DATA_TYPE.get(postProcessFunction));
}
}

private void validateOutput(String errorMsg, Map<String, Object> output, String dataType) {
assertTrue(errorMsg, output.containsKey("output"));
assertTrue(errorMsg, output.get("output") instanceof List);
List outputList = (List) output.get("output");
assertEquals(errorMsg, 1, outputList.size());
assertTrue(errorMsg, outputList.get(0) instanceof Map);
assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data") instanceof List);
assertEquals(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data_type"), dataType);
}

private boolean tokenNotSet() {
if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) {
log.info("#### The AWS credentials are not set. Skipping test. ####");
return true;
}
return false;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package org.opensearch.ml.rest;

import org.apache.commons.lang3.StringUtils;
import org.junit.Before;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class RestCohereInferenceIT extends MLCommonsRestTestCase {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This IT can be merged into the RemoteInferenceIT?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The RemoteInferenceIT class file is already too big, I prefer we use separate file to test single feature for better maintenance.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, the spotless is not applied yet.

private final String COHERE_KEY = Optional.ofNullable(System.getenv("COHERE_KEY")).orElse("UzRF34a6gj0OKkvHOO6FZxLItv8CNpK5dFdCaUDW");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't expose this API key!!!

cc @ylwu-amzn

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zane-neo, let's avoid adding credentials in code

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's my private API key for testing, forgot to remove it in the PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

git-secrets could be helpful here.

private final Map<String, String> DATA_TYPE = Map.of(
"connector.post_process.cohere_v2.embedding.float", "FLOAT32",
"connector.post_process.cohere_v2.embedding.int8", "INT8",
"connector.post_process.cohere_v2.embedding.uint8", "UINT8",
"connector.post_process.cohere_v2.embedding.binary", "BINARY",
"connector.post_process.cohere_v2.embedding.ubinary", "UBINARY"
);
private final List<String> POST_PROCESS_FUNCTIONS = List.of(
"connector.post_process.cohere_v2.embedding.float",
"connector.post_process.cohere_v2.embedding.int8",
"connector.post_process.cohere_v2.embedding.uint8",
"connector.post_process.cohere_v2.embedding.binary",
"connector.post_process.cohere_v2.embedding.ubinary");

@Before
public void setup() throws IOException {
updateClusterSettings("plugins.ml_commons.trusted_connector_endpoints_regex", List.of("^.*$"));
}


public void test_cohereInference_withDifferent_postProcessFunction() throws URISyntaxException, IOException, InterruptedException {
String templates = Files
.readString(
Path
.of(
RestMLPredictionAction.class
.getClassLoader()
.getResource("org/opensearch/ml/rest/templates/CohereConnectorBodies.json")
.toURI()
)
);
for (String postProcessFunction : POST_PROCESS_FUNCTIONS) {
String connectorRequestBody = String.format(templates, COHERE_KEY, StringUtils.substringAfterLast(postProcessFunction, "."), postProcessFunction);
String testCaseName = postProcessFunction + "_test";
String modelId = registerRemoteModel(connectorRequestBody, testCaseName, true);
String errorMsg = String.format("failed to run test with test name: %s", testCaseName);
TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build();
MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build();
Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput);
assertTrue(errorMsg, inferenceResult.containsKey("inference_results"));
List output = (List) inferenceResult.get("inference_results");
assertEquals(errorMsg, 1, output.size());
assertTrue(errorMsg, output.get(0) instanceof Map);
validateOutput(errorMsg, (Map) output.get(0), DATA_TYPE.get(postProcessFunction));
}
}

private void validateOutput(String errorMsg, Map<String, Object> output, String dataType) {
assertTrue(errorMsg, output.containsKey("output"));
assertTrue(errorMsg, output.get("output") instanceof List);
List outputList = (List) output.get("output");
assertEquals(errorMsg, 2, outputList.size());
assertTrue(errorMsg, outputList.get(0) instanceof Map);
assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data") instanceof List);
assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data_type").equals(dataType));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"name": "Amazon Bedrock Connector: embedding",
"description": "The connector to bedrock Titan embedding model",
"version": 1,
"protocol": "aws_sigv4",
"parameters": {
"region": "%s",
"service_name": "bedrock",
"model_name": "amazon.titan-embed-text-v2:0"
},
"credential": {
"access_key": "%s",
"secret_key": "%s",
"session_token": "%s"
},
"actions": [
{
"action_type": "predict",
"method": "POST",
"url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke",
"headers": {
"content-type": "application/json",
"x-amz-content-sha256": "required"
},
"request_body": "{ \"inputText\": \"${parameters.inputText}\", \"embeddingTypes\": [\"%s\"] }",
"pre_process_function": "connector.pre_process.bedrock.embedding",
"post_process_function": "%s"
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"name": "Cohere Connector: embedding",
"description": "The connector to cohere embedding model",
"version": 1,
"protocol": "http",
"parameters": {
"model_name": "embed-english-v3.0"
},
"credential": {
"cohere_key": "%s"
},
"actions": [
{
"action_type": "predict",
"method": "POST",
"url": "https://api.cohere.com/v2/embed",
"headers": {
"content-type": "application/json",
"Authorization": "Bearer ${credential.cohere_key}"
},
"request_body": "{ \"texts\": ${parameters.texts}, \"truncate\": \"END\", \"model\": \"${parameters.model_name}\", \"embedding_types\": [\"%s\"], \"input_type\": \"classification\"}",
"pre_process_function": "connector.pre_process.cohere.embedding",
"post_process_function": "%s"
}
]
}
Loading