18
18
import org .junit .Before ;
19
19
import org .junit .Rule ;
20
20
import org .junit .rules .ExpectedException ;
21
+ import org .opensearch .action .search .SearchRequest ;
21
22
import org .opensearch .client .Response ;
22
23
import org .opensearch .client .ResponseException ;
23
24
import org .opensearch .client .RestClient ;
24
25
import org .opensearch .commons .rest .SecureRestClientBuilder ;
25
26
import org .opensearch .index .query .MatchAllQueryBuilder ;
27
+ import org .opensearch .index .query .QueryBuilders ;
26
28
import org .opensearch .ml .common .AccessMode ;
27
29
import org .opensearch .ml .common .FunctionName ;
28
30
import org .opensearch .ml .common .MLTaskState ;
31
+ import org .opensearch .ml .common .agent .MLAgent ;
32
+ import org .opensearch .ml .common .dataset .remote .RemoteInferenceInputDataSet ;
33
+ import org .opensearch .ml .common .input .execute .agent .AgentMLInput ;
29
34
import org .opensearch .ml .common .input .parameter .clustering .KMeansParams ;
35
+ import org .opensearch .ml .common .transport .agent .MLAgentDeleteRequest ;
36
+ import org .opensearch .ml .common .transport .agent .MLAgentGetRequest ;
37
+ import org .opensearch .ml .common .transport .agent .MLRegisterAgentRequest ;
30
38
import org .opensearch .ml .common .transport .model_group .MLRegisterModelGroupInput ;
31
39
import org .opensearch .ml .common .transport .register .MLRegisterModelInput ;
32
40
import org .opensearch .ml .utils .TestHelper ;
35
43
import com .google .common .base .Throwables ;
36
44
import com .google .common .collect .ImmutableList ;
37
45
46
+ import static org .opensearch .ml .common .CommonValue .ML_AGENT_INDEX ;
47
+
38
48
public class SecureMLRestIT extends MLCommonsRestTestCase {
39
49
private String irisIndex = "iris_data_secure_ml_it" ;
40
50
@@ -59,6 +69,8 @@ public class SecureMLRestIT extends MLCommonsRestTestCase {
59
69
60
70
private String modelGroupId ;
61
71
72
+ private MLAgent mlAgent ;
73
+
62
74
/**
63
75
* Create an unguessable password. Simple password are weak due to https://tinyurl.com/383em9zk
64
76
* @return a random password.
@@ -151,6 +163,8 @@ public void setup() throws IOException {
151
163
this .modelGroupId = (String ) registerModelGroupResult .get ("model_group_id" );
152
164
});
153
165
mlRegisterModelInput = createRegisterModelInput (modelGroupId );
166
+
167
+ mlAgent = createCatIndexToolMLAgent ();
154
168
}
155
169
156
170
@ After
@@ -248,6 +262,151 @@ public void testDeployModelWithFullAccess() throws IOException, InterruptedExcep
248
262
});
249
263
}
250
264
265
+ public void testExecuteAgentWithFullAccess () throws IOException {
266
+ registerMLAgent (mlFullAccessClient , TestHelper .toJsonString (mlAgent ), registerMLAgentResult -> {
267
+ assertNotNull (registerMLAgentResult );
268
+ assertTrue (registerMLAgentResult .containsKey ("agent_id" ));
269
+ String agentId = (String ) registerMLAgentResult .get ("agent_id" );
270
+ try {
271
+ AgentMLInput agentMLInput = AgentMLInput
272
+ .AgentMLInputBuilder ()
273
+ .agentId (agentId )
274
+ .functionName (FunctionName .AGENT )
275
+ .inputDataset (
276
+ RemoteInferenceInputDataSet .builder ().parameters (Map .of ("question" , "How many indices do I have?" )).build ()
277
+ )
278
+ .build ();
279
+
280
+ executeAgent (mlFullAccessClient , agentId , TestHelper .toJsonString (agentMLInput ), mlExecuteTaskResponse -> {
281
+ assertNotNull (mlExecuteTaskResponse );
282
+ assertTrue (mlExecuteTaskResponse .containsKey ("inference_results" ));
283
+ });
284
+ } catch (IOException e ) {
285
+ assertNull (e );
286
+ }
287
+ });
288
+ }
289
+
290
+ public void testExecuteAgentWithReadOnlyAccess () throws IOException {
291
+ exceptionRule .expect (RuntimeException .class );
292
+ exceptionRule .expectMessage ("no permissions for [cluster:admin/opensearch/ml/execute]" );
293
+ AgentMLInput agentMLInput = AgentMLInput
294
+ .AgentMLInputBuilder ()
295
+ .agentId ("test-agent" )
296
+ .functionName (FunctionName .AGENT )
297
+ .inputDataset (
298
+ RemoteInferenceInputDataSet .builder ().parameters (Map .of ("question" , "How many indices do I have?" )).build ()
299
+ )
300
+ .build ();
301
+
302
+ executeAgent (mlReadOnlyClient , "test-agent" , TestHelper .toJsonString (agentMLInput ), mlExecuteTaskResponse -> {
303
+ assertNotNull (mlExecuteTaskResponse );
304
+ assertTrue (mlExecuteTaskResponse .containsKey ("inference_results" ));
305
+ });
306
+ }
307
+
308
+ public void testGetAgentWithFullAccess () throws IOException {
309
+ registerMLAgent (mlFullAccessClient , TestHelper .toJsonString (mlAgent ), registerMLAgentResult -> {
310
+ assertNotNull (registerMLAgentResult );
311
+ assertTrue (registerMLAgentResult .containsKey ("agent_id" ));
312
+ String agentId = (String ) registerMLAgentResult .get ("agent_id" );
313
+ try {
314
+ MLAgentGetRequest mlAgentGetRequest = MLAgentGetRequest .builder ().agentId (agentId ).build ();
315
+ getAgent (mlFullAccessClient , agentId , mlGetAgentResponse -> {
316
+ assertNotNull (mlGetAgentResponse );
317
+ assertTrue (mlGetAgentResponse .containsKey ("name" ));
318
+ assertEquals (mlGetAgentResponse .get ("name" ), "Test_Agent_For_CatIndex_tool" );
319
+ });
320
+ } catch (IOException e ) {
321
+ assertNull (e );
322
+ }
323
+ });
324
+ }
325
+
326
+ public void testGetAgentWithNoAccess () throws IOException {
327
+ exceptionRule .expect (RuntimeException .class );
328
+ exceptionRule .expectMessage ("no permissions for [cluster:admin/opensearch/ml/agents/get]" );
329
+
330
+ getAgent (mlNoAccessClient , "test-agent" , mlExecuteTaskResponse -> {
331
+ assertNotNull (mlExecuteTaskResponse );
332
+ assertTrue (mlExecuteTaskResponse .containsKey ("inference_results" ));
333
+ });
334
+ }
335
+
336
+ public void testSearchAgentWithFullAccess () throws IOException {
337
+ registerMLAgent (mlFullAccessClient , TestHelper .toJsonString (mlAgent ), registerMLAgentResult -> {
338
+ assertNotNull (registerMLAgentResult );
339
+ assertTrue (registerMLAgentResult .containsKey ("agent_id" ));
340
+ try {
341
+ SearchRequest searchRequest = new SearchRequest (ML_AGENT_INDEX );
342
+ SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder ();
343
+ searchSourceBuilder .query (QueryBuilders .matchAllQuery ());
344
+ searchRequest .source (searchSourceBuilder );
345
+ searchAgent (mlFullAccessClient , gson .toJson (searchRequest ), mlSearchAgentResponse -> {
346
+ assertNotNull (mlSearchAgentResponse );
347
+ assertTrue (mlSearchAgentResponse .containsKey ("hits" ));
348
+ });
349
+ } catch (IOException e ) {
350
+ assertNull (e );
351
+ }
352
+ });
353
+ }
354
+
355
+ public void testSearchAgentWithNoAccess () throws IOException {
356
+ exceptionRule .expect (RuntimeException .class );
357
+ exceptionRule .expectMessage ("no permissions for [cluster:admin/opensearch/ml/agents/search]" );
358
+
359
+ SearchRequest searchRequest = new SearchRequest (ML_AGENT_INDEX );
360
+ SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder ();
361
+ searchSourceBuilder .query (QueryBuilders .matchAllQuery ());
362
+ searchRequest .source (searchSourceBuilder );
363
+ searchAgent (mlNoAccessClient , gson .toJson (searchRequest ), mlSearchAgentResponse -> {
364
+ assertNotNull (mlSearchAgentResponse );
365
+ assertTrue (mlSearchAgentResponse .containsKey ("hits" ));
366
+ });
367
+ }
368
+
369
+ public void testDeleteAgentWithFullAccess () throws IOException {
370
+ registerMLAgent (mlFullAccessClient , TestHelper .toJsonString (mlAgent ), registerMLAgentResult -> {
371
+ assertNotNull (registerMLAgentResult );
372
+ assertTrue (registerMLAgentResult .containsKey ("agent_id" ));
373
+ String agentId = (String ) registerMLAgentResult .get ("agent_id" );
374
+ try {
375
+ deleteAgent (mlFullAccessClient , agentId , mlSearchAgentResponse -> {
376
+ assertNotNull (mlSearchAgentResponse );
377
+ assertTrue (mlSearchAgentResponse .containsKey ("result" ));
378
+ assertEquals (mlSearchAgentResponse .get ("result" ), "deleted" );
379
+ });
380
+ } catch (IOException e ) {
381
+ assertNull (e );
382
+ }
383
+ });
384
+ }
385
+
386
+ public void testDeleteAgentWithNoAccess () throws IOException {
387
+ exceptionRule .expect (RuntimeException .class );
388
+ exceptionRule .expectMessage ("no permissions for [cluster:admin/opensearch/ml/agents/delete]" );
389
+
390
+ deleteAgent (mlReadOnlyClient , "agentId" , mlSearchAgentResponse -> {
391
+ assertNotNull (mlSearchAgentResponse );
392
+ assertTrue (mlSearchAgentResponse .containsKey ("result" ));
393
+ assertEquals (mlSearchAgentResponse .get ("result" ), "deleted" );
394
+ });
395
+ }
396
+
397
+ public void testRegisterAgentWithFullAccess () throws IOException {
398
+ registerMLAgent (mlFullAccessClient , TestHelper .toJsonString (mlAgent ), registerMLAgentResult -> {
399
+ assertNotNull (registerMLAgentResult );
400
+ assertTrue (registerMLAgentResult .containsKey ("agent_id" ));
401
+ });
402
+ }
403
+
404
+ public void testRegisterAgentWithReadOnlyMLAccess () throws IOException {
405
+ exceptionRule .expect (ResponseException .class );
406
+ exceptionRule .expectMessage ("no permissions for [cluster:admin/opensearch/ml/agents/register]" );
407
+ registerMLAgent (mlReadOnlyClient , TestHelper .toJsonString (mlAgent ), null );
408
+ }
409
+
251
410
public void testTrainWithReadOnlyMLAccess () throws IOException {
252
411
exceptionRule .expect (ResponseException .class );
253
412
exceptionRule .expectMessage ("no permissions for [cluster:admin/opensearch/ml/train]" );
0 commit comments