Skip to content

Commit 9663053

Browse files
authored
Enhance: support skip_validating_missing_parameters in connector (opensearch-project#2812)
* introduce skip parameter validation Signed-off-by: yuye-aws <yuyezhu@amazon.com> * implement ut Signed-off-by: yuye-aws <yuyezhu@amazon.com> * implement it Signed-off-by: yuye-aws <yuyezhu@amazon.com> * spotless apply Signed-off-by: yuye-aws <yuyezhu@amazon.com> --------- Signed-off-by: yuye-aws <yuyezhu@amazon.com>
1 parent a4dff63 commit 9663053

File tree

4 files changed

+320
-1
lines changed

4 files changed

+320
-1
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java

+2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@
5858
public class ConnectorUtils {
5959

6060
private static final Aws4Signer signer;
61+
public static final String SKIP_VALIDATE_MISSING_PARAMETERS = "skip_validating_missing_parameters";
62+
6163
static {
6264
signer = Aws4Signer.create();
6365
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.ml.engine.algorithms.remote;
77

8+
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.SKIP_VALIDATE_MISSING_PARAMETERS;
89
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.escapeRemoteInferenceInputData;
910
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput;
1011

@@ -189,7 +190,9 @@ default void preparePayloadAndInvoke(
189190
// override again to always prioritize the input parameter
190191
parameters.putAll(inputParameters);
191192
String payload = connector.createPayload(action, parameters);
192-
connector.validatePayload(payload);
193+
if (!Boolean.parseBoolean(parameters.getOrDefault(SKIP_VALIDATE_MISSING_PARAMETERS, "false"))) {
194+
connector.validatePayload(payload);
195+
}
193196
String userStr = getClient()
194197
.threadPool()
195198
.getThreadContext()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.engine.algorithms.remote;
7+
8+
import static org.mockito.ArgumentMatchers.any;
9+
import static org.mockito.Mockito.argThat;
10+
import static org.mockito.Mockito.spy;
11+
import static org.mockito.Mockito.times;
12+
import static org.mockito.Mockito.when;
13+
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
14+
import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD;
15+
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
16+
import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD;
17+
import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD;
18+
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.SKIP_VALIDATE_MISSING_PARAMETERS;
19+
20+
import java.util.Arrays;
21+
import java.util.Map;
22+
23+
import org.junit.Assert;
24+
import org.junit.Before;
25+
import org.junit.Test;
26+
import org.mockito.Mock;
27+
import org.mockito.Mockito;
28+
import org.mockito.MockitoAnnotations;
29+
import org.opensearch.client.Client;
30+
import org.opensearch.common.collect.Tuple;
31+
import org.opensearch.common.settings.Settings;
32+
import org.opensearch.common.util.concurrent.ThreadContext;
33+
import org.opensearch.core.action.ActionListener;
34+
import org.opensearch.ingest.TestTemplateService;
35+
import org.opensearch.ml.common.FunctionName;
36+
import org.opensearch.ml.common.connector.AwsConnector;
37+
import org.opensearch.ml.common.connector.Connector;
38+
import org.opensearch.ml.common.connector.ConnectorAction;
39+
import org.opensearch.ml.common.connector.ConnectorClientConfig;
40+
import org.opensearch.ml.common.connector.RetryBackoffPolicy;
41+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
42+
import org.opensearch.ml.common.input.MLInput;
43+
import org.opensearch.ml.common.output.model.ModelTensors;
44+
import org.opensearch.ml.engine.encryptor.Encryptor;
45+
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
46+
import org.opensearch.script.ScriptService;
47+
import org.opensearch.threadpool.ThreadPool;
48+
49+
import com.google.common.collect.ImmutableMap;
50+
51+
public class RemoteConnectorExecutorTest {
52+
53+
Encryptor encryptor;
54+
55+
@Mock
56+
Client client;
57+
58+
@Mock
59+
ThreadPool threadPool;
60+
61+
@Mock
62+
private ScriptService scriptService;
63+
64+
@Mock
65+
ActionListener<Tuple<Integer, ModelTensors>> actionListener;
66+
67+
@Before
68+
public void setUp() {
69+
MockitoAnnotations.openMocks(this);
70+
encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
71+
when(scriptService.compile(any(), any()))
72+
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory("{\"result\": \"hello world\"}"));
73+
}
74+
75+
private Connector getConnector(Map<String, String> parameters) {
76+
ConnectorAction predictAction = ConnectorAction
77+
.builder()
78+
.actionType(PREDICT)
79+
.method("POST")
80+
.url("http:///mock")
81+
.requestBody("{\"input\": \"${parameters.input}\"}")
82+
.build();
83+
Map<String, String> credential = ImmutableMap
84+
.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
85+
return AwsConnector
86+
.awsConnectorBuilder()
87+
.name("test connector")
88+
.version("1")
89+
.protocol("http")
90+
.parameters(parameters)
91+
.credential(credential)
92+
.actions(Arrays.asList(predictAction))
93+
.connectorClientConfig(new ConnectorClientConfig(10, 10, 10, 1, 1, 0, RetryBackoffPolicy.CONSTANT))
94+
.build();
95+
}
96+
97+
private AwsConnectorExecutor getExecutor(Connector connector) {
98+
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
99+
Settings settings = Settings.builder().build();
100+
ThreadContext threadContext = new ThreadContext(settings);
101+
when(executor.getClient()).thenReturn(client);
102+
when(client.threadPool()).thenReturn(threadPool);
103+
when(threadPool.getThreadContext()).thenReturn(threadContext);
104+
return executor;
105+
}
106+
107+
@Test
108+
public void executePreparePayloadAndInvoke_SkipValidateMissingParameterDisabled() {
109+
Map<String, String> parameters = ImmutableMap
110+
.of(SKIP_VALIDATE_MISSING_PARAMETERS, "false", SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2");
111+
Connector connector = getConnector(parameters);
112+
AwsConnectorExecutor executor = getExecutor(connector);
113+
114+
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
115+
.builder()
116+
.parameters(Map.of("input", "You are a ${parameters.role}"))
117+
.actionType(PREDICT)
118+
.build();
119+
String actionType = inputDataSet.getActionType().toString();
120+
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build();
121+
122+
Exception exception = Assert
123+
.assertThrows(
124+
IllegalArgumentException.class,
125+
() -> executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener)
126+
);
127+
assert exception.getMessage().contains("Some parameter placeholder not filled in payload: role");
128+
}
129+
130+
@Test
131+
public void executePreparePayloadAndInvoke_SkipValidateMissingParameterEnabled() {
132+
Map<String, String> parameters = ImmutableMap
133+
.of(SKIP_VALIDATE_MISSING_PARAMETERS, "true", SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2");
134+
Connector connector = getConnector(parameters);
135+
AwsConnectorExecutor executor = getExecutor(connector);
136+
137+
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
138+
.builder()
139+
.parameters(Map.of("input", "You are a ${parameters.role}"))
140+
.actionType(PREDICT)
141+
.build();
142+
String actionType = inputDataSet.getActionType().toString();
143+
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build();
144+
145+
executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener);
146+
Mockito
147+
.verify(executor, times(1))
148+
.invokeRemoteService(any(), any(), any(), argThat(argument -> argument.contains("You are a ${parameters.role}")), any(), any());
149+
}
150+
151+
@Test
152+
public void executePreparePayloadAndInvoke_SkipValidateMissingParameterDefault() {
153+
Map<String, String> parameters = ImmutableMap.of(SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2");
154+
Connector connector = getConnector(parameters);
155+
AwsConnectorExecutor executor = getExecutor(connector);
156+
157+
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
158+
.builder()
159+
.parameters(Map.of("input", "You are a ${parameters.role}"))
160+
.actionType(PREDICT)
161+
.build();
162+
String actionType = inputDataSet.getActionType().toString();
163+
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build();
164+
165+
Exception exception = Assert
166+
.assertThrows(
167+
IllegalArgumentException.class,
168+
() -> executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener)
169+
);
170+
assert exception.getMessage().contains("Some parameter placeholder not filled in payload: role");
171+
}
172+
}

plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java

+142
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,69 @@ public void testPredictRemoteModelWithWrongOutputInterface() throws IOException,
287287
});
288288
}
289289

