@@ -237,6 +237,70 @@ public void testPredictRemoteModel() throws IOException, InterruptedException {
237
237
assertFalse (((String ) responseMap .get ("text" )).isEmpty ());
238
238
}
239
239
240
+ public void testPredictRemoteModelWithInterface (String testCase , Consumer <Map > verifyResponse , Consumer <Exception > verifyException )
241
+ throws IOException ,
242
+ InterruptedException {
243
+ // Skip test if key is null
244
+ if (OPENAI_KEY == null ) {
245
+ return ;
246
+ }
247
+ Response response = createConnector (completionModelConnectorEntity );
248
+ Map responseMap = parseResponseToMap (response );
249
+ String connectorId = (String ) responseMap .get ("connector_id" );
250
+ response = registerRemoteModelWithInterface ("openAI-GPT-3.5 completions" , connectorId , testCase );
251
+ responseMap = parseResponseToMap (response );
252
+ String taskId = (String ) responseMap .get ("task_id" );
253
+ waitForTask (taskId , MLTaskState .COMPLETED );
254
+ response = getTask (taskId );
255
+ responseMap = parseResponseToMap (response );
256
+ String modelId = (String ) responseMap .get ("model_id" );
257
+ response = deployRemoteModel (modelId );
258
+ responseMap = parseResponseToMap (response );
259
+ taskId = (String ) responseMap .get ("task_id" );
260
+ waitForTask (taskId , MLTaskState .COMPLETED );
261
+ String predictInput = "{\n " + " \" parameters\" : {\n " + " \" prompt\" : \" Say this is a test\" \n " + " }\n " + "}" ;
262
+ try {
263
+ response = predictRemoteModel (modelId , predictInput );
264
+ responseMap = parseResponseToMap (response );
265
+ verifyResponse .accept (responseMap );
266
+ } catch (Exception e ) {
267
+ verifyException .accept (e );
268
+ }
269
+ }
270
+
271
+ public void testPredictRemoteModelWithCorrectInterface () throws IOException , InterruptedException {
272
+ testPredictRemoteModelWithInterface ("correctInterface" , (responseMap ) -> {
273
+ List responseList = (List ) responseMap .get ("inference_results" );
274
+ responseMap = (Map ) responseList .get (0 );
275
+ responseList = (List ) responseMap .get ("output" );
276
+ responseMap = (Map ) responseList .get (0 );
277
+ responseMap = (Map ) responseMap .get ("dataAsMap" );
278
+ responseList = (List ) responseMap .get ("choices" );
279
+ if (responseList == null ) {
280
+ assertTrue (checkThrottlingOpenAI (responseMap ));
281
+ return ;
282
+ }
283
+ responseMap = (Map ) responseList .get (0 );
284
+ assertFalse (((String ) responseMap .get ("text" )).isEmpty ());
285
+ }, null );
286
+ }
287
+
288
+ public void testPredictRemoteModelWithWrongInputInterface () throws IOException , InterruptedException {
289
+ testPredictRemoteModelWithInterface ("wrongInputInterface" , null , (exception ) -> {
290
+ assertTrue (exception instanceof org .opensearch .client .ResponseException );
291
+ String stackTrace = ExceptionUtils .getStackTrace (exception );
292
+ assertTrue (stackTrace .contains ("Error validating input schema" ));
293
+ });
294
+ }
295
+
296
+ public void testPredictRemoteModelWithWrongOutputInterface () throws IOException , InterruptedException {
297
+ testPredictRemoteModelWithInterface ("wrongOutputInterface" , null , (exception ) -> {
298
+ assertTrue (exception instanceof org .opensearch .client .ResponseException );
299
+ String stackTrace = ExceptionUtils .getStackTrace (exception );
300
+ assertTrue (stackTrace .contains ("Error validating output schema" ));
301
+ });
302
+ }
303
+
240
304
public void testUndeployRemoteModel () throws IOException , InterruptedException {
241
305
Response response = createConnector (completionModelConnectorEntity );
242
306
Map responseMap = parseResponseToMap (response );
@@ -777,6 +841,183 @@ public static Response registerRemoteModel(String name, String connectorId) thro
777
841
.makeRequest (client (), "POST" , "/_plugins/_ml/models/_register" , null , TestHelper .toHttpEntity (registerModelEntity ), null );
778
842
}
779
843
844
+ public static Response registerRemoteModelWithInterface (String name , String connectorId , String testCase ) throws IOException {
845
+ String registerModelGroupEntity = "{\n "
846
+ + " \" name\" : \" remote_model_group\" ,\n "
847
+ + " \" description\" : \" This is an example description\" \n "
848
+ + "}" ;
849
+ Response response = TestHelper
850
+ .makeRequest (
851
+ client (),
852
+ "POST" ,
853
+ "/_plugins/_ml/model_groups/_register" ,
854
+ null ,
855
+ TestHelper .toHttpEntity (registerModelGroupEntity ),
856
+ null
857
+ );
858
+ Map responseMap = parseResponseToMap (response );
859
+ assertEquals ((String ) responseMap .get ("status" ), "CREATED" );
860
+ String modelGroupId = (String ) responseMap .get ("model_group_id" );
861
+
862
+ final String openaiConnectorEntityWithCorrectInterface = "{\n "
863
+ + " \" name\" : \" "
864
+ + name
865
+ + "\" ,\n "
866
+ + " \" model_group_id\" : \" "
867
+ + modelGroupId
868
+ + "\" ,\n "
869
+ + " \" function_name\" : \" remote\" ,\n "
870
+ + " \" connector_id\" : \" "
871
+ + connectorId
872
+ + "\" ,\n "
873
+ + " \" interface\" : {\n "
874
+ + " \" input\" : {\n "
875
+ + " \" properties\" : {\n "
876
+ + " \" parameters\" : {\n "
877
+ + " \" properties\" : {\n "
878
+ + " \" prompt\" : {\n "
879
+ + " \" type\" : \" string\" ,\n "
880
+ + " \" description\" : \" This is a test description field\" \n "
881
+ + " }\n "
882
+ + " }\n "
883
+ + " }\n "
884
+ + " }\n "
885
+ + " },\n "
886
+ + " \" output\" : {\n "
887
+ + " \" properties\" : {\n "
888
+ + " \" inference_results\" : {\n "
889
+ + " \" type\" : \" array\" ,\n "
890
+ + " \" items\" : {\n "
891
+ + " \" type\" : \" object\" ,\n "
892
+ + " \" properties\" : {\n "
893
+ + " \" output\" : {\n "
894
+ + " \" type\" : \" array\" ,\n "
895
+ + " \" items\" : {\n "
896
+ + " \" properties\" : {\n "
897
+ + " \" name\" : {\n "
898
+ + " \" type\" : \" string\" ,\n "
899
+ + " \" description\" : \" This is a test description field\" \n "
900
+ + " },\n "
901
+ + " \" dataAsMap\" : {\n "
902
+ + " \" type\" : \" object\" ,\n "
903
+ + " \" description\" : \" This is a test description field\" \n "
904
+ + " }\n "
905
+ + " }\n "
906
+ + " },\n "
907
+ + " \" description\" : \" This is a test description field\" \n "
908
+ + " },\n "
909
+ + " \" status_code\" : {\n "
910
+ + " \" type\" : \" integer\" ,\n "
911
+ + " \" description\" : \" This is a test description field\" \n "
912
+ + " }\n "
913
+ + " }\n "
914
+ + " },\n "
915
+ + " \" description\" : \" This is a test description field\" \n "
916
+ + " }\n "
917
+ + " }\n "
918
+ + " }\n "
919
+ + " }\n "
920
+ + "}" ;
921
+
922
+ final String openaiConnectorEntityWithWrongInputInterface = "{\n "
923
+ + " \" name\" : \" "
924
+ + name
925
+ + "\" ,\n "
926
+ + " \" model_group_id\" : \" "
927
+ + modelGroupId
928
+ + "\" ,\n "
929
+ + " \" function_name\" : \" remote\" ,\n "
930
+ + " \" connector_id\" : \" "
931
+ + connectorId
932
+ + "\" ,\n "
933
+ + " \" interface\" : {\n "
934
+ + " \" input\" : {\n "
935
+ + " \" properties\" : {\n "
936
+ + " \" parameters\" : {\n "
937
+ + " \" properties\" : {\n "
938
+ + " \" prompt\" : {\n "
939
+ + " \" type\" : \" integer\" ,\n "
940
+ + " \" description\" : \" This is a test description field\" \n "
941
+ + " }\n "
942
+ + " }\n "
943
+ + " }\n "
944
+ + " }\n "
945
+ + " }\n "
946
+ + " }\n "
947
+ + "}" ;
948
+
949
+ final String openaiConnectorEntityWithWrongOutputInterface = "{\n "
950
+ + " \" name\" : \" "
951
+ + name
952
+ + "\" ,\n "
953
+ + " \" model_group_id\" : \" "
954
+ + modelGroupId
955
+ + "\" ,\n "
956
+ + " \" function_name\" : \" remote\" ,\n "
957
+ + " \" connector_id\" : \" "
958
+ + connectorId
959
+ + "\" ,\n "
960
+ + " \" interface\" : {\n "
961
+ + " \" output\" : {\n "
962
+ + " \" properties\" : {\n "
963
+ + " \" inference_results\" : {\n "
964
+ + " \" type\" : \" array\" ,\n "
965
+ + " \" items\" : {\n "
966
+ + " \" type\" : \" object\" ,\n "
967
+ + " \" properties\" : {\n "
968
+ + " \" output\" : {\n "
969
+ + " \" type\" : \" integer\" ,\n "
970
+ + " \" description\" : \" This is a test description field\" \n "
971
+ + " },\n "
972
+ + " \" status_code\" : {\n "
973
+ + " \" type\" : \" integer\" ,\n "
974
+ + " \" description\" : \" This is a test description field\" \n "
975
+ + " }\n "
976
+ + " }\n "
977
+ + " },\n "
978
+ + " \" description\" : \" This is a test description field\" \n "
979
+ + " }\n "
980
+ + " }\n "
981
+ + " }\n "
982
+ + " }\n "
983
+ + "}" ;
984
+
985
+ switch (testCase ) {
986
+ case "correctInterface" :
987
+ return TestHelper
988
+ .makeRequest (
989
+ client (),
990
+ "POST" ,
991
+ "/_plugins/_ml/models/_register" ,
992
+ null ,
993
+ TestHelper .toHttpEntity (openaiConnectorEntityWithCorrectInterface ),
994
+ null
995
+ );
996
+ case "wrongInputInterface" :
997
+ return TestHelper
998
+ .makeRequest (
999
+ client (),
1000
+ "POST" ,
1001
+ "/_plugins/_ml/models/_register" ,
1002
+ null ,
1003
+ TestHelper .toHttpEntity (openaiConnectorEntityWithWrongInputInterface ),
1004
+ null
1005
+ );
1006
+ case "wrongOutputInterface" :
1007
+ return TestHelper
1008
+ .makeRequest (
1009
+ client (),
1010
+ "POST" ,
1011
+ "/_plugins/_ml/models/_register" ,
1012
+ null ,
1013
+ TestHelper .toHttpEntity (openaiConnectorEntityWithWrongOutputInterface ),
1014
+ null
1015
+ );
1016
+ default :
1017
+ throw new IllegalArgumentException ("Invalid test case" );
1018
+ }
1019
+ }
1020
+
780
1021
public static Response deployRemoteModel (String modelId ) throws IOException {
781
1022
return TestHelper .makeRequest (client (), "POST" , "/_plugins/_ml/models/" + modelId + "/_deploy" , null , "" , null );
782
1023
}
@@ -831,4 +1072,15 @@ public String registerRemoteModel() throws IOException {
831
1072
logger .info ("task ID created: {}" , taskId );
832
1073
return taskId ;
833
1074
}
1075
+
1076
+ public String registerRemoteModelWithInterface (String testCase ) throws IOException {
1077
+ Response response = createConnector (completionModelConnectorEntity );
1078
+ Map responseMap = parseResponseToMap (response );
1079
+ String connectorId = (String ) responseMap .get ("connector_id" );
1080
+ response = registerRemoteModelWithInterface ("openAI-GPT-3.5 completions" , connectorId , testCase );
1081
+ responseMap = parseResponseToMap (response );
1082
+ String taskId = (String ) responseMap .get ("task_id" );
1083
+ logger .info ("task ID created: {}" , taskId );
1084
+ return taskId ;
1085
+ }
834
1086
}
0 commit comments