Skip to content

Commit cc09f1f

Browse files
add IT for remote model automatic deploy with TTL (opensearch-project#2431) (opensearch-project#2441)
* add IT for remote model automatic deploy with TTL Signed-off-by: Xun Zhang <xunzh@amazon.com> * remove duplicate and unuseful remote inference ITs Signed-off-by: Xun Zhang <xunzh@amazon.com> --------- Signed-off-by: Xun Zhang <xunzh@amazon.com> (cherry picked from commit 4f7dc90) Co-authored-by: Xun Zhang <xunzh@amazon.com>
1 parent e416816 commit cc09f1f

File tree

2 files changed

+71
-43
lines changed

2 files changed

+71
-43
lines changed

plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java

+15
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.nio.file.Path;
2626
import java.util.Arrays;
2727
import java.util.Collections;
28+
import java.util.HashMap;
2829
import java.util.List;
2930
import java.util.Locale;
3031
import java.util.Map;
@@ -846,6 +847,9 @@ public static Map parseResponseToMap(Response response) throws IOException {
846847
public Map getModelProfile(String modelId, Consumer verifyFunction) throws IOException {
847848
Response response = TestHelper.makeRequest(client(), "GET", "/_plugins/_ml/profile/models/" + modelId, null, (String) null, null);
848849
Map profile = parseResponseToMap(response);
850+
if (profile == null || profile.get("nodes") == null) {
851+
return new HashMap();
852+
}
849853
Map<String, Object> nodeProfiles = (Map) profile.get("nodes");
850854
for (Map.Entry<String, Object> entry : nodeProfiles.entrySet()) {
851855
Map<String, Object> modelProfiles = (Map) entry.getValue();
@@ -903,6 +907,17 @@ public Consumer<Map<String, Object>> verifyTextEmbeddingModelDeployed() {
903907
};
904908
}
905909

910+
public Consumer<Map<String, Object>> verifyRemoteModelDeployed() {
911+
return (modelProfile) -> {
912+
if (modelProfile.containsKey("model_state")) {
913+
assertEquals(MLModelState.DEPLOYED.name(), modelProfile.get("model_state"));
914+
assertTrue(((String) modelProfile.get("predictor")).startsWith("org.opensearch.ml.engine.algorithms.remote.RemoteModel@"));
915+
}
916+
List<String> workNodes = (List) modelProfile.get("worker_nodes");
917+
assertTrue(workNodes.size() > 0);
918+
};
919+
}
920+
906921
public Map undeployModel(String modelId) throws IOException {
907922
Response response = TestHelper
908923
.makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_undeploy", null, (String) null, null);

plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java

+56-43
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
import java.io.IOException;
99
import java.util.List;
1010
import java.util.Map;
11+
import java.util.concurrent.TimeUnit;
1112
import java.util.function.Consumer;
1213

1314
import org.apache.commons.lang3.exception.ExceptionUtils;
1415
import org.apache.http.HttpHeaders;
1516
import org.apache.http.message.BasicHeader;
1617
import org.junit.Before;
18+
import org.junit.Ignore;
1719
import org.junit.Rule;
1820
import org.junit.rules.ExpectedException;
1921
import org.opensearch.client.Response;
@@ -170,19 +172,6 @@ public void testSearchMLTasks_afterCreation() throws IOException {
170172
assertEquals((Double) 1.0, (Double) ((Map) ((Map) responseMap.get("hits")).get("total")).get("value"));
171173
}
172174

173-
public void testRegisterRemoteModel() throws IOException, InterruptedException {
174-
Response response = createConnector(completionModelConnectorEntity);
175-
Map responseMap = parseResponseToMap(response);
176-
String connectorId = (String) responseMap.get("connector_id");
177-
response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId);
178-
responseMap = parseResponseToMap(response);
179-
String taskId = (String) responseMap.get("task_id");
180-
waitForTask(taskId, MLTaskState.COMPLETED);
181-
response = getTask(taskId);
182-
responseMap = parseResponseToMap(response);
183-
assertNotNull(responseMap.get("model_id"));
184-
}
185-
186175
public void testDeployRemoteModel() throws IOException, InterruptedException {
187176
Response response = createConnector(completionModelConnectorEntity);
188177
Map responseMap = parseResponseToMap(response);
@@ -201,25 +190,18 @@ public void testDeployRemoteModel() throws IOException, InterruptedException {
201190
waitForTask(taskId, MLTaskState.COMPLETED);
202191
}
203192

204-
public void testPredictRemoteModel() throws IOException, InterruptedException {
193+
public void testPredictWithAutoDeployAndTTL_RemoteModel() throws IOException, InterruptedException {
205194
// Skip test if key is null
206195
if (OPENAI_KEY == null) {
196+
System.out.println("OPENAI_KEY is null");
207197
return;
208198
}
209199
Response response = createConnector(completionModelConnectorEntity);
210200
Map responseMap = parseResponseToMap(response);
211201
String connectorId = (String) responseMap.get("connector_id");
212-
response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId);
213-
responseMap = parseResponseToMap(response);
214-
String taskId = (String) responseMap.get("task_id");
215-
waitForTask(taskId, MLTaskState.COMPLETED);
216-
response = getTask(taskId);
202+
response = registerRemoteModelWithTTL("openAI-GPT-3.5 completions", connectorId, 1);
217203
responseMap = parseResponseToMap(response);
218204
String modelId = (String) responseMap.get("model_id");
219-
response = deployRemoteModel(modelId);
220-
responseMap = parseResponseToMap(response);
221-
taskId = (String) responseMap.get("task_id");
222-
waitForTask(taskId, MLTaskState.COMPLETED);
223205
String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}";
224206
response = predictRemoteModel(modelId, predictInput);
225207
responseMap = parseResponseToMap(response);
@@ -235,6 +217,10 @@ public void testPredictRemoteModel() throws IOException, InterruptedException {
235217
}
236218
responseMap = (Map) responseList.get(0);
237219
assertFalse(((String) responseMap.get("text")).isEmpty());
220+
221+
getModelProfile(modelId, verifyRemoteModelDeployed());
222+
TimeUnit.SECONDS.sleep(71);
223+
assertTrue(getModelProfile(modelId, verifyRemoteModelDeployed()).isEmpty());
238224
}
239225

240226
public void testPredictRemoteModelWithInterface(String testCase, Consumer<Map> verifyResponse, Consumer<Exception> verifyException)
@@ -301,26 +287,6 @@ public void testPredictRemoteModelWithWrongOutputInterface() throws IOException,
301287
});
302288
}
303289

