18
18
import org .junit .Before ;
19
19
import org .junit .Rule ;
20
20
import org .junit .rules .ExpectedException ;
21
+ import org .opensearch .action .search .SearchRequest ;
21
22
import org .opensearch .client .Response ;
22
23
import org .opensearch .client .ResponseException ;
23
24
import org .opensearch .client .RestClient ;
24
25
import org .opensearch .commons .rest .SecureRestClientBuilder ;
25
26
import org .opensearch .index .query .MatchAllQueryBuilder ;
27
+ import org .opensearch .index .query .QueryBuilders ;
26
28
import org .opensearch .ml .common .AccessMode ;
27
29
import org .opensearch .ml .common .FunctionName ;
28
30
import org .opensearch .ml .common .MLTaskState ;
31
+ import org .opensearch .ml .common .agent .MLAgent ;
32
+ import org .opensearch .ml .common .dataset .remote .RemoteInferenceInputDataSet ;
33
+ import org .opensearch .ml .common .input .execute .agent .AgentMLInput ;
29
34
import org .opensearch .ml .common .input .parameter .clustering .KMeansParams ;
35
+ import org .opensearch .ml .common .transport .agent .MLAgentDeleteRequest ;
36
+ import org .opensearch .ml .common .transport .agent .MLAgentGetRequest ;
37
+ import org .opensearch .ml .common .transport .agent .MLRegisterAgentRequest ;
30
38
import org .opensearch .ml .common .transport .model_group .MLRegisterModelGroupInput ;
31
39
import org .opensearch .ml .common .transport .register .MLRegisterModelInput ;
32
40
import org .opensearch .ml .utils .TestHelper ;
35
43
import com .google .common .base .Throwables ;
36
44
import com .google .common .collect .ImmutableList ;
37
45
46
+ import static org .opensearch .ml .common .CommonValue .ML_AGENT_INDEX ;
47
+
38
48
public class SecureMLRestIT extends MLCommonsRestTestCase {
39
49
private String irisIndex = "iris_data_secure_ml_it" ;
40
50
@@ -59,6 +69,8 @@ public class SecureMLRestIT extends MLCommonsRestTestCase {
59
69
60
70
private String modelGroupId ;
61
71
72
+ private MLAgent mlAgent ;
73
+
62
74
/**
63
75
* Create an unguessable password. Simple password are weak due to https://tinyurl.com/383em9zk
64
76
* @return a random password.
@@ -151,6 +163,8 @@ public void setup() throws IOException {
151
163
this .modelGroupId = (String ) registerModelGroupResult .get ("model_group_id" );
152
164
});
153
165
mlRegisterModelInput = createRegisterModelInput (modelGroupId );
166
+
167
+ mlAgent = createCatIndexToolMLAgent ();
154
168
}
155
169
156
170
@ After
@@ -248,6 +262,152 @@ public void testDeployModelWithFullAccess() throws IOException, InterruptedExcep
248
262
});
249
263
}
250
264
265
+ public void testExecuteAgentWithFullAccess () throws IOException {
266
+ registerMLAgent (mlFullAccessClient , TestHelper .toJsonString (mlAgent ), registerMLAgentResult -> {
267
+ assertNotNull (registerMLAgentResult );
268
+ assertTrue (registerMLAgentResult .containsKey ("agent_id" ));
269
+ String agentId = (String ) registerMLAgentResult .get ("agent_id" );
270
+ try {
271
+ AgentMLInput agentMLInput = AgentMLInput
272
+ .AgentMLInputBuilder ()
273
+ .agentId (agentId )
274
+ .functionName (FunctionName .AGENT )
275
+ .inputDataset (
276
+ RemoteInferenceInputDataSet .builder ().parameters (Map .of ("question" , "How many indices do I have?" )).build ()
277
+ )
278
+ .build ();
279
+
280
+ executeAgent (mlFullAccessClient , agentId , TestHelper .toJsonString (agentMLInput ), mlExecuteTaskResponse -> {
281
+ assertNotNull (mlExecuteTaskResponse );
282
+ assertTrue (mlExecuteTaskResponse .containsKey ("inference_results" ));
283
+ });
284
+ } catch (IOException e ) {
285
+ assertNull (e );
286
+ }
287
+ });
288
+ }
289
+
290
+ public void testExecuteAgentWithReadOnlyAccess () throws IOException {
291
+ exceptionRule .expect (ResponseException .class );
292
+ exceptionRule .toString ();
293
+ exceptionRule .expectMessage ("no permissions for [cluster:admin/opensearch/ml/execute]" );
294
+ AgentMLInput agentMLInput = AgentMLInput
295
+ .AgentMLInputBuilder ()
296
+ .agentId ("test-agent" )
297
+ .functionName (FunctionName .AGENT )
298
+ .inputDataset (
299
+ RemoteInferenceInputDataSet .builder ().parameters (Map .of ("question" , "How many indices do I have?" )).build ()
300
+ )
301
+ .build ();
302
+
303
+ executeAgent (mlReadOnlyClient , "test-agent" , TestHelper .toJsonString (agentMLInput ), mlExecuteTaskResponse -> {
304
+ assertNotNull (mlExecuteTaskResponse );
305
+ assertTrue (mlExecuteTaskResponse .containsKey ("inference_results" ));
306
+ });
307
+ }
308
+
309
+ public void testGetAgentWithFullAccess () throws IOException {
310
+ registerMLAgent (mlFullAccessClient , TestHelper .toJsonString (mlAgent ), registerMLAgentResult -> {
311
+ assertNotNull (registerMLAgentResult );
312
+ assertTrue (registerMLAgentResult .containsKey ("agent_id" ));
313
+ String agentId = (String ) registerMLAgentResult .get ("agent_id" );
314
+ try {
315
+ MLAgentGetRequest mlAgentGetRequest = MLAgentGetRequest .builder ().agentId (agentId ).build ();
316
+ getAgent (mlFullAccessClient , agentId , mlGetAgentResponse -> {
317
+ assertNotNull (mlGetAgentResponse );
318
+ assertTrue (mlGetAgentResponse .containsKey ("name" ));
319
+ assertEquals (mlGetAgentResponse .get ("name" ), "Test_Agent_For_CatIndex_tool" );
320
+ });
321
+ } catch (IOException e ) {
322
+ assertNull (e );
323
+ }
324
+ });
325
+ }
326
+
327
+ public void testGetAgentWithNoAccess () throws IOException {
328
+ exceptionRule .expect (ResponseException .class );
329
+ exceptionRule .expectMessage ("no permissions for [cluster:admin/opensearch/ml/agents/get]" );
330
+
331
+ getAgent (mlNoAccessClient , "test-agent" , mlExecuteTaskResponse -> {
332
+ assertNotNull (mlExecuteTaskResponse );
333
+ assertTrue (mlExecuteTaskResponse .containsKey ("inference_results" ));
334
+ });
335
+ }
336
+
337
+ public void testSearchAgentWithFullAccess () throws IOException {
338
+ registerMLAgent (mlFullAccessClient , TestHelper .toJsonString (mlAgent ), registerMLAgentResult -> {
339
+ assertNotNull (registerMLAgentResult );
340
+ assertTrue (registerMLAgentResult .containsKey ("agent_id" ));
341
+ try {
342
+ SearchRequest searchRequest = new SearchRequest (ML_AGENT_INDEX );
343
+ SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder ();
344
+ searchSourceBuilder .query (QueryBuilders .matchAllQuery ());
345
+ searchRequest .source (searchSourceBuilder );
346
+ searchAgent (mlFullAccessClient , gson .toJson (searchRequest ), mlSearchAgentResponse -> {
347
+ assertNotNull (mlSearchAgentResponse );
348
+ assertTrue (mlSearchAgentResponse .containsKey ("hits" ));
349
+ });
350
+ } catch (IOException e ) {
351
+ assertNull (e );
352
+ }
353
+ });
354
+ }
355
+
356
+ public void testSearchAgentWithNoAccess () throws IOException {
357
+ exceptionRule .expect (ResponseException .class );
358
+ exceptionRule .expectMessage ("no permissions for [cluster:admin/opensearch/ml/agents/search]" );
359
+
360
+ SearchRequest searchRequest = new SearchRequest (ML_AGENT_INDEX );
361
+ SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder ();
362
+ searchSourceBuilder .query (QueryBuilders .matchAllQuery ());
363
+ searchRequest .source (searchSourceBuilder );
364
+ searchAgent (mlNoAccessClient , gson .toJson (searchRequest ), mlSearchAgentResponse -> {
365
+ assertNotNull (mlSearchAgentResponse );
366
+ assertTrue (mlSearchAgentResponse .containsKey ("hits" ));
367
+ });
368
+ }
369
+
370
+ public void testDeleteAgentWithFullAccess () throws IOException {
371
+ registerMLAgent (mlFullAccessClient , TestHelper .toJsonString (mlAgent ), registerMLAgentResult -> {
372
+ assertNotNull (registerMLAgentResult );
373
+ assertTrue (registerMLAgentResult .containsKey ("agent_id" ));
374
+ String agentId = (String ) registerMLAgentResult .get ("agent_id" );
375
+ try {
376
+ deleteAgent (mlFullAccessClient , agentId , mlSearchAgentResponse -> {
377
+ assertNotNull (mlSearchAgentResponse );
378
+ assertTrue (mlSearchAgentResponse .containsKey ("result" ));
379
+ assertEquals (mlSearchAgentResponse .get ("result" ), "deleted" );
380
+ });
381
+ } catch (IOException e ) {
382
+ assertNull (e );
383
+ }
384
+ });
385
+ }
386
+
387
+ public void testDeleteAgentWithNoAccess () throws IOException {
388
+ exceptionRule .expect (ResponseException .class );
389
+ exceptionRule .expectMessage ("no permissions for [cluster:admin/opensearch/ml/agents/delete]" );
390
+
391
+ deleteAgent (mlReadOnlyClient , "agentId" , mlSearchAgentResponse -> {
392
+ assertNotNull (mlSearchAgentResponse );
393
+ assertTrue (mlSearchAgentResponse .containsKey ("result" ));
394
+ assertEquals (mlSearchAgentResponse .get ("result" ), "deleted" );
395
+ });
396
+ }
397
+
398
+ public void testRegisterAgentWithFullAccess () throws IOException {
399
+ registerMLAgent (mlFullAccessClient , TestHelper .toJsonString (mlAgent ), registerMLAgentResult -> {
400
+ assertNotNull (registerMLAgentResult );
401
+ assertTrue (registerMLAgentResult .containsKey ("agent_id" ));
402
+ });
403
+ }
404
+
405
+ public void testRegisterAgentWithReadOnlyMLAccess () throws IOException {
406
+ exceptionRule .expect (ResponseException .class );
407
+ exceptionRule .expectMessage ("no permissions for [cluster:admin/opensearch/ml/agents/register]" );
408
+ registerMLAgent (mlReadOnlyClient , TestHelper .toJsonString (mlAgent ), null );
409
+ }
410
+
251
411
public void testTrainWithReadOnlyMLAccess () throws IOException {
252
412
exceptionRule .expect (ResponseException .class );
253
413
exceptionRule .expectMessage ("no permissions for [cluster:admin/opensearch/ml/train]" );
0 commit comments