290+
public void testPredictRemoteModelWithSkipValidatingMissingParameter(
291+
String testCase,
292+
Consumer<Map> verifyResponse,
293+
Consumer<Exception> verifyException
294+
) throws IOException,
295+
InterruptedException {
296+
// Skip test if key is null
297+
if (OPENAI_KEY == null) {
298+
return;
299+
}
300+
Response response = createConnector(this.getConnectorBodyBySkipValidatingMissingParameter(testCase));
301+
Map responseMap = parseResponseToMap(response);
302+
String connectorId = (String) responseMap.get("connector_id");
303+
response = registerRemoteModelWithInterface("openAI-GPT-3.5 completions", connectorId, "correctInterface");
304+
responseMap = parseResponseToMap(response);
305+
String taskId = (String) responseMap.get("task_id");
306+
waitForTask(taskId, MLTaskState.COMPLETED);
307+
response = getTask(taskId);
308+
responseMap = parseResponseToMap(response);
309+
String modelId = (String) responseMap.get("model_id");
310+
response = deployRemoteModel(modelId);
311+
responseMap = parseResponseToMap(response);
312+
taskId = (String) responseMap.get("task_id");
313+
waitForTask(taskId, MLTaskState.COMPLETED);
314+
String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a ${parameters.test}\"\n" + " }\n" + "}";
315+
try {
316+
response = predictRemoteModel(modelId, predictInput);
317+
responseMap = parseResponseToMap(response);
318+
verifyResponse.accept(responseMap);
319+
} catch (Exception e) {
320+
verifyException.accept(e);
321+
}
322+
}
323+
324+
public void testPredictRemoteModelWithSkipValidatingMissingParameterMissing() throws IOException, InterruptedException {
325+
testPredictRemoteModelWithSkipValidatingMissingParameter("missing", null, (exception) -> {
326+
assertTrue(exception.getMessage().contains("Some parameter placeholder not filled in payload: test"));
327+
});
328+
}
329+
330+
public void testPredictRemoteModelWithSkipValidatingMissingParameterEnabled() throws IOException, InterruptedException {
331+
testPredictRemoteModelWithSkipValidatingMissingParameter("enabled", (responseMap) -> {
332+
List responseList = (List) responseMap.get("inference_results");
333+
responseMap = (Map) responseList.get(0);
334+
responseList = (List) responseMap.get("output");
335+
responseMap = (Map) responseList.get(0);
336+
responseMap = (Map) responseMap.get("dataAsMap");
337+
responseList = (List) responseMap.get("choices");
338+
if (responseList == null) {
339+
assertTrue(checkThrottlingOpenAI(responseMap));
340+
return;
341+
}
342+
responseMap = (Map) responseList.get(0);
343+
assertFalse(((String) responseMap.get("text")).isEmpty());
344+
}, null);
345+
}
346+
347+
public void testPredictRemoteModelWithSkipValidatingMissingParameterDisabled() throws IOException, InterruptedException {
348+
testPredictRemoteModelWithSkipValidatingMissingParameter("disabled", null, (exception) -> {
349+
assertTrue(exception.getMessage().contains("Some parameter placeholder not filled in payload: test"));
350+
});
351+
}
352+
290353
public void testOpenAIChatCompletionModel() throws IOException, InterruptedException {
291354
// Skip test if key is null
292355
if (OPENAI_KEY == null) {
@@ -870,6 +933,85 @@ public static Response registerRemoteModelWithTTLAndSkipHeapMemCheck(String name
870933
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
871934
}
872935

936+
private String getConnectorBodyBySkipValidatingMissingParameter(String testCase) {
937+
return switch (testCase) {
938+
case "missing" -> completionModelConnectorEntity;
939+
case "enabled" -> "{\n"
940+
+ "\"name\": \"OpenAI Connector\",\n"
941+
+ "\"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n"
942+
+ "\"version\": 1,\n"
943+
+ "\"client_config\": {\n"
944+
+ " \"max_connection\": 20,\n"
945+
+ " \"connection_timeout\": 50000,\n"
946+
+ " \"read_timeout\": 50000\n"
947+
+ " },\n"
948+
+ "\"protocol\": \"http\",\n"
949+
+ "\"parameters\": {\n"
950+
+ " \"endpoint\": \"api.openai.com\",\n"
951+
+ " \"auth\": \"API_Key\",\n"
952+
+ " \"content_type\": \"application/json\",\n"
953+
+ " \"max_tokens\": 7,\n"
954+
+ " \"temperature\": 0,\n"
955+
+ " \"model\": \"gpt-3.5-turbo-instruct\",\n"
956+
+ " \"skip_validating_missing_parameters\": \"true\"\n"
957+
+ " },\n"
958+
+ " \"credential\": {\n"
959+
+ " \"openAI_key\": \""
960+
+ this.OPENAI_KEY
961+
+ "\"\n"
962+
+ " },\n"
963+
+ " \"actions\": [\n"
964+
+ " {"
965+
+ " \"action_type\": \"predict\",\n"
966+
+ " \"method\": \"POST\",\n"
967+
+ " \"url\": \"https://${parameters.endpoint}/v1/completions\",\n"
968+
+ " \"headers\": {\n"
969+
+ " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n"
970+
+ " },\n"
971+
+ " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\", \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"temperature\\\": ${parameters.temperature} }\"\n"
972+
+ " }\n"
973+
+ " ]\n"
974+
+ "}";
975+
case "disabled" -> "{\n"
976+
+ "\"name\": \"OpenAI Connector\",\n"
977+
+ "\"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n"
978+
+ "\"version\": 1,\n"
979+
+ "\"client_config\": {\n"
980+
+ " \"max_connection\": 20,\n"
981+
+ " \"connection_timeout\": 50000,\n"
982+
+ " \"read_timeout\": 50000\n"
983+
+ " },\n"
984+
+ "\"protocol\": \"http\",\n"
985+
+ "\"parameters\": {\n"
986+
+ " \"endpoint\": \"api.openai.com\",\n"
987+
+ " \"auth\": \"API_Key\",\n"
988+
+ " \"content_type\": \"application/json\",\n"
989+
+ " \"max_tokens\": 7,\n"
990+
+ " \"temperature\": 0,\n"
991+
+ " \"model\": \"gpt-3.5-turbo-instruct\",\n"
992+
+ " \"skip_validating_missing_parameters\": \"false\"\n"
993+
+ " },\n"
994+
+ " \"credential\": {\n"
995+
+ " \"openAI_key\": \""
996+
+ this.OPENAI_KEY
997+
+ "\"\n"
998+
+ " },\n"
999+
+ " \"actions\": [\n"
1000+
+ " {"
1001+
+ " \"action_type\": \"predict\",\n"
1002+
+ " \"method\": \"POST\",\n"
1003+
+ " \"url\": \"https://${parameters.endpoint}/v1/completions\",\n"
1004+
+ " \"headers\": {\n"
1005+
+ " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n"
1006+
+ " },\n"
1007+
+ " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\", \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"temperature\\\": ${parameters.temperature} }\"\n"
1008+
+ " }\n"
1009+
+ " ]\n"
1010+
+ "}";
1011+
default -> throw new IllegalArgumentException("Invalid test case");
1012+
};
1013+
}
1014+
8731015
public static Response registerRemoteModelWithInterface(String name, String connectorId, String testCase) throws IOException {
8741016
String registerModelGroupEntity = "{\n"
8751017
+ " \"name\": \"remote_model_group\",\n"

0 commit comments

Comments
 (0)