Skip to content

Commit a732741

Browse files
(cherry picked from commit a98dbbf) Co-authored-by: Sicheng Song <sicheng.song@outlook.com>
1 parent 937cf87 commit a732741

File tree

1 file changed

+252
-0
lines changed

1 file changed

+252
-0
lines changed

plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java

+252
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,70 @@ public void testPredictRemoteModel() throws IOException, InterruptedException {
237237
assertFalse(((String) responseMap.get("text")).isEmpty());
238238
}
239239

240+
public void testPredictRemoteModelWithInterface(String testCase, Consumer<Map> verifyResponse, Consumer<Exception> verifyException)
241+
throws IOException,
242+
InterruptedException {
243+
// Skip test if key is null
244+
if (OPENAI_KEY == null) {
245+
return;
246+
}
247+
Response response = createConnector(completionModelConnectorEntity);
248+
Map responseMap = parseResponseToMap(response);
249+
String connectorId = (String) responseMap.get("connector_id");
250+
response = registerRemoteModelWithInterface("openAI-GPT-3.5 completions", connectorId, testCase);
251+
responseMap = parseResponseToMap(response);
252+
String taskId = (String) responseMap.get("task_id");
253+
waitForTask(taskId, MLTaskState.COMPLETED);
254+
response = getTask(taskId);
255+
responseMap = parseResponseToMap(response);
256+
String modelId = (String) responseMap.get("model_id");
257+
response = deployRemoteModel(modelId);
258+
responseMap = parseResponseToMap(response);
259+
taskId = (String) responseMap.get("task_id");
260+
waitForTask(taskId, MLTaskState.COMPLETED);
261+
String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}";
262+
try {
263+
response = predictRemoteModel(modelId, predictInput);
264+
responseMap = parseResponseToMap(response);
265+
verifyResponse.accept(responseMap);
266+
} catch (Exception e) {
267+
verifyException.accept(e);
268+
}
269+
}
270+
271+
public void testPredictRemoteModelWithCorrectInterface() throws IOException, InterruptedException {
272+
testPredictRemoteModelWithInterface("correctInterface", (responseMap) -> {
273+
List responseList = (List) responseMap.get("inference_results");
274+
responseMap = (Map) responseList.get(0);
275+
responseList = (List) responseMap.get("output");
276+
responseMap = (Map) responseList.get(0);
277+
responseMap = (Map) responseMap.get("dataAsMap");
278+
responseList = (List) responseMap.get("choices");
279+
if (responseList == null) {
280+
assertTrue(checkThrottlingOpenAI(responseMap));
281+
return;
282+
}
283+
responseMap = (Map) responseList.get(0);
284+
assertFalse(((String) responseMap.get("text")).isEmpty());
285+
}, null);
286+
}
287+
288+
public void testPredictRemoteModelWithWrongInputInterface() throws IOException, InterruptedException {
289+
testPredictRemoteModelWithInterface("wrongInputInterface", null, (exception) -> {
290+
assertTrue(exception instanceof org.opensearch.client.ResponseException);
291+
String stackTrace = ExceptionUtils.getStackTrace(exception);
292+
assertTrue(stackTrace.contains("Error validating input schema"));
293+
});
294+
}
295+
296+
public void testPredictRemoteModelWithWrongOutputInterface() throws IOException, InterruptedException {
297+
testPredictRemoteModelWithInterface("wrongOutputInterface", null, (exception) -> {
298+
assertTrue(exception instanceof org.opensearch.client.ResponseException);
299+
String stackTrace = ExceptionUtils.getStackTrace(exception);
300+
assertTrue(stackTrace.contains("Error validating output schema"));
301+
});
302+
}
303+
240304
public void testUndeployRemoteModel() throws IOException, InterruptedException {
241305
Response response = createConnector(completionModelConnectorEntity);
242306
Map responseMap = parseResponseToMap(response);
@@ -777,6 +841,183 @@ public static Response registerRemoteModel(String name, String connectorId) thro
777841
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
778842
}
779843