304-
public void testUndeployRemoteModel() throws IOException, InterruptedException {
305-
Response response = createConnector(completionModelConnectorEntity);
306-
Map responseMap = parseResponseToMap(response);
307-
String connectorId = (String) responseMap.get("connector_id");
308-
response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId);
309-
responseMap = parseResponseToMap(response);
310-
String taskId = (String) responseMap.get("task_id");
311-
waitForTask(taskId, MLTaskState.COMPLETED);
312-
response = getTask(taskId);
313-
responseMap = parseResponseToMap(response);
314-
String modelId = (String) responseMap.get("model_id");
315-
response = deployRemoteModel(modelId);
316-
responseMap = parseResponseToMap(response);
317-
taskId = (String) responseMap.get("task_id");
318-
waitForTask(taskId, MLTaskState.COMPLETED);
319-
response = undeployRemoteModel(modelId);
320-
responseMap = parseResponseToMap(response);
321-
assertTrue(responseMap.toString().contains("undeployed"));
322-
}
323-
324290
public void testOpenAIChatCompletionModel() throws IOException, InterruptedException {
325291
// Skip test if key is null
326292
if (OPENAI_KEY == null) {
@@ -384,8 +350,13 @@ public void testOpenAIChatCompletionModel() throws IOException, InterruptedExcep
384350
responseMap = parseResponseToMap(response);
385351
// TODO handle throttling error
386352
assertNotNull(responseMap);
353+
354+
response = undeployRemoteModel(modelId);
355+
responseMap = parseResponseToMap(response);
356+
assertTrue(responseMap.toString().contains("undeployed"));
387357
}
388358

359+
@Ignore
389360
public void testOpenAIEditsModel() throws IOException, InterruptedException {
390361
// Skip test if key is null
391362
if (OPENAI_KEY == null) {
@@ -457,6 +428,7 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException {
457428
assertFalse(((String) responseMap.get("content")).isEmpty());
458429
}
459430

431+
@Ignore
460432
public void testOpenAIModerationsModel() throws IOException, InterruptedException {
461433
// Skip test if key is null
462434
if (OPENAI_KEY == null) {
@@ -687,6 +659,7 @@ public void testCohereGenerateTextModel() throws IOException, InterruptedExcepti
687659
assertFalse(((String) responseMap.get("text")).isEmpty());
688660
}
689661

662+
@Ignore
690663
public void testCohereClassifyModel() throws IOException, InterruptedException {
691664
// Skip test if key is null
692665
if (COHERE_KEY == null) {
@@ -841,6 +814,46 @@ public static Response registerRemoteModel(String name, String connectorId) thro
841814
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
842815
}
843816

817+
public static Response registerRemoteModelWithTTL(String name, String connectorId, int ttl) throws IOException {
818+
String registerModelGroupEntity = "{\n"
819+
+ " \"name\": \"remote_model_group\",\n"
820+
+ " \"description\": \"This is an example description\"\n"
821+
+ "}";
822+
Response response = TestHelper
823+
.makeRequest(
824+
client(),
825+
"POST",
826+
"/_plugins/_ml/model_groups/_register",
827+
null,
828+
TestHelper.toHttpEntity(registerModelGroupEntity),
829+
null
830+
);
831+
Map responseMap = parseResponseToMap(response);
832+
assertEquals((String) responseMap.get("status"), "CREATED");
833+
String modelGroupId = (String) responseMap.get("model_group_id");
834+
835+
String registerModelEntity = "{\n"
836+
+ " \"name\": \""
837+
+ name
838+
+ "\",\n"
839+
+ " \"function_name\": \"remote\",\n"
840+
+ " \"model_group_id\": \""
841+
+ modelGroupId
842+
+ "\",\n"
843+
+ " \"version\": \"1.0.0\",\n"
844+
+ " \"description\": \"test model\",\n"
845+
+ " \"connector_id\": \""
846+
+ connectorId
847+
+ "\",\n"
848+
+ " \"deploy_setting\": "
849+
+ " { \"model_ttl_minutes\": "
850+
+ ttl
851+
+ "}\n"
852+
+ "}";
853+
return TestHelper
854+
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
855+
}
856+
844857
public static Response registerRemoteModelWithInterface(String name, String connectorId, String testCase) throws IOException {
845858
String registerModelGroupEntity = "{\n"
846859
+ " \"name\": \"remote_model_group\",\n"

0 commit comments

Comments
 (0)