Skip to content

Commit c84c947

Browse files
Automated model interface generation on aws llms (opensearch-project#2689) (opensearch-project#2707)
* Automated model interface generation on aws llms Signed-off-by: b4sjoo <sicheng.song@outlook.com> * Add UTs Signed-off-by: b4sjoo <sicheng.song@outlook.com> * Add Comments and TODOs Signed-off-by: b4sjoo <sicheng.song@outlook.com> --------- Signed-off-by: b4sjoo <sicheng.song@outlook.com> (cherry picked from commit 9b413a7) Co-authored-by: Sicheng Song <sicheng.song@outlook.com>
1 parent c310660 commit c84c947

File tree

7 files changed

+880
-7
lines changed

7 files changed

+880
-7
lines changed

common/src/main/java/org/opensearch/ml/common/utils/ModelInterfaceUtils.java

+623
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.utils;
7+
8+
import org.junit.Before;
9+
import org.junit.Rule;
10+
import org.junit.Test;
11+
import org.junit.rules.ExpectedException;
12+
import org.mockito.Spy;
13+
import org.opensearch.ml.common.FunctionName;
14+
import org.opensearch.ml.common.connector.HttpConnector;
15+
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
16+
17+
import java.util.HashMap;
18+
import java.util.Map;
19+
20+
import static org.junit.Assert.assertEquals;
21+
import static org.junit.Assert.assertNull;
22+
import static org.opensearch.ml.common.utils.ModelInterfaceUtils.AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE;
23+
import static org.opensearch.ml.common.utils.ModelInterfaceUtils.AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE;
24+
import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE;
25+
import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE;
26+
import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_ANTHROPIC_CLAUDE_V3_SONNET_MODEL_INTERFACE;
27+
import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE;
28+
import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_MODEL_INTERFACE;
29+
import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE;
30+
import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE;
31+
import static org.opensearch.ml.common.utils.ModelInterfaceUtils.updateRegisterModelInputModelInterfaceFieldsByConnector;
32+
33+
public class ModelInterfaceUtilsTest {
34+
@Spy
35+
MLRegisterModelInput registerModelInputWithInnerConnector;
36+
37+
@Spy
38+
MLRegisterModelInput registerModelInputWithStandaloneConnector;
39+
40+
@Spy
41+
public HttpConnector connector;
42+
43+
@Rule
44+
public ExpectedException exceptionRule = ExpectedException.none();
45+
46+
@Before
47+
public void setUp() throws Exception {
48+
registerModelInputWithInnerConnector = MLRegisterModelInput
49+
.builder()
50+
.modelName("test-model-with-inner-connector")
51+
.functionName(FunctionName.REMOTE)
52+
.build();
53+
54+
registerModelInputWithStandaloneConnector = MLRegisterModelInput
55+
.builder()
56+
.modelName("test-model-with-stand-alone-connector")
57+
.functionName(FunctionName.REMOTE)
58+
.build();
59+
}
60+
61+
@Test
62+
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE() {
63+
Map<String, String> parameters = new HashMap<>();
64+
parameters.put("service_name", "bedrock");
65+
parameters.put("model", "ai21.j2-mid-v1");
66+
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();
67+
68+
updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
69+
assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE);
70+
}
71+
72+
@Test
73+
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_ANTHROPIC_CLAUDE_V3_SONNET_MODEL_INTERFACE() {
74+
Map<String, String> parameters = new HashMap<>();
75+
parameters.put("service_name", "bedrock");
76+
parameters.put("model", "anthropic.claude-3-sonnet-20240229-v1:0");
77+
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();
78+
79+
updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
80+
assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_ANTHROPIC_CLAUDE_V3_SONNET_MODEL_INTERFACE);
81+
}
82+
83+
@Test
84+
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE() {
85+
Map<String, String> parameters = new HashMap<>();
86+
parameters.put("service_name", "bedrock");
87+
parameters.put("model", "anthropic.claude-v2");
88+
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();
89+
90+
updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
91+
assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE);
92+
}
93+
94+
@Test
95+
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE() {
96+
Map<String, String> parameters = new HashMap<>();
97+
parameters.put("service_name", "bedrock");
98+
parameters.put("model", "cohere.embed.english-v3");
99+
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();
100+
101+
updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
102+
assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE);
103+
}
104+
105+
@Test
106+
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_COHERE_EMBED_MULTILINGUAL_V3_MODEL_INTERFACE() {
107+
Map<String, String> parameters = new HashMap<>();
108+
parameters.put("service_name", "bedrock");
109+
parameters.put("model", "cohere.embed.multilingual-v3");
110+
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();
111+
112+
updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
113+
assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_MODEL_INTERFACE);
114+
}
115+
116+
@Test
117+
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE() {
118+
Map<String, String> parameters = new HashMap<>();
119+
parameters.put("service_name", "bedrock");
120+
parameters.put("model", "amazon.titan-embed-text-v1");
121+
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();
122+
123+
updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
124+
assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE);
125+
}
126+
127+
@Test
128+
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE() {
129+
Map<String, String> parameters = new HashMap<>();
130+
parameters.put("service_name", "bedrock");
131+
parameters.put("model", "amazon.titan-embed-image-v1");
132+
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();
133+
134+
updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
135+
assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE);
136+
}
137+
138+
@Test
139+
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorAMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE() {
140+
Map<String, String> parameters = new HashMap<>();
141+
parameters.put("service_name", "comprehend");
142+
parameters.put("api_name", "DetectDominantLanguage");
143+
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();
144+
145+
updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
146+
assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE);
147+
}
148+
149+
@Test
150+
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorAMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE() {
151+
Map<String, String> parameters = new HashMap<>();
152+
parameters.put("service_name", "textract");
153+
parameters.put("api_name", "DetectDocumentText");
154+
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();
155+
156+
updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
157+
assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE);
158+
}
159+
160+
@Test
161+
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorServiceNameNotFound() {
162+
Map<String, String> parameters = new HashMap<>();
163+
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();
164+
165+
updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
166+
assertNull(registerModelInputWithStandaloneConnector.getModelInterface());
167+
}
168+
169+
@Test
170+
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBedrockModelNameNotFound() {
171+
Map<String, String> parameters = new HashMap<>();
172+
parameters.put("service_name", "bedrock");
173+
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();
174+
175+
updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
176+
assertNull(registerModelInputWithStandaloneConnector.getModelInterface());
177+
}
178+
179+
@Test
180+
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorAmazonComprehendAPINameNotFound() {
181+
Map<String, String> parameters = new HashMap<>();
182+
parameters.put("service_name", "comprehend");
183+
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();
184+
185+
updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
186+
assertNull(registerModelInputWithStandaloneConnector.getModelInterface());
187+
}
188+
189+
@Test
190+
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorNullParameters() {
191+
connector = HttpConnector.builder().protocol("http").build();
192+
193+
updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
194+
assertNull(registerModelInputWithStandaloneConnector.getModelInterface());
195+
}
196+
197+
@Test
198+
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorInnerConnectorBEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE() {
199+
Map<String, String> parameters = new HashMap<>();
200+
parameters.put("service_name", "bedrock");
201+
parameters.put("model", "ai21.j2-mid-v1");
202+
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();
203+
registerModelInputWithInnerConnector.setConnector(connector);
204+
updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithInnerConnector);
205+
assertEquals(registerModelInputWithInnerConnector.getModelInterface(), BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE);
206+
}
207+
208+
@Test
209+
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorInnerConnectorNullParameters() {
210+
connector = HttpConnector.builder().protocol("http").build();
211+
registerModelInputWithInnerConnector.setConnector(connector);
212+
updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithInnerConnector);
213+
assertNull(registerModelInputWithInnerConnector.getModelInterface());
214+
}
215+
}

plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java

+12-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
99
import static org.opensearch.ml.common.MLTaskState.FAILED;
1010
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
11+
import static org.opensearch.ml.common.utils.ModelInterfaceUtils.updateRegisterModelInputModelInterfaceFieldsByConnector;
1112
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL;
1213
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;
1314
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX;
@@ -239,7 +240,14 @@ private void doRegister(MLRegisterModelInput registerModelInput, ActionListener<
239240
if (Strings.isNotBlank(registerModelInput.getConnectorId())) {
240241
connectorAccessControlHelper.validateConnectorAccess(client, registerModelInput.getConnectorId(), ActionListener.wrap(r -> {
241242
if (Boolean.TRUE.equals(r)) {
242-
createModelGroup(registerModelInput, listener);
243+
if (registerModelInput.getModelInterface() == null) {
244+
mlModelManager.getConnector(registerModelInput.getConnectorId(), ActionListener.wrap(connector -> {
245+
updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInput, connector);
246+
createModelGroup(registerModelInput, listener);
247+
}, listener::onFailure));
248+
} else {
249+
createModelGroup(registerModelInput, listener);
250+
}
243251
} else {
244252
listener
245253
.onFailure(
@@ -261,6 +269,9 @@ private void doRegister(MLRegisterModelInput registerModelInput, ActionListener<
261269
validateInternalConnector(registerModelInput);
262270
ActionListener<MLCreateConnectorResponse> dryRunResultListener = ActionListener.wrap(res -> {
263271
log.info("Dry run create connector successfully");
272+
if (registerModelInput.getModelInterface() == null) {
273+
updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInput);
274+
}
264275
createModelGroup(registerModelInput, listener);
265276
}, e -> {
266277
log.error(e.getMessage(), e);

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -1620,7 +1620,7 @@ public void getController(String modelId, ActionListener<MLController> listener)
16201620
* @param connectorId connector id
16211621
* @param listener action listener
16221622
*/
1623-
private void getConnector(String connectorId, ActionListener<Connector> listener) {
1623+
public void getConnector(String connectorId, ActionListener<Connector> listener) {
16241624
GetRequest getRequest = new GetRequest().index(CommonValue.ML_CONNECTOR_INDEX).id(connectorId);
16251625
client.get(getRequest, ActionListener.wrap(r -> {
16261626
if (r != null && r.isExists()) {

plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java

+14-2
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,11 @@ protected Response registerRemoteModel(String modelGroupName, String name, Strin
366366
+ " \"description\": \"test model\",\n"
367367
+ " \"connector_id\": \""
368368
+ connectorId
369-
+ "\"\n"
369+
+ "\",\n"
370+
+ " \"interface\": {\n"
371+
+ " \"input\": {},\n"
372+
+ " \"output\": {}\n"
373+
+ " }\n"
370374
+ "}";
371375
return TestHelper
372376
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
@@ -423,7 +427,11 @@ protected Response registerRemoteModelWithLocalRegexGuardrails(String name, Stri
423427
+ " ],\n"
424428
+ " \"regex\": [\"regex1\", \"regex2\"]\n"
425429
+ " }\n"
426-
+ " }\n"
430+
+ "},\n"
431+
+ " \"interface\": {\n"
432+
+ " \"input\": {},\n"
433+
+ " \"output\": {}\n"
434+
+ " }\n"
427435
+ "}";
428436
return TestHelper
429437
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
@@ -461,6 +469,10 @@ protected Response registerRemoteModelWithModelGuardrails(String name, String co
461469
+ " \"connector_id\": \""
462470
+ connectorId
463471
+ "\",\n"
472+
+ " \"interface\": {\n"
473+
+ " \"input\": {},\n"
474+
+ " \"output\": {}\n"
475+
+ " },\n"
464476
+ " \"guardrails\": {\n"
465477
+ " \"type\": \"model\",\n"
466478
+ " \"input_guardrail\": {\n"

plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java

+10-2
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,11 @@ public static Response registerRemoteModel(String modelGroupName, String name, S
814814
+ " \"description\": \"test model\",\n"
815815
+ " \"connector_id\": \""
816816
+ connectorId
817-
+ "\"\n"
817+
+ "\",\n"
818+
+ " \"interface\": {\n"
819+
+ " \"input\": {},\n"
820+
+ " \"output\": {}\n"
821+
+ " }\n"
818822
+ "}";
819823
return TestHelper
820824
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
@@ -856,7 +860,11 @@ public static Response registerRemoteModelWithTTLAndSkipHeapMemCheck(String name
856860
+ " \"deploy_setting\": "
857861
+ " { \"model_ttl_minutes\": "
858862
+ ttl
859-
+ "}\n"
863+
+ "},\n"
864+
+ " \"interface\": {\n"
865+
+ " \"input\": {},\n"
866+
+ " \"output\": {}\n"
867+
+ " }\n"
860868
+ "}";
861869
return TestHelper
862870
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);

plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,11 @@ private void setupLLMModel(String connectorId) throws IOException {
162162
+ " \"description\": \"test model\",\n"
163163
+ " \"connector_id\": \""
164164
+ connectorId
165-
+ "\"\n"
165+
+ "\",\n"
166+
+ " \"interface\": {\n"
167+
+ " \"input\": {},\n"
168+
+ " \"output\": {}\n"
169+
+ " }\n"
166170
+ "}";
167171

168172
registerModel(client(), input, response -> {

0 commit comments

Comments
 (0)