8
8
import java .io .IOException ;
9
9
import java .util .List ;
10
10
import java .util .Map ;
11
+ import java .util .concurrent .TimeUnit ;
11
12
import java .util .function .Consumer ;
12
13
13
14
import org .apache .commons .lang3 .exception .ExceptionUtils ;
14
15
import org .apache .http .HttpHeaders ;
15
16
import org .apache .http .message .BasicHeader ;
16
17
import org .junit .Before ;
18
+ import org .junit .Ignore ;
17
19
import org .junit .Rule ;
18
20
import org .junit .rules .ExpectedException ;
19
21
import org .opensearch .client .Response ;
@@ -170,19 +172,6 @@ public void testSearchMLTasks_afterCreation() throws IOException {
170
172
assertEquals ((Double ) 1.0 , (Double ) ((Map ) ((Map ) responseMap .get ("hits" )).get ("total" )).get ("value" ));
171
173
}
172
174
173
- public void testRegisterRemoteModel () throws IOException , InterruptedException {
174
- Response response = createConnector (completionModelConnectorEntity );
175
- Map responseMap = parseResponseToMap (response );
176
- String connectorId = (String ) responseMap .get ("connector_id" );
177
- response = registerRemoteModel ("openAI-GPT-3.5 completions" , connectorId );
178
- responseMap = parseResponseToMap (response );
179
- String taskId = (String ) responseMap .get ("task_id" );
180
- waitForTask (taskId , MLTaskState .COMPLETED );
181
- response = getTask (taskId );
182
- responseMap = parseResponseToMap (response );
183
- assertNotNull (responseMap .get ("model_id" ));
184
- }
185
-
186
175
public void testDeployRemoteModel () throws IOException , InterruptedException {
187
176
Response response = createConnector (completionModelConnectorEntity );
188
177
Map responseMap = parseResponseToMap (response );
@@ -201,25 +190,18 @@ public void testDeployRemoteModel() throws IOException, InterruptedException {
201
190
waitForTask (taskId , MLTaskState .COMPLETED );
202
191
}
203
192
204
- public void testPredictRemoteModel () throws IOException , InterruptedException {
193
+ public void testPredictWithAutoDeployAndTTL_RemoteModel () throws IOException , InterruptedException {
205
194
// Skip test if key is null
206
195
if (OPENAI_KEY == null ) {
196
+ System .out .println ("OPENAI_KEY is null" );
207
197
return ;
208
198
}
209
199
Response response = createConnector (completionModelConnectorEntity );
210
200
Map responseMap = parseResponseToMap (response );
211
201
String connectorId = (String ) responseMap .get ("connector_id" );
212
- response = registerRemoteModel ("openAI-GPT-3.5 completions" , connectorId );
213
- responseMap = parseResponseToMap (response );
214
- String taskId = (String ) responseMap .get ("task_id" );
215
- waitForTask (taskId , MLTaskState .COMPLETED );
216
- response = getTask (taskId );
202
+ response = registerRemoteModelWithTTL ("openAI-GPT-3.5 completions" , connectorId , 1 );
217
203
responseMap = parseResponseToMap (response );
218
204
String modelId = (String ) responseMap .get ("model_id" );
219
- response = deployRemoteModel (modelId );
220
- responseMap = parseResponseToMap (response );
221
- taskId = (String ) responseMap .get ("task_id" );
222
- waitForTask (taskId , MLTaskState .COMPLETED );
223
205
String predictInput = "{\n " + " \" parameters\" : {\n " + " \" prompt\" : \" Say this is a test\" \n " + " }\n " + "}" ;
224
206
response = predictRemoteModel (modelId , predictInput );
225
207
responseMap = parseResponseToMap (response );
@@ -235,6 +217,10 @@ public void testPredictRemoteModel() throws IOException, InterruptedException {
235
217
}
236
218
responseMap = (Map ) responseList .get (0 );
237
219
assertFalse (((String ) responseMap .get ("text" )).isEmpty ());
220
+
221
+ getModelProfile (modelId , verifyRemoteModelDeployed ());
222
+ TimeUnit .SECONDS .sleep (71 );
223
+ assertTrue (getModelProfile (modelId , verifyRemoteModelDeployed ()).isEmpty ());
238
224
}
239
225
240
226
public void testPredictRemoteModelWithInterface (String testCase , Consumer <Map > verifyResponse , Consumer <Exception > verifyException )
@@ -301,26 +287,6 @@ public void testPredictRemoteModelWithWrongOutputInterface() throws IOException,
301
287
});
302
288
}
303
289
304
- public void testUndeployRemoteModel () throws IOException , InterruptedException {
305
- Response response = createConnector (completionModelConnectorEntity );
306
- Map responseMap = parseResponseToMap (response );
307
- String connectorId = (String ) responseMap .get ("connector_id" );
308
- response = registerRemoteModel ("openAI-GPT-3.5 completions" , connectorId );
309
- responseMap = parseResponseToMap (response );
310
- String taskId = (String ) responseMap .get ("task_id" );
311
- waitForTask (taskId , MLTaskState .COMPLETED );
312
- response = getTask (taskId );
313
- responseMap = parseResponseToMap (response );
314
- String modelId = (String ) responseMap .get ("model_id" );
315
- response = deployRemoteModel (modelId );
316
- responseMap = parseResponseToMap (response );
317
- taskId = (String ) responseMap .get ("task_id" );
318
- waitForTask (taskId , MLTaskState .COMPLETED );
319
- response = undeployRemoteModel (modelId );
320
- responseMap = parseResponseToMap (response );
321
- assertTrue (responseMap .toString ().contains ("undeployed" ));
322
- }
323
-
324
290
public void testOpenAIChatCompletionModel () throws IOException , InterruptedException {
325
291
// Skip test if key is null
326
292
if (OPENAI_KEY == null ) {
@@ -384,8 +350,13 @@ public void testOpenAIChatCompletionModel() throws IOException, InterruptedExcep
384
350
responseMap = parseResponseToMap (response );
385
351
// TODO handle throttling error
386
352
assertNotNull (responseMap );
353
+
354
+ response = undeployRemoteModel (modelId );
355
+ responseMap = parseResponseToMap (response );
356
+ assertTrue (responseMap .toString ().contains ("undeployed" ));
387
357
}
388
358
359
+ @ Ignore
389
360
public void testOpenAIEditsModel () throws IOException , InterruptedException {
390
361
// Skip test if key is null
391
362
if (OPENAI_KEY == null ) {
@@ -457,6 +428,7 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException {
457
428
assertFalse (((String ) responseMap .get ("content" )).isEmpty ());
458
429
}
459
430
431
+ @ Ignore
460
432
public void testOpenAIModerationsModel () throws IOException , InterruptedException {
461
433
// Skip test if key is null
462
434
if (OPENAI_KEY == null ) {
@@ -687,6 +659,7 @@ public void testCohereGenerateTextModel() throws IOException, InterruptedExcepti
687
659
assertFalse (((String ) responseMap .get ("text" )).isEmpty ());
688
660
}
689
661
662
+ @ Ignore
690
663
public void testCohereClassifyModel () throws IOException , InterruptedException {
691
664
// Skip test if key is null
692
665
if (COHERE_KEY == null ) {
@@ -841,6 +814,46 @@ public static Response registerRemoteModel(String name, String connectorId) thro
841
814
.makeRequest (client (), "POST" , "/_plugins/_ml/models/_register" , null , TestHelper .toHttpEntity (registerModelEntity ), null );
842
815
}
843
816
817
+ public static Response registerRemoteModelWithTTL (String name , String connectorId , int ttl ) throws IOException {
818
+ String registerModelGroupEntity = "{\n "
819
+ + " \" name\" : \" remote_model_group\" ,\n "
820
+ + " \" description\" : \" This is an example description\" \n "
821
+ + "}" ;
822
+ Response response = TestHelper
823
+ .makeRequest (
824
+ client (),
825
+ "POST" ,
826
+ "/_plugins/_ml/model_groups/_register" ,
827
+ null ,
828
+ TestHelper .toHttpEntity (registerModelGroupEntity ),
829
+ null
830
+ );
831
+ Map responseMap = parseResponseToMap (response );
832
+ assertEquals ((String ) responseMap .get ("status" ), "CREATED" );
833
+ String modelGroupId = (String ) responseMap .get ("model_group_id" );
834
+
835
+ String registerModelEntity = "{\n "
836
+ + " \" name\" : \" "
837
+ + name
838
+ + "\" ,\n "
839
+ + " \" function_name\" : \" remote\" ,\n "
840
+ + " \" model_group_id\" : \" "
841
+ + modelGroupId
842
+ + "\" ,\n "
843
+ + " \" version\" : \" 1.0.0\" ,\n "
844
+ + " \" description\" : \" test model\" ,\n "
845
+ + " \" connector_id\" : \" "
846
+ + connectorId
847
+ + "\" ,\n "
848
+ + " \" deploy_setting\" : "
849
+ + " { \" model_ttl_minutes\" : "
850
+ + ttl
851
+ + "}\n "
852
+ + "}" ;
853
+ return TestHelper
854
+ .makeRequest (client (), "POST" , "/_plugins/_ml/models/_register" , null , TestHelper .toHttpEntity (registerModelEntity ), null );
855
+ }
856
+
844
857
public static Response registerRemoteModelWithInterface (String name , String connectorId , String testCase ) throws IOException {
845
858
String registerModelGroupEntity = "{\n "
846
859
+ " \" name\" : \" remote_model_group\" ,\n "
0 commit comments