|
8 | 8 | import java.io.IOException;
|
9 | 9 | import java.util.List;
|
10 | 10 | import java.util.Map;
|
| 11 | +import java.util.concurrent.TimeUnit; |
11 | 12 | import java.util.function.Consumer;
|
12 | 13 |
|
13 | 14 | import org.apache.commons.lang3.exception.ExceptionUtils;
|
@@ -237,6 +238,39 @@ public void testPredictRemoteModel() throws IOException, InterruptedException {
|
237 | 238 | assertFalse(((String) responseMap.get("text")).isEmpty());
|
238 | 239 | }
|
239 | 240 |
|
| 241 | + public void testPredictWithAutoDeployAndTTL_RemoteModel() throws IOException, InterruptedException { |
| 242 | + // Skip test if key is null |
| 243 | + if (OPENAI_KEY == null) { |
| 244 | + System.out.println("OPENAI_KEY is null"); |
| 245 | + return; |
| 246 | + } |
| 247 | + Response response = createConnector(completionModelConnectorEntity); |
| 248 | + Map responseMap = parseResponseToMap(response); |
| 249 | + String connectorId = (String) responseMap.get("connector_id"); |
| 250 | + response = registerRemoteModelWithTTL("openAI-GPT-3.5 completions", connectorId, 1); |
| 251 | + responseMap = parseResponseToMap(response); |
| 252 | + String modelId = (String) responseMap.get("model_id"); |
| 253 | + String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}"; |
| 254 | + response = predictRemoteModel(modelId, predictInput); |
| 255 | + responseMap = parseResponseToMap(response); |
| 256 | + List responseList = (List) responseMap.get("inference_results"); |
| 257 | + responseMap = (Map) responseList.get(0); |
| 258 | + responseList = (List) responseMap.get("output"); |
| 259 | + responseMap = (Map) responseList.get(0); |
| 260 | + responseMap = (Map) responseMap.get("dataAsMap"); |
| 261 | + responseList = (List) responseMap.get("choices"); |
| 262 | + if (responseList == null) { |
| 263 | + assertTrue(checkThrottlingOpenAI(responseMap)); |
| 264 | + return; |
| 265 | + } |
| 266 | + responseMap = (Map) responseList.get(0); |
| 267 | + assertFalse(((String) responseMap.get("text")).isEmpty()); |
| 268 | + |
| 269 | + getModelProfile(modelId, verifyRemoteModelDeployed()); |
| 270 | + TimeUnit.SECONDS.sleep(71); |
| 271 | + assertTrue(getModelProfile(modelId, verifyRemoteModelDeployed()).isEmpty()); |
| 272 | + } |
| 273 | + |
240 | 274 | public void testPredictRemoteModelWithInterface(String testCase, Consumer<Map> verifyResponse, Consumer<Exception> verifyException)
|
241 | 275 | throws IOException,
|
242 | 276 | InterruptedException {
|
@@ -841,6 +875,46 @@ public static Response registerRemoteModel(String name, String connectorId) thro
|
841 | 875 | .makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
|
842 | 876 | }
|
843 | 877 |
|
| 878 | + public static Response registerRemoteModelWithTTL(String name, String connectorId, int ttl) throws IOException { |
| 879 | + String registerModelGroupEntity = "{\n" |
| 880 | + + " \"name\": \"remote_model_group\",\n" |
| 881 | + + " \"description\": \"This is an example description\"\n" |
| 882 | + + "}"; |
| 883 | + Response response = TestHelper |
| 884 | + .makeRequest( |
| 885 | + client(), |
| 886 | + "POST", |
| 887 | + "/_plugins/_ml/model_groups/_register", |
| 888 | + null, |
| 889 | + TestHelper.toHttpEntity(registerModelGroupEntity), |
| 890 | + null |
| 891 | + ); |
| 892 | + Map responseMap = parseResponseToMap(response); |
| 893 | + assertEquals((String) responseMap.get("status"), "CREATED"); |
| 894 | + String modelGroupId = (String) responseMap.get("model_group_id"); |
| 895 | + |
| 896 | + String registerModelEntity = "{\n" |
| 897 | + + " \"name\": \"" |
| 898 | + + name |
| 899 | + + "\",\n" |
| 900 | + + " \"function_name\": \"remote\",\n" |
| 901 | + + " \"model_group_id\": \"" |
| 902 | + + modelGroupId |
| 903 | + + "\",\n" |
| 904 | + + " \"version\": \"1.0.0\",\n" |
| 905 | + + " \"description\": \"test model\",\n" |
| 906 | + + " \"connector_id\": \"" |
| 907 | + + connectorId |
| 908 | + + "\",\n" |
| 909 | + + " \"deploy_setting\": " |
| 910 | + + " { \"model_ttl_minutes\": " |
| 911 | + + ttl |
| 912 | + + "}\n" |
| 913 | + + "}"; |
| 914 | + return TestHelper |
| 915 | + .makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null); |
| 916 | + } |
| 917 | + |
844 | 918 | public static Response registerRemoteModelWithInterface(String name, String connectorId, String testCase) throws IOException {
|
845 | 919 | String registerModelGroupEntity = "{\n"
|
846 | 920 | + " \"name\": \"remote_model_group\",\n"
|
|
0 commit comments