@@ -287,6 +287,69 @@ public void testPredictRemoteModelWithWrongOutputInterface() throws IOException,
287
287
});
288
288
}
289
289
290
+ public void testPredictRemoteModelWithSkipValidatingMissingParameter (
291
+ String testCase ,
292
+ Consumer <Map > verifyResponse ,
293
+ Consumer <Exception > verifyException
294
+ ) throws IOException ,
295
+ InterruptedException {
296
+ // Skip test if key is null
297
+ if (OPENAI_KEY == null ) {
298
+ return ;
299
+ }
300
+ Response response = createConnector (this .getConnectorBodyBySkipValidatingMissingParameter (testCase ));
301
+ Map responseMap = parseResponseToMap (response );
302
+ String connectorId = (String ) responseMap .get ("connector_id" );
303
+ response = registerRemoteModelWithInterface ("openAI-GPT-3.5 completions" , connectorId , "correctInterface" );
304
+ responseMap = parseResponseToMap (response );
305
+ String taskId = (String ) responseMap .get ("task_id" );
306
+ waitForTask (taskId , MLTaskState .COMPLETED );
307
+ response = getTask (taskId );
308
+ responseMap = parseResponseToMap (response );
309
+ String modelId = (String ) responseMap .get ("model_id" );
310
+ response = deployRemoteModel (modelId );
311
+ responseMap = parseResponseToMap (response );
312
+ taskId = (String ) responseMap .get ("task_id" );
313
+ waitForTask (taskId , MLTaskState .COMPLETED );
314
+ String predictInput = "{\n " + " \" parameters\" : {\n " + " \" prompt\" : \" Say this is a ${parameters.test}\" \n " + " }\n " + "}" ;
315
+ try {
316
+ response = predictRemoteModel (modelId , predictInput );
317
+ responseMap = parseResponseToMap (response );
318
+ verifyResponse .accept (responseMap );
319
+ } catch (Exception e ) {
320
+ verifyException .accept (e );
321
+ }
322
+ }
323
+
324
+ public void testPredictRemoteModelWithSkipValidatingMissingParameterMissing () throws IOException , InterruptedException {
325
+ testPredictRemoteModelWithSkipValidatingMissingParameter ("missing" , null , (exception ) -> {
326
+ assertTrue (exception .getMessage ().contains ("Some parameter placeholder not filled in payload: test" ));
327
+ });
328
+ }
329
+
330
+ public void testPredictRemoteModelWithSkipValidatingMissingParameterEnabled () throws IOException , InterruptedException {
331
+ testPredictRemoteModelWithSkipValidatingMissingParameter ("enabled" , (responseMap ) -> {
332
+ List responseList = (List ) responseMap .get ("inference_results" );
333
+ responseMap = (Map ) responseList .get (0 );
334
+ responseList = (List ) responseMap .get ("output" );
335
+ responseMap = (Map ) responseList .get (0 );
336
+ responseMap = (Map ) responseMap .get ("dataAsMap" );
337
+ responseList = (List ) responseMap .get ("choices" );
338
+ if (responseList == null ) {
339
+ assertTrue (checkThrottlingOpenAI (responseMap ));
340
+ return ;
341
+ }
342
+ responseMap = (Map ) responseList .get (0 );
343
+ assertFalse (((String ) responseMap .get ("text" )).isEmpty ());
344
+ }, null );
345
+ }
346
+
347
+ public void testPredictRemoteModelWithSkipValidatingMissingParameterDisabled () throws IOException , InterruptedException {
348
+ testPredictRemoteModelWithSkipValidatingMissingParameter ("disabled" , null , (exception ) -> {
349
+ assertTrue (exception .getMessage ().contains ("Some parameter placeholder not filled in payload: test" ));
350
+ });
351
+ }
352
+
290
353
public void testOpenAIChatCompletionModel () throws IOException , InterruptedException {
291
354
// Skip test if key is null
292
355
if (OPENAI_KEY == null ) {
@@ -870,6 +933,85 @@ public static Response registerRemoteModelWithTTLAndSkipHeapMemCheck(String name
870
933
.makeRequest (client (), "POST" , "/_plugins/_ml/models/_register" , null , TestHelper .toHttpEntity (registerModelEntity ), null );
871
934
}
872
935
936
+ private String getConnectorBodyBySkipValidatingMissingParameter (String testCase ) {
937
+ return switch (testCase ) {
938
+ case "missing" -> completionModelConnectorEntity ;
939
+ case "enabled" -> "{\n "
940
+ + "\" name\" : \" OpenAI Connector\" ,\n "
941
+ + "\" description\" : \" The connector to public OpenAI model service for GPT 3.5\" ,\n "
942
+ + "\" version\" : 1,\n "
943
+ + "\" client_config\" : {\n "
944
+ + " \" max_connection\" : 20,\n "
945
+ + " \" connection_timeout\" : 50000,\n "
946
+ + " \" read_timeout\" : 50000\n "
947
+ + " },\n "
948
+ + "\" protocol\" : \" http\" ,\n "
949
+ + "\" parameters\" : {\n "
950
+ + " \" endpoint\" : \" api.openai.com\" ,\n "
951
+ + " \" auth\" : \" API_Key\" ,\n "
952
+ + " \" content_type\" : \" application/json\" ,\n "
953
+ + " \" max_tokens\" : 7,\n "
954
+ + " \" temperature\" : 0,\n "
955
+ + " \" model\" : \" gpt-3.5-turbo-instruct\" ,\n "
956
+ + " \" skip_validating_missing_parameters\" : \" true\" \n "
957
+ + " },\n "
958
+ + " \" credential\" : {\n "
959
+ + " \" openAI_key\" : \" "
960
+ + this .OPENAI_KEY
961
+ + "\" \n "
962
+ + " },\n "
963
+ + " \" actions\" : [\n "
964
+ + " {"
965
+ + " \" action_type\" : \" predict\" ,\n "
966
+ + " \" method\" : \" POST\" ,\n "
967
+ + " \" url\" : \" https://${parameters.endpoint}/v1/completions\" ,\n "
968
+ + " \" headers\" : {\n "
969
+ + " \" Authorization\" : \" Bearer ${credential.openAI_key}\" \n "
970
+ + " },\n "
971
+ + " \" request_body\" : \" { \\ \" model\\ \" : \\ \" ${parameters.model}\\ \" , \\ \" prompt\\ \" : \\ \" ${parameters.prompt}\\ \" , \\ \" max_tokens\\ \" : ${parameters.max_tokens}, \\ \" temperature\\ \" : ${parameters.temperature} }\" \n "
972
+ + " }\n "
973
+ + " ]\n "
974
+ + "}" ;
975
+ case "disabled" -> "{\n "
976
+ + "\" name\" : \" OpenAI Connector\" ,\n "
977
+ + "\" description\" : \" The connector to public OpenAI model service for GPT 3.5\" ,\n "
978
+ + "\" version\" : 1,\n "
979
+ + "\" client_config\" : {\n "
980
+ + " \" max_connection\" : 20,\n "
981
+ + " \" connection_timeout\" : 50000,\n "
982
+ + " \" read_timeout\" : 50000\n "
983
+ + " },\n "
984
+ + "\" protocol\" : \" http\" ,\n "
985
+ + "\" parameters\" : {\n "
986
+ + " \" endpoint\" : \" api.openai.com\" ,\n "
987
+ + " \" auth\" : \" API_Key\" ,\n "
988
+ + " \" content_type\" : \" application/json\" ,\n "
989
+ + " \" max_tokens\" : 7,\n "
990
+ + " \" temperature\" : 0,\n "
991
+ + " \" model\" : \" gpt-3.5-turbo-instruct\" ,\n "
992
+ + " \" skip_validating_missing_parameters\" : \" false\" \n "
993
+ + " },\n "
994
+ + " \" credential\" : {\n "
995
+ + " \" openAI_key\" : \" "
996
+ + this .OPENAI_KEY
997
+ + "\" \n "
998
+ + " },\n "
999
+ + " \" actions\" : [\n "
1000
+ + " {"
1001
+ + " \" action_type\" : \" predict\" ,\n "
1002
+ + " \" method\" : \" POST\" ,\n "
1003
+ + " \" url\" : \" https://${parameters.endpoint}/v1/completions\" ,\n "
1004
+ + " \" headers\" : {\n "
1005
+ + " \" Authorization\" : \" Bearer ${credential.openAI_key}\" \n "
1006
+ + " },\n "
1007
+ + " \" request_body\" : \" { \\ \" model\\ \" : \\ \" ${parameters.model}\\ \" , \\ \" prompt\\ \" : \\ \" ${parameters.prompt}\\ \" , \\ \" max_tokens\\ \" : ${parameters.max_tokens}, \\ \" temperature\\ \" : ${parameters.temperature} }\" \n "
1008
+ + " }\n "
1009
+ + " ]\n "
1010
+ + "}" ;
1011
+ default -> throw new IllegalArgumentException ("Invalid test case" );
1012
+ };
1013
+ }
1014
+
873
1015
public static Response registerRemoteModelWithInterface (String name , String connectorId , String testCase ) throws IOException {
874
1016
String registerModelGroupEntity = "{\n "
875
1017
+ " \" name\" : \" remote_model_group\" ,\n "
0 commit comments