6
6
package org .opensearch .ml .engine .algorithms .remote ;
7
7
8
8
import static org .junit .Assert .assertEquals ;
9
+ import static org .mockito .ArgumentMatchers .any ;
9
10
import static org .mockito .Mockito .spy ;
10
11
import static org .mockito .Mockito .times ;
11
12
import static org .mockito .Mockito .when ;
30
31
import org .opensearch .common .settings .Settings ;
31
32
import org .opensearch .common .util .concurrent .ThreadContext ;
32
33
import org .opensearch .core .action .ActionListener ;
34
+ import org .opensearch .ingest .TestTemplateService ;
33
35
import org .opensearch .ml .common .FunctionName ;
34
36
import org .opensearch .ml .common .connector .AwsConnector ;
35
37
import org .opensearch .ml .common .connector .Connector ;
42
44
import org .opensearch .ml .common .transport .MLTaskResponse ;
43
45
import org .opensearch .ml .engine .encryptor .Encryptor ;
44
46
import org .opensearch .ml .engine .encryptor .EncryptorImpl ;
47
+ import org .opensearch .script .ScriptService ;
45
48
import org .opensearch .threadpool .ThreadPool ;
46
49
47
50
import com .google .common .collect .ImmutableList ;
@@ -67,10 +70,15 @@ public class AwsConnectorExecutorTest {
67
70
68
71
Encryptor encryptor ;
69
72
73
+ @ Mock
74
+ private ScriptService scriptService ;
75
+
70
76
@ Before
71
77
public void setUp () {
72
78
MockitoAnnotations .openMocks (this );
73
79
encryptor = new EncryptorImpl ("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=" );
80
+ when (scriptService .compile (any (), any ()))
81
+ .then (invocation -> new TestTemplateService .MockTemplateScript .Factory ("{\" result\" : \" hello world\" }" ));
74
82
}
75
83
76
84
@ Test
@@ -282,4 +290,80 @@ public void executePredict_RemoteInferenceInput_negativeStepSize_throwIllegalArg
282
290
Mockito .verify (actionListener , times (1 )).onFailure (exceptionCaptor .capture ());
283
291
assert exceptionCaptor .getValue () instanceof IllegalArgumentException ;
284
292
}
293
+
294
+ @ Test
295
+ public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPredictionAction () {
296
+ ConnectorAction predictAction = ConnectorAction
297
+ .builder ()
298
+ .actionType (ConnectorAction .ActionType .PREDICT )
299
+ .method ("POST" )
300
+ .url ("http://openai.com/mock" )
301
+ .requestBody ("{\" input\" : ${parameters.input}}" )
302
+ .preProcessFunction (MLPreProcessFunction .TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT )
303
+ .build ();
304
+ Map <String , String > credential = ImmutableMap
305
+ .of (ACCESS_KEY_FIELD , encryptor .encrypt ("test_key" ), SECRET_KEY_FIELD , encryptor .encrypt ("test_secret_key" ));
306
+ Map <String , String > parameters = ImmutableMap .of (REGION_FIELD , "us-west-2" , SERVICE_NAME_FIELD , "sagemaker" );
307
+ Connector connector = AwsConnector
308
+ .awsConnectorBuilder ()
309
+ .name ("test connector" )
310
+ .version ("1" )
311
+ .protocol ("http" )
312
+ .parameters (parameters )
313
+ .credential (credential )
314
+ .build ();
315
+ connector .decrypt ((c ) -> encryptor .decrypt (c ));
316
+ AwsConnectorExecutor executor = spy (new AwsConnectorExecutor (connector ));
317
+ Settings settings = Settings .builder ().build ();
318
+ threadContext = new ThreadContext (settings );
319
+ when (executor .getClient ()).thenReturn (client );
320
+ when (client .threadPool ()).thenReturn (threadPool );
321
+ when (threadPool .getThreadContext ()).thenReturn (threadContext );
322
+
323
+ MLInputDataset inputDataSet = TextDocsInputDataSet .builder ().docs (ImmutableList .of ("input1" , "input2" , "input3" )).build ();
324
+ executor
325
+ .executePredict (MLInput .builder ().algorithm (FunctionName .TEXT_EMBEDDING ).inputDataset (inputDataSet ).build (), actionListener );
326
+ ArgumentCaptor <Exception > exceptionArgumentCaptor = ArgumentCaptor .forClass (Exception .class );
327
+ Mockito .verify (actionListener , times (1 )).onFailure (exceptionArgumentCaptor .capture ());
328
+ assert exceptionArgumentCaptor .getValue () instanceof IllegalArgumentException ;
329
+ assert "no predict action found" .equals (exceptionArgumentCaptor .getValue ().getMessage ());
330
+ }
331
+
332
+ @ Test
333
+ public void executePredict_TextDocsInferenceInput_withoutStepSize_userDefinedPreProcessFunction () {
334
+ ConnectorAction predictAction = ConnectorAction
335
+ .builder ()
336
+ .actionType (ConnectorAction .ActionType .PREDICT )
337
+ .method ("POST" )
338
+ .url ("http://openai.com/mock" )
339
+ .requestBody ("{\" input\" : ${parameters.input}}" )
340
+ .preProcessFunction (
341
+ "\n StringBuilder builder = new StringBuilder();\n builder.append(\" \\ \" \" );\n String first = params.text_docs[0];\n builder.append(first);\n builder.append(\" \\ \" \" );\n def parameters = \" {\" +\" \\ \" text_inputs\\ \" :\" + builder + \" }\" ;\n return \" {\" +\" \\ \" parameters\\ \" :\" + parameters + \" }\" ;"
342
+ )
343
+ .build ();
344
+ Map <String , String > credential = ImmutableMap
345
+ .of (ACCESS_KEY_FIELD , encryptor .encrypt ("test_key" ), SECRET_KEY_FIELD , encryptor .encrypt ("test_secret_key" ));
346
+ Map <String , String > parameters = ImmutableMap .of (REGION_FIELD , "us-west-2" , SERVICE_NAME_FIELD , "sagemaker" );
347
+ Connector connector = AwsConnector
348
+ .awsConnectorBuilder ()
349
+ .name ("test connector" )
350
+ .version ("1" )
351
+ .protocol ("http" )
352
+ .parameters (parameters )
353
+ .credential (credential )
354
+ .actions (Arrays .asList (predictAction ))
355
+ .build ();
356
+ connector .decrypt ((c ) -> encryptor .decrypt (c ));
357
+ AwsConnectorExecutor executor = spy (new AwsConnectorExecutor (connector ));
358
+ Settings settings = Settings .builder ().build ();
359
+ threadContext = new ThreadContext (settings );
360
+ when (executor .getClient ()).thenReturn (client );
361
+ when (client .threadPool ()).thenReturn (threadPool );
362
+ when (threadPool .getThreadContext ()).thenReturn (threadContext );
363
+ when (executor .getScriptService ()).thenReturn (scriptService );
364
+
365
+ MLInputDataset inputDataSet = TextDocsInputDataSet .builder ().docs (ImmutableList .of ("input1" , "input2" , "input3" )).build ();
366
+ executor
367
+ .executePredict (MLInput .builder ().algorithm (FunctionName .TEXT_EMBEDDING ).inputDataset (inputDataSet ).build (), actionListener );
368
+ }
285
369
}
0 commit comments