Skip to content

Commit aedc9b3

Browse files
committed
add agent framework security it tests
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent eeba1c3 commit aedc9b3

File tree

2 files changed

+205
-0
lines changed

2 files changed

+205
-0
lines changed

plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java

+45
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,14 @@
6666
import org.opensearch.ml.common.AccessMode;
6767
import org.opensearch.ml.common.FunctionName;
6868
import org.opensearch.ml.common.MLTaskState;
69+
import org.opensearch.ml.common.agent.MLAgent;
70+
import org.opensearch.ml.common.agent.MLToolSpec;
6971
import org.opensearch.ml.common.dataset.MLInputDataset;
7072
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
7173
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
74+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
7275
import org.opensearch.ml.common.input.MLInput;
76+
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
7377
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
7478
import org.opensearch.ml.common.model.MLModelConfig;
7579
import org.opensearch.ml.common.model.MLModelFormat;
@@ -731,6 +735,47 @@ public String registerModel(String input) throws IOException {
731735
return parseTaskIdFromResponse(response);
732736
}
733737

738+
public void registerMLAgent(RestClient client, String input, Consumer<Map<String, Object>> function) throws IOException {
739+
Response response = TestHelper.makeRequest(client, "POST", "/_plugins/_ml/agents/_register", null, input, null);
740+
verifyResponse(function, response);
741+
}
742+
743+
public void executeAgent(RestClient client, String agentId, String input, Consumer<Map<String, Object>> function) throws IOException {
744+
Response response = TestHelper.makeRequest(client, "POST", "/_plugins/_ml/agents/" + agentId + "/_execute", null, input, null);
745+
verifyResponse(function, response);
746+
}
747+
748+
public void getAgent(RestClient client, String agentId, Consumer<Map<String, Object>> function) throws IOException {
749+
Response response = TestHelper.makeRequest(client, "GET", "/_plugins/_ml/agents/" + agentId, null, "", null);
750+
verifyResponse(function, response);
751+
}
752+
753+
public void searchAgent(RestClient client, String input, Consumer<Map<String, Object>> function) throws IOException {
754+
Response response = TestHelper.makeRequest(client, "POST", "/_plugins/_ml/agents/_search", null, input, null);
755+
verifyResponse(function, response);
756+
}
757+
758+
public void deleteAgent(RestClient client, String agentId, Consumer<Map<String, Object>> function) throws IOException {
759+
Response response = TestHelper.makeRequest(client, "DELETE", "/_plugins/_ml/agents/" + agentId, null, "", null);
760+
verifyResponse(function, response);
761+
}
762+
763+
public MLAgent createCatIndexToolMLAgent() {
764+
MLToolSpec catIndexTool = MLToolSpec
765+
.builder()
766+
.type("CatIndexTool")
767+
.name("DemoCatIndexTool")
768+
.parameters(Map.of("input", "${parameters.question}"))
769+
.build();
770+
return MLAgent
771+
.builder()
772+
.name("Test_Agent_For_CatIndex_tool")
773+
.type("flow")
774+
.description("this is a test agent for the CatIndexTool")
775+
.tools(List.of(catIndexTool))
776+
.build();
777+
}
778+
734779
public void deployModel(RestClient client, MLRegisterModelInput registerModelInput, Consumer<Map<String, Object>> function)
735780
throws IOException,
736781
InterruptedException {

plugin/src/test/java/org/opensearch/ml/rest/SecureMLRestIT.java

+160
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,23 @@
1818
import org.junit.Before;
1919
import org.junit.Rule;
2020
import org.junit.rules.ExpectedException;
21+
import org.opensearch.action.search.SearchRequest;
2122
import org.opensearch.client.Response;
2223
import org.opensearch.client.ResponseException;
2324
import org.opensearch.client.RestClient;
2425
import org.opensearch.commons.rest.SecureRestClientBuilder;
2526
import org.opensearch.index.query.MatchAllQueryBuilder;
27+
import org.opensearch.index.query.QueryBuilders;
2628
import org.opensearch.ml.common.AccessMode;
2729
import org.opensearch.ml.common.FunctionName;
2830
import org.opensearch.ml.common.MLTaskState;
31+
import org.opensearch.ml.common.agent.MLAgent;
32+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
33+
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
2934
import org.opensearch.ml.common.input.parameter.clustering.KMeansParams;
35+
import org.opensearch.ml.common.transport.agent.MLAgentDeleteRequest;
36+
import org.opensearch.ml.common.transport.agent.MLAgentGetRequest;
37+
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
3038
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
3139
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
3240
import org.opensearch.ml.utils.TestHelper;
@@ -35,6 +43,8 @@
3543
import com.google.common.base.Throwables;
3644
import com.google.common.collect.ImmutableList;
3745

46+
import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX;
47+
3848
public class SecureMLRestIT extends MLCommonsRestTestCase {
3949
private String irisIndex = "iris_data_secure_ml_it";
4050

@@ -59,6 +69,8 @@ public class SecureMLRestIT extends MLCommonsRestTestCase {
5969

6070
private String modelGroupId;
6171

72+
private MLAgent mlAgent;
73+
6274
/**
6375
* Create an unguessable password. Simple password are weak due to https://tinyurl.com/383em9zk
6476
* @return a random password.
@@ -151,6 +163,8 @@ public void setup() throws IOException {
151163
this.modelGroupId = (String) registerModelGroupResult.get("model_group_id");
152164
});
153165
mlRegisterModelInput = createRegisterModelInput(modelGroupId);
166+
167+
mlAgent = createCatIndexToolMLAgent();
154168
}
155169

156170
@After
@@ -248,6 +262,152 @@ public void testDeployModelWithFullAccess() throws IOException, InterruptedExcep
248262
});
249263
}
250264

265+
public void testExecuteAgentWithFullAccess() throws IOException {
266+
registerMLAgent(mlFullAccessClient, TestHelper.toJsonString(mlAgent), registerMLAgentResult -> {
267+
assertNotNull(registerMLAgentResult);
268+
assertTrue(registerMLAgentResult.containsKey("agent_id"));
269+
String agentId = (String) registerMLAgentResult.get("agent_id");
270+
try {
271+
AgentMLInput agentMLInput = AgentMLInput
272+
.AgentMLInputBuilder()
273+
.agentId(agentId)
274+
.functionName(FunctionName.AGENT)
275+
.inputDataset(
276+
RemoteInferenceInputDataSet.builder().parameters(Map.of("question", "How many indices do I have?")).build()
277+
)
278+
.build();
279+
280+
executeAgent(mlFullAccessClient, agentId, TestHelper.toJsonString(agentMLInput), mlExecuteTaskResponse -> {
281+
assertNotNull(mlExecuteTaskResponse);
282+
assertTrue(mlExecuteTaskResponse.containsKey("inference_results"));
283+
});
284+
} catch (IOException e) {
285+
assertNull(e);
286+
}
287+
});
288+
}
289+
290+
public void testExecuteAgentWithReadOnlyAccess() throws IOException {
291+
exceptionRule.expect(ResponseException.class);
292+
exceptionRule.toString();
293+
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/execute]");
294+
AgentMLInput agentMLInput = AgentMLInput
295+
.AgentMLInputBuilder()
296+
.agentId("test-agent")
297+
.functionName(FunctionName.AGENT)
298+
.inputDataset(
299+
RemoteInferenceInputDataSet.builder().parameters(Map.of("question", "How many indices do I have?")).build()
300+
)
301+
.build();
302+
303+
executeAgent(mlReadOnlyClient, "test-agent", TestHelper.toJsonString(agentMLInput), mlExecuteTaskResponse -> {
304+
assertNotNull(mlExecuteTaskResponse);
305+
assertTrue(mlExecuteTaskResponse.containsKey("inference_results"));
306+
});
307+
}
308+
309+
public void testGetAgentWithFullAccess() throws IOException {
310+
registerMLAgent(mlFullAccessClient, TestHelper.toJsonString(mlAgent), registerMLAgentResult -> {
311+
assertNotNull(registerMLAgentResult);
312+
assertTrue(registerMLAgentResult.containsKey("agent_id"));
313+
String agentId = (String) registerMLAgentResult.get("agent_id");
314+
try {
315+
MLAgentGetRequest mlAgentGetRequest = MLAgentGetRequest.builder().agentId(agentId).build();
316+
getAgent(mlFullAccessClient, agentId, mlGetAgentResponse -> {
317+
assertNotNull(mlGetAgentResponse);
318+
assertTrue(mlGetAgentResponse.containsKey("name"));
319+
assertEquals(mlGetAgentResponse.get("name"), "Test_Agent_For_CatIndex_tool");
320+
});
321+
} catch (IOException e) {
322+
assertNull(e);
323+
}
324+
});
325+
}
326+
327+
public void testGetAgentWithNoAccess() throws IOException {
328+
exceptionRule.expect(ResponseException.class);
329+
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/agents/get]");
330+
331+
getAgent(mlNoAccessClient, "test-agent", mlExecuteTaskResponse -> {
332+
assertNotNull(mlExecuteTaskResponse);
333+
assertTrue(mlExecuteTaskResponse.containsKey("inference_results"));
334+
});
335+
}
336+
337+
public void testSearchAgentWithFullAccess() throws IOException {
338+
registerMLAgent(mlFullAccessClient, TestHelper.toJsonString(mlAgent), registerMLAgentResult -> {
339+
assertNotNull(registerMLAgentResult);
340+
assertTrue(registerMLAgentResult.containsKey("agent_id"));
341+
try {
342+
SearchRequest searchRequest = new SearchRequest(ML_AGENT_INDEX);
343+
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
344+
searchSourceBuilder.query(QueryBuilders.matchAllQuery());
345+
searchRequest.source(searchSourceBuilder);
346+
searchAgent(mlFullAccessClient, gson.toJson(searchRequest), mlSearchAgentResponse -> {
347+
assertNotNull(mlSearchAgentResponse);
348+
assertTrue(mlSearchAgentResponse.containsKey("hits"));
349+
});
350+
} catch (IOException e) {
351+
assertNull(e);
352+
}
353+
});
354+
}
355+
356+
public void testSearchAgentWithNoAccess() throws IOException {
357+
exceptionRule.expect(ResponseException.class);
358+
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/agents/search]");
359+
360+
SearchRequest searchRequest = new SearchRequest(ML_AGENT_INDEX);
361+
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
362+
searchSourceBuilder.query(QueryBuilders.matchAllQuery());
363+
searchRequest.source(searchSourceBuilder);
364+
searchAgent(mlNoAccessClient, gson.toJson(searchRequest), mlSearchAgentResponse -> {
365+
assertNotNull(mlSearchAgentResponse);
366+
assertTrue(mlSearchAgentResponse.containsKey("hits"));
367+
});
368+
}
369+
370+
public void testDeleteAgentWithFullAccess() throws IOException {
371+
registerMLAgent(mlFullAccessClient, TestHelper.toJsonString(mlAgent), registerMLAgentResult -> {
372+
assertNotNull(registerMLAgentResult);
373+
assertTrue(registerMLAgentResult.containsKey("agent_id"));
374+
String agentId = (String) registerMLAgentResult.get("agent_id");
375+
try {
376+
deleteAgent(mlFullAccessClient, agentId, mlSearchAgentResponse -> {
377+
assertNotNull(mlSearchAgentResponse);
378+
assertTrue(mlSearchAgentResponse.containsKey("result"));
379+
assertEquals(mlSearchAgentResponse.get("result"), "deleted");
380+
});
381+
} catch (IOException e) {
382+
assertNull(e);
383+
}
384+
});
385+
}
386+
387+
public void testDeleteAgentWithNoAccess() throws IOException {
388+
exceptionRule.expect(ResponseException.class);
389+
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/agents/delete]");
390+
391+
deleteAgent(mlReadOnlyClient, "agentId", mlSearchAgentResponse -> {
392+
assertNotNull(mlSearchAgentResponse);
393+
assertTrue(mlSearchAgentResponse.containsKey("result"));
394+
assertEquals(mlSearchAgentResponse.get("result"), "deleted");
395+
});
396+
}
397+
398+
public void testRegisterAgentWithFullAccess() throws IOException {
399+
registerMLAgent(mlFullAccessClient, TestHelper.toJsonString(mlAgent), registerMLAgentResult -> {
400+
assertNotNull(registerMLAgentResult);
401+
assertTrue(registerMLAgentResult.containsKey("agent_id"));
402+
});
403+
}
404+
405+
public void testRegisterAgentWithReadOnlyMLAccess() throws IOException {
406+
exceptionRule.expect(ResponseException.class);
407+
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/agents/register]");
408+
registerMLAgent(mlReadOnlyClient, TestHelper.toJsonString(mlAgent), null);
409+
}
410+
251411
public void testTrainWithReadOnlyMLAccess() throws IOException {
252412
exceptionRule.expect(ResponseException.class);
253413
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/train]");

0 commit comments

Comments
 (0)