Skip to content

Commit f0b0bec

Browse files
committed
add IT for remote model automatic deploy with TTL
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 89f23d2 commit f0b0bec

File tree

2 files changed

+89
-0
lines changed

2 files changed

+89
-0
lines changed

plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java

+15
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import java.nio.file.Path;
2828
import java.util.Arrays;
2929
import java.util.Collections;
30+
import java.util.HashMap;
3031
import java.util.List;
3132
import java.util.Locale;
3233
import java.util.Map;
@@ -861,6 +862,9 @@ public static Map parseResponseToMap(Response response) throws IOException {
861862
public Map getModelProfile(String modelId, Consumer verifyFunction) throws IOException {
862863
Response response = TestHelper.makeRequest(client(), "GET", "/_plugins/_ml/profile/models/" + modelId, null, (String) null, null);
863864
Map profile = parseResponseToMap(response);
865+
if (profile == null || profile.get("nodes") == null) {
866+
return new HashMap();
867+
}
864868
Map<String, Object> nodeProfiles = (Map) profile.get("nodes");
865869
for (Map.Entry<String, Object> entry : nodeProfiles.entrySet()) {
866870
Map<String, Object> modelProfiles = (Map) entry.getValue();
@@ -918,6 +922,17 @@ public Consumer<Map<String, Object>> verifyTextEmbeddingModelDeployed() {
918922
};
919923
}
920924

925+
public Consumer<Map<String, Object>> verifyRemoteModelDeployed() {
926+
return (modelProfile) -> {
927+
if (modelProfile.containsKey("model_state")) {
928+
assertEquals(MLModelState.DEPLOYED.name(), modelProfile.get("model_state"));
929+
assertTrue(((String) modelProfile.get("predictor")).startsWith("org.opensearch.ml.engine.algorithms.remote.RemoteModel@"));
930+
}
931+
List<String> workNodes = (List) modelProfile.get("worker_nodes");
932+
assertTrue(workNodes.size() > 0);
933+
};
934+
}
935+
921936
public Map undeployModel(String modelId) throws IOException {
922937
Response response = TestHelper
923938
.makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_undeploy", null, (String) null, null);

plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java

+74
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.io.IOException;
99
import java.util.List;
1010
import java.util.Map;
11+
import java.util.concurrent.TimeUnit;
1112
import java.util.function.Consumer;
1213

1314
import org.apache.commons.lang3.exception.ExceptionUtils;
@@ -237,6 +238,39 @@ public void testPredictRemoteModel() throws IOException, InterruptedException {
237238
assertFalse(((String) responseMap.get("text")).isEmpty());
238239
}
239240

241+
public void testPredictWithAutoDeployAndTTL_RemoteModel() throws IOException, InterruptedException {
242+
// Skip test if key is null
243+
if (OPENAI_KEY == null) {
244+
System.out.println("OPENAI_KEY is null");
245+
return;
246+
}
247+
Response response = createConnector(completionModelConnectorEntity);
248+
Map responseMap = parseResponseToMap(response);
249+
String connectorId = (String) responseMap.get("connector_id");
250+
response = registerRemoteModelWithTTL("openAI-GPT-3.5 completions", connectorId, 1);
251+
responseMap = parseResponseToMap(response);
252+
String modelId = (String) responseMap.get("model_id");
253+
String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}";
254+
response = predictRemoteModel(modelId, predictInput);
255+
responseMap = parseResponseToMap(response);
256+
List responseList = (List) responseMap.get("inference_results");
257+
responseMap = (Map) responseList.get(0);
258+
responseList = (List) responseMap.get("output");
259+
responseMap = (Map) responseList.get(0);
260+
responseMap = (Map) responseMap.get("dataAsMap");
261+
responseList = (List) responseMap.get("choices");
262+
if (responseList == null) {
263+
assertTrue(checkThrottlingOpenAI(responseMap));
264+
return;
265+
}
266+
responseMap = (Map) responseList.get(0);
267+
assertFalse(((String) responseMap.get("text")).isEmpty());
268+
269+
getModelProfile(modelId, verifyRemoteModelDeployed());
270+
TimeUnit.SECONDS.sleep(71);
271+
assertTrue(getModelProfile(modelId, verifyRemoteModelDeployed()).isEmpty());
272+
}
273+
240274
public void testPredictRemoteModelWithInterface(String testCase, Consumer<Map> verifyResponse, Consumer<Exception> verifyException)
241275
throws IOException,
242276
InterruptedException {
@@ -841,6 +875,46 @@ public static Response registerRemoteModel(String name, String connectorId) thro
841875
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
842876
}
843877

878+
public static Response registerRemoteModelWithTTL(String name, String connectorId, int ttl) throws IOException {
879+
String registerModelGroupEntity = "{\n"
880+
+ " \"name\": \"remote_model_group\",\n"
881+
+ " \"description\": \"This is an example description\"\n"
882+
+ "}";
883+
Response response = TestHelper
884+
.makeRequest(
885+
client(),
886+
"POST",
887+
"/_plugins/_ml/model_groups/_register",
888+
null,
889+
TestHelper.toHttpEntity(registerModelGroupEntity),
890+
null
891+
);
892+
Map responseMap = parseResponseToMap(response);
893+
assertEquals((String) responseMap.get("status"), "CREATED");
894+
String modelGroupId = (String) responseMap.get("model_group_id");
895+
896+
String registerModelEntity = "{\n"
897+
+ " \"name\": \""
898+
+ name
899+
+ "\",\n"
900+
+ " \"function_name\": \"remote\",\n"
901+
+ " \"model_group_id\": \""
902+
+ modelGroupId
903+
+ "\",\n"
904+
+ " \"version\": \"1.0.0\",\n"
905+
+ " \"description\": \"test model\",\n"
906+
+ " \"connector_id\": \""
907+
+ connectorId
908+
+ "\",\n"
909+
+ " \"deploy_setting\": "
910+
+ " { \"model_ttl_minutes\": "
911+
+ ttl
912+
+ "}\n"
913+
+ "}";
914+
return TestHelper
915+
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
916+
}
917+
844918
public static Response registerRemoteModelWithInterface(String name, String connectorId, String testCase) throws IOException {
845919
String registerModelGroupEntity = "{\n"
846920
+ " \"name\": \"remote_model_group\",\n"

0 commit comments

Comments
 (0)