-
Notifications
You must be signed in to change notification settings - Fork 148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support embedding types #3560
support embedding types #3560
Changes from 1 commit
c3d50eb
0a19fab
266d627
bef2c18
a0c5fdd
33cc4f4
e1e1e14
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.rest; | ||
|
||
import lombok.SneakyThrows; | ||
import lombok.extern.log4j.Log4j2; | ||
import org.apache.commons.lang3.StringUtils; | ||
import org.junit.Before; | ||
import org.opensearch.ml.common.FunctionName; | ||
import org.opensearch.ml.common.dataset.TextDocsInputDataSet; | ||
import org.opensearch.ml.common.input.MLInput; | ||
|
||
import java.io.IOException; | ||
import java.nio.file.Files; | ||
import java.nio.file.Path; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Locale; | ||
import java.util.Map; | ||
|
||
@Log4j2 | ||
public class RestBedRockV2PostProcessFunctionInferenceIT extends MLCommonsRestTestCase { | ||
private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); | ||
private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); | ||
private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN"); | ||
private static final String GITHUB_CI_AWS_REGION = "us-west-2"; | ||
private static final List<String> POST_PROCESS_FUNCTIONS = List.of( | ||
"connector.post_process.bedrock_v2.embedding.float", | ||
"connector.post_process.bedrock_v2.embedding.binary" | ||
); | ||
private static final Map<String, String> DATA_TYPE = Map.of( | ||
"connector.post_process.bedrock_v2.embedding.float", "FLOAT32", | ||
"connector.post_process.bedrock_v2.embedding.binary", "BINARY" | ||
); | ||
|
||
@SneakyThrows | ||
@Before | ||
public void setup() throws IOException, InterruptedException { | ||
RestMLRemoteInferenceIT.disableClusterConnectorAccessControl(); | ||
Thread.sleep(20000); | ||
} | ||
|
||
public void test_bedrock_embedding_model() throws Exception { | ||
// Skip test if key is null | ||
if (tokenNotSet()) { | ||
return; | ||
} | ||
String templates = Files | ||
.readString( | ||
Path | ||
.of( | ||
RestMLPredictionAction.class | ||
.getClassLoader() | ||
.getResource("org/opensearch/ml/rest/templates/BedRockV2ConnectorBodies.json") | ||
.toURI() | ||
) | ||
); | ||
for (String postProcessFunction : POST_PROCESS_FUNCTIONS) { | ||
String bedrockEmbeddingModelName = "bedrock embedding model: " + postProcessFunction; | ||
String modelId = registerRemoteModel( | ||
String | ||
.format( | ||
templates, | ||
GITHUB_CI_AWS_REGION, | ||
AWS_ACCESS_KEY_ID, | ||
AWS_SECRET_ACCESS_KEY, | ||
AWS_SESSION_TOKEN, | ||
StringUtils.substringAfterLast(postProcessFunction, "."), | ||
postProcessFunction | ||
), | ||
bedrockEmbeddingModelName, | ||
true | ||
); | ||
String errorMsg = String.format("failed to test: %s", postProcessFunction); | ||
TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build(); | ||
MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); | ||
Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); | ||
assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); | ||
List output = (List) inferenceResult.get("inference_results"); | ||
assertEquals(errorMsg, 2, output.size()); | ||
assertTrue(errorMsg, output.get(0) instanceof Map); | ||
assertTrue(errorMsg, output.get(1) instanceof Map); | ||
validateOutput(errorMsg, (Map) output.get(0), DATA_TYPE.get(postProcessFunction)); | ||
validateOutput(errorMsg, (Map) output.get(1), DATA_TYPE.get(postProcessFunction)); | ||
} | ||
} | ||
|
||
private void validateOutput(String errorMsg, Map<String, Object> output, String dataType) { | ||
assertTrue(errorMsg, output.containsKey("output")); | ||
assertTrue(errorMsg, output.get("output") instanceof List); | ||
List outputList = (List) output.get("output"); | ||
assertEquals(errorMsg, 1, outputList.size()); | ||
assertTrue(errorMsg, outputList.get(0) instanceof Map); | ||
assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data") instanceof List); | ||
assertEquals(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data_type"), dataType); | ||
} | ||
|
||
private boolean tokenNotSet() { | ||
if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { | ||
log.info("#### The AWS credentials are not set. Skipping test. ####"); | ||
return true; | ||
} | ||
return false; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
package org.opensearch.ml.rest; | ||
|
||
import org.apache.commons.lang3.StringUtils; | ||
import org.junit.Before; | ||
import org.opensearch.ml.common.FunctionName; | ||
import org.opensearch.ml.common.dataset.TextDocsInputDataSet; | ||
import org.opensearch.ml.common.input.MLInput; | ||
|
||
import java.io.IOException; | ||
import java.net.URISyntaxException; | ||
import java.nio.file.Files; | ||
import java.nio.file.Path; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Optional; | ||
|
||
public class RestCohereInferenceIT extends MLCommonsRestTestCase { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This IT can be merged into the RemoteInferenceIT? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The RemoteInferenceIT class file is already too big, I prefer we use separate file to test single feature for better maintenance. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. btw, the spotless is not applied yet. |
||
private final String COHERE_KEY = Optional.ofNullable(System.getenv("COHERE_KEY")).orElse("UzRF34a6gj0OKkvHOO6FZxLItv8CNpK5dFdCaUDW"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We shouldn't expose this API key!!! cc @ylwu-amzn There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @zane-neo, let's avoid adding credentials in code There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's my private API key for testing, forgot to remove it in the PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. git-secrets could be helpful here. |
||
private final Map<String, String> DATA_TYPE = Map.of( | ||
"connector.post_process.cohere_v2.embedding.float", "FLOAT32", | ||
"connector.post_process.cohere_v2.embedding.int8", "INT8", | ||
"connector.post_process.cohere_v2.embedding.uint8", "UINT8", | ||
"connector.post_process.cohere_v2.embedding.binary", "BINARY", | ||
"connector.post_process.cohere_v2.embedding.ubinary", "UBINARY" | ||
); | ||
private final List<String> POST_PROCESS_FUNCTIONS = List.of( | ||
"connector.post_process.cohere_v2.embedding.float", | ||
"connector.post_process.cohere_v2.embedding.int8", | ||
"connector.post_process.cohere_v2.embedding.uint8", | ||
"connector.post_process.cohere_v2.embedding.binary", | ||
"connector.post_process.cohere_v2.embedding.ubinary"); | ||
|
||
@Before | ||
public void setup() throws IOException { | ||
updateClusterSettings("plugins.ml_commons.trusted_connector_endpoints_regex", List.of("^.*$")); | ||
} | ||
|
||
|
||
public void test_cohereInference_withDifferent_postProcessFunction() throws URISyntaxException, IOException, InterruptedException { | ||
String templates = Files | ||
.readString( | ||
Path | ||
.of( | ||
RestMLPredictionAction.class | ||
.getClassLoader() | ||
.getResource("org/opensearch/ml/rest/templates/CohereConnectorBodies.json") | ||
.toURI() | ||
) | ||
); | ||
for (String postProcessFunction : POST_PROCESS_FUNCTIONS) { | ||
String connectorRequestBody = String.format(templates, COHERE_KEY, StringUtils.substringAfterLast(postProcessFunction, "."), postProcessFunction); | ||
String testCaseName = postProcessFunction + "_test"; | ||
String modelId = registerRemoteModel(connectorRequestBody, testCaseName, true); | ||
String errorMsg = String.format("failed to run test with test name: %s", testCaseName); | ||
TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build(); | ||
MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); | ||
Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); | ||
assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); | ||
List output = (List) inferenceResult.get("inference_results"); | ||
assertEquals(errorMsg, 1, output.size()); | ||
assertTrue(errorMsg, output.get(0) instanceof Map); | ||
validateOutput(errorMsg, (Map) output.get(0), DATA_TYPE.get(postProcessFunction)); | ||
} | ||
} | ||
|
||
private void validateOutput(String errorMsg, Map<String, Object> output, String dataType) { | ||
assertTrue(errorMsg, output.containsKey("output")); | ||
assertTrue(errorMsg, output.get("output") instanceof List); | ||
List outputList = (List) output.get("output"); | ||
assertEquals(errorMsg, 2, outputList.size()); | ||
assertTrue(errorMsg, outputList.get(0) instanceof Map); | ||
assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data") instanceof List); | ||
assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data_type").equals(dataType)); | ||
} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
{ | ||
"name": "Amazon Bedrock Connector: embedding", | ||
"description": "The connector to bedrock Titan embedding model", | ||
"version": 1, | ||
"protocol": "aws_sigv4", | ||
"parameters": { | ||
"region": "%s", | ||
"service_name": "bedrock", | ||
"model_name": "amazon.titan-embed-text-v2:0" | ||
}, | ||
"credential": { | ||
"access_key": "%s", | ||
"secret_key": "%s", | ||
"session_token": "%s" | ||
}, | ||
"actions": [ | ||
{ | ||
"action_type": "predict", | ||
"method": "POST", | ||
"url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", | ||
"headers": { | ||
"content-type": "application/json", | ||
"x-amz-content-sha256": "required" | ||
}, | ||
"request_body": "{ \"inputText\": \"${parameters.inputText}\", \"embeddingTypes\": [\"%s\"] }", | ||
"pre_process_function": "connector.pre_process.bedrock.embedding", | ||
"post_process_function": "%s" | ||
} | ||
] | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
{ | ||
"name": "Cohere Connector: embedding", | ||
"description": "The connector to cohere embedding model", | ||
"version": 1, | ||
"protocol": "http", | ||
"parameters": { | ||
"model_name": "embed-english-v3.0" | ||
}, | ||
"credential": { | ||
"cohere_key": "%s" | ||
}, | ||
"actions": [ | ||
{ | ||
"action_type": "predict", | ||
"method": "POST", | ||
"url": "https://api.cohere.com/v2/embed", | ||
"headers": { | ||
"content-type": "application/json", | ||
"Authorization": "Bearer ${credential.cohere_key}" | ||
}, | ||
"request_body": "{ \"texts\": ${parameters.texts}, \"truncate\": \"END\", \"model\": \"${parameters.model_name}\", \"embedding_types\": [\"%s\"], \"input_type\": \"classification\"}", | ||
"pre_process_function": "connector.pre_process.cohere.embedding", | ||
"post_process_function": "%s" | ||
} | ||
] | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why nor merge into the BedRock Inference IT to share the IT resources, reducing the overall IT run time length?
https://github.com/opensearch-project/ml-commons/blob/main/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The template of RestBedRockInferenceIT and RestBedRockV2PostProcessFunctionInferenceIT are different, so if we merge them together, the code will seem mess up, separating them is for better code reading experience.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree, but this means a simple test case would need spinning up new IT cluster resources. Given the IT suites already takes 40+ mins, it's not scalable to keep adding new test classes. So organizing the same test categories into a same test suite sharing the resources is a better approach overall. Do you have any better ideas on this matter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I can take a look again to find a way to merge them and maintain the readability in best effort.