844+
public static Response registerRemoteModelWithInterface(String name, String connectorId, String testCase) throws IOException {
845+
String registerModelGroupEntity = "{\n"
846+
+ " \"name\": \"remote_model_group\",\n"
847+
+ " \"description\": \"This is an example description\"\n"
848+
+ "}";
849+
Response response = TestHelper
850+
.makeRequest(
851+
client(),
852+
"POST",
853+
"/_plugins/_ml/model_groups/_register",
854+
null,
855+
TestHelper.toHttpEntity(registerModelGroupEntity),
856+
null
857+
);
858+
Map responseMap = parseResponseToMap(response);
859+
assertEquals((String) responseMap.get("status"), "CREATED");
860+
String modelGroupId = (String) responseMap.get("model_group_id");
861+
862+
final String openaiConnectorEntityWithCorrectInterface = "{\n"
863+
+ " \"name\": \""
864+
+ name
865+
+ "\",\n"
866+
+ " \"model_group_id\": \""
867+
+ modelGroupId
868+
+ "\",\n"
869+
+ " \"function_name\": \"remote\",\n"
870+
+ " \"connector_id\": \""
871+
+ connectorId
872+
+ "\",\n"
873+
+ " \"interface\": {\n"
874+
+ " \"input\": {\n"
875+
+ " \"properties\": {\n"
876+
+ " \"parameters\": {\n"
877+
+ " \"properties\": {\n"
878+
+ " \"prompt\": {\n"
879+
+ " \"type\": \"string\",\n"
880+
+ " \"description\": \"This is a test description field\"\n"
881+
+ " }\n"
882+
+ " }\n"
883+
+ " }\n"
884+
+ " }\n"
885+
+ " },\n"
886+
+ " \"output\": {\n"
887+
+ " \"properties\": {\n"
888+
+ " \"inference_results\": {\n"
889+
+ " \"type\": \"array\",\n"
890+
+ " \"items\": {\n"
891+
+ " \"type\": \"object\",\n"
892+
+ " \"properties\": {\n"
893+
+ " \"output\": {\n"
894+
+ " \"type\": \"array\",\n"
895+
+ " \"items\": {\n"
896+
+ " \"properties\": {\n"
897+
+ " \"name\": {\n"
898+
+ " \"type\": \"string\",\n"
899+
+ " \"description\": \"This is a test description field\"\n"
900+
+ " },\n"
901+
+ " \"dataAsMap\": {\n"
902+
+ " \"type\": \"object\",\n"
903+
+ " \"description\": \"This is a test description field\"\n"
904+
+ " }\n"
905+
+ " }\n"
906+
+ " },\n"
907+
+ " \"description\": \"This is a test description field\"\n"
908+
+ " },\n"
909+
+ " \"status_code\": {\n"
910+
+ " \"type\": \"integer\",\n"
911+
+ " \"description\": \"This is a test description field\"\n"
912+
+ " }\n"
913+
+ " }\n"
914+
+ " },\n"
915+
+ " \"description\": \"This is a test description field\"\n"
916+
+ " }\n"
917+
+ " }\n"
918+
+ " }\n"
919+
+ " }\n"
920+
+ "}";
921+
922+
final String openaiConnectorEntityWithWrongInputInterface = "{\n"
923+
+ " \"name\": \""
924+
+ name
925+
+ "\",\n"
926+
+ " \"model_group_id\": \""
927+
+ modelGroupId
928+
+ "\",\n"
929+
+ " \"function_name\": \"remote\",\n"
930+
+ " \"connector_id\": \""
931+
+ connectorId
932+
+ "\",\n"
933+
+ " \"interface\": {\n"
934+
+ " \"input\": {\n"
935+
+ " \"properties\": {\n"
936+
+ " \"parameters\": {\n"
937+
+ " \"properties\": {\n"
938+
+ " \"prompt\": {\n"
939+
+ " \"type\": \"integer\",\n"
940+
+ " \"description\": \"This is a test description field\"\n"
941+
+ " }\n"
942+
+ " }\n"
943+
+ " }\n"
944+
+ " }\n"
945+
+ " }\n"
946+
+ " }\n"
947+
+ "}";
948+
949+
final String openaiConnectorEntityWithWrongOutputInterface = "{\n"
950+
+ " \"name\": \""
951+
+ name
952+
+ "\",\n"
953+
+ " \"model_group_id\": \""
954+
+ modelGroupId
955+
+ "\",\n"
956+
+ " \"function_name\": \"remote\",\n"
957+
+ " \"connector_id\": \""
958+
+ connectorId
959+
+ "\",\n"
960+
+ " \"interface\": {\n"
961+
+ " \"output\": {\n"
962+
+ " \"properties\": {\n"
963+
+ " \"inference_results\": {\n"
964+
+ " \"type\": \"array\",\n"
965+
+ " \"items\": {\n"
966+
+ " \"type\": \"object\",\n"
967+
+ " \"properties\": {\n"
968+
+ " \"output\": {\n"
969+
+ " \"type\": \"integer\",\n"
970+
+ " \"description\": \"This is a test description field\"\n"
971+
+ " },\n"
972+
+ " \"status_code\": {\n"
973+
+ " \"type\": \"integer\",\n"
974+
+ " \"description\": \"This is a test description field\"\n"
975+
+ " }\n"
976+
+ " }\n"
977+
+ " },\n"
978+
+ " \"description\": \"This is a test description field\"\n"
979+
+ " }\n"
980+
+ " }\n"
981+
+ " }\n"
982+
+ " }\n"
983+
+ "}";
984+
985+
switch (testCase) {
986+
case "correctInterface":
987+
return TestHelper
988+
.makeRequest(
989+
client(),
990+
"POST",
991+
"/_plugins/_ml/models/_register",
992+
null,
993+
TestHelper.toHttpEntity(openaiConnectorEntityWithCorrectInterface),
994+
null
995+
);
996+
case "wrongInputInterface":
997+
return TestHelper
998+
.makeRequest(
999+
client(),
1000+
"POST",
1001+
"/_plugins/_ml/models/_register",
1002+
null,
1003+
TestHelper.toHttpEntity(openaiConnectorEntityWithWrongInputInterface),
1004+
null
1005+
);
1006+
case "wrongOutputInterface":
1007+
return TestHelper
1008+
.makeRequest(
1009+
client(),
1010+
"POST",
1011+
"/_plugins/_ml/models/_register",
1012+
null,
1013+
TestHelper.toHttpEntity(openaiConnectorEntityWithWrongOutputInterface),
1014+
null
1015+
);
1016+
default:
1017+
throw new IllegalArgumentException("Invalid test case");
1018+
}
1019+
}
1020+
7801021
public static Response deployRemoteModel(String modelId) throws IOException {
7811022
return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_deploy", null, "", null);
7821023
}
@@ -831,4 +1072,15 @@ public String registerRemoteModel() throws IOException {
8311072
logger.info("task ID created: {}", taskId);
8321073
return taskId;
8331074
}
1075+
1076+
public String registerRemoteModelWithInterface(String testCase) throws IOException {
1077+
Response response = createConnector(completionModelConnectorEntity);
1078+
Map responseMap = parseResponseToMap(response);
1079+
String connectorId = (String) responseMap.get("connector_id");
1080+
response = registerRemoteModelWithInterface("openAI-GPT-3.5 completions", connectorId, testCase);
1081+
responseMap = parseResponseToMap(response);
1082+
String taskId = (String) responseMap.get("task_id");
1083+
logger.info("task ID created: {}", taskId);
1084+
return taskId;
1085+
}
8341086
}

0 commit comments

Comments
 (0)