Skip to content

Commit f2d083c

Browse files
committed
add agent framework security it tests
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent eeba1c3 commit f2d083c

File tree

2 files changed

+195
-0
lines changed

2 files changed

+195
-0
lines changed

plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java

+43
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@
6666
import org.opensearch.ml.common.AccessMode;
6767
import org.opensearch.ml.common.FunctionName;
6868
import org.opensearch.ml.common.MLTaskState;
69+
import org.opensearch.ml.common.agent.MLAgent;
70+
import org.opensearch.ml.common.agent.MLToolSpec;
6971
import org.opensearch.ml.common.dataset.MLInputDataset;
7072
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
7173
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
@@ -731,6 +733,47 @@ public String registerModel(String input) throws IOException {
731733
return parseTaskIdFromResponse(response);
732734
}
733735

736+
public void registerMLAgent(RestClient client, String input, Consumer<Map<String, Object>> function) throws IOException {
737+
Response response = TestHelper.makeRequest(client, "POST", "/_plugins/_ml/agents/_register", null, input, null);
738+
verifyResponse(function, response);
739+
}
740+
741+
public void executeAgent(RestClient client, String agentId, String input, Consumer<Map<String, Object>> function) throws IOException {
742+
Response response = TestHelper.makeRequest(client, "POST", "/_plugins/_ml/agents/" + agentId + "/_execute", null, input, null);
743+
verifyResponse(function, response);
744+
}
745+
746+
public void getAgent(RestClient client, String agentId, Consumer<Map<String, Object>> function) throws IOException {
747+
Response response = TestHelper.makeRequest(client, "GET", "/_plugins/_ml/agents/" + agentId, null, "", null);
748+
verifyResponse(function, response);
749+
}
750+
751+
public void searchAgent(RestClient client, String input, Consumer<Map<String, Object>> function) throws IOException {
752+
Response response = TestHelper.makeRequest(client, "POST", "/_plugins/_ml/agents/_search", null, input, null);
753+
verifyResponse(function, response);
754+
}
755+
756+
public void deleteAgent(RestClient client, String agentId, Consumer<Map<String, Object>> function) throws IOException {
757+
Response response = TestHelper.makeRequest(client, "DELETE", "/_plugins/_ml/agents/" + agentId, null, "", null);
758+
verifyResponse(function, response);
759+
}
760+
761+
public MLAgent createCatIndexToolMLAgent() {
762+
MLToolSpec catIndexTool = MLToolSpec
763+
.builder()
764+
.type("CatIndexTool")
765+
.name("DemoCatIndexTool")
766+
.parameters(Map.of("input", "${parameters.question}"))
767+
.build();
768+
return MLAgent
769+
.builder()
770+
.name("Test_Agent_For_CatIndex_tool")
771+
.type("flow")
772+
.description("this is a test agent for the CatIndexTool")
773+
.tools(List.of(catIndexTool))
774+
.build();
775+
}
776+
734777
public void deployModel(RestClient client, MLRegisterModelInput registerModelInput, Consumer<Map<String, Object>> function)
735778
throws IOException,
736779
InterruptedException {

plugin/src/test/java/org/opensearch/ml/rest/SecureMLRestIT.java

+152
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@
2626
import org.opensearch.ml.common.AccessMode;
2727
import org.opensearch.ml.common.FunctionName;
2828
import org.opensearch.ml.common.MLTaskState;
29+
import org.opensearch.ml.common.agent.MLAgent;
30+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
31+
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
2932
import org.opensearch.ml.common.input.parameter.clustering.KMeansParams;
33+
import org.opensearch.ml.common.transport.agent.MLAgentGetRequest;
3034
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
3135
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
3236
import org.opensearch.ml.utils.TestHelper;
@@ -59,6 +63,8 @@ public class SecureMLRestIT extends MLCommonsRestTestCase {
5963

6064
private String modelGroupId;
6165

66+
private MLAgent mlAgent;
67+
6268
/**
6369
* Create an unguessable password. Simple password are weak due to https://tinyurl.com/383em9zk
6470
* @return a random password.
@@ -151,6 +157,8 @@ public void setup() throws IOException {
151157
this.modelGroupId = (String) registerModelGroupResult.get("model_group_id");
152158
});
153159
mlRegisterModelInput = createRegisterModelInput(modelGroupId);
160+
161+
mlAgent = createCatIndexToolMLAgent();
154162
}
155163

156164
@After
@@ -248,6 +256,150 @@ public void testDeployModelWithFullAccess() throws IOException, InterruptedExcep
248256
});
249257
}
250258

259+
public void testExecuteAgentWithFullAccess() throws IOException {
260+
registerMLAgent(mlFullAccessClient, TestHelper.toJsonString(mlAgent), registerMLAgentResult -> {
261+
assertNotNull(registerMLAgentResult);
262+
assertTrue(registerMLAgentResult.containsKey("agent_id"));
263+
String agentId = (String) registerMLAgentResult.get("agent_id");
264+
try {
265+
AgentMLInput agentMLInput = AgentMLInput
266+
.AgentMLInputBuilder()
267+
.agentId(agentId)
268+
.functionName(FunctionName.AGENT)
269+
.inputDataset(
270+
RemoteInferenceInputDataSet.builder().parameters(Map.of("question", "How many indices do I have?")).build()
271+
)
272+
.build();
273+
274+
executeAgent(mlFullAccessClient, agentId, TestHelper.toJsonString(agentMLInput), mlExecuteTaskResponse -> {
275+
assertNotNull(mlExecuteTaskResponse);
276+
assertTrue(mlExecuteTaskResponse.containsKey("inference_results"));
277+
});
278+
} catch (IOException e) {
279+
assertNull(e);
280+
}
281+
});
282+
}
283+
284+
public void testExecuteAgentWithReadOnlyAccess() throws IOException {
285+
exceptionRule.expect(ResponseException.class);
286+
exceptionRule.toString();
287+
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/execute]");
288+
AgentMLInput agentMLInput = AgentMLInput
289+
.AgentMLInputBuilder()
290+
.agentId("test-agent")
291+
.functionName(FunctionName.AGENT)
292+
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(Map.of("question", "How many indices do I have?")).build())
293+
.build();
294+
295+
executeAgent(mlReadOnlyClient, "test-agent", TestHelper.toJsonString(agentMLInput), mlExecuteTaskResponse -> {
296+
assertNotNull(mlExecuteTaskResponse);
297+
assertTrue(mlExecuteTaskResponse.containsKey("inference_results"));
298+
});
299+
}
300+
301+
public void testGetAgentWithFullAccess() throws IOException {
302+
registerMLAgent(mlFullAccessClient, TestHelper.toJsonString(mlAgent), registerMLAgentResult -> {
303+
assertNotNull(registerMLAgentResult);
304+
assertTrue(registerMLAgentResult.containsKey("agent_id"));
305+
String agentId = (String) registerMLAgentResult.get("agent_id");
306+
try {
307+
MLAgentGetRequest mlAgentGetRequest = MLAgentGetRequest.builder().agentId(agentId).build();
308+
getAgent(mlFullAccessClient, agentId, mlGetAgentResponse -> {
309+
assertNotNull(mlGetAgentResponse);
310+
assertTrue(mlGetAgentResponse.containsKey("name"));
311+
assertEquals(mlGetAgentResponse.get("name"), "Test_Agent_For_CatIndex_tool");
312+
});
313+
} catch (IOException e) {
314+
assertNull(e);
315+
}
316+
});
317+
}
318+
319+
public void testGetAgentWithNoAccess() throws IOException {
320+
exceptionRule.expect(ResponseException.class);
321+
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/agents/get]");
322+
323+
getAgent(mlNoAccessClient, "test-agent", mlExecuteTaskResponse -> {
324+
assertNotNull(mlExecuteTaskResponse);
325+
assertTrue(mlExecuteTaskResponse.containsKey("inference_results"));
326+
});
327+
}
328+
329+
public void testSearchAgentWithFullAccess() throws IOException {
330+
registerMLAgent(mlFullAccessClient, TestHelper.toJsonString(mlAgent), registerMLAgentResult -> {
331+
assertNotNull(registerMLAgentResult);
332+
assertTrue(registerMLAgentResult.containsKey("agent_id"));
333+
try {
334+
searchAgent(
335+
mlFullAccessClient,
336+
"{\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " }\n" + "}",
337+
mlSearchAgentResponse -> {
338+
assertNotNull(mlSearchAgentResponse);
339+
assertTrue(mlSearchAgentResponse.containsKey("hits"));
340+
}
341+
);
342+
} catch (IOException e) {
343+
assertNull(e);
344+
}
345+
});
346+
}
347+
348+
public void testSearchAgentWithNoAccess() throws IOException {
349+
exceptionRule.expect(ResponseException.class);
350+
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/agents/search]");
351+
352+
searchAgent(
353+
mlNoAccessClient,
354+
"{\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " }\n" + "}",
355+
mlSearchAgentResponse -> {
356+
assertNotNull(mlSearchAgentResponse);
357+
assertTrue(mlSearchAgentResponse.containsKey("hits"));
358+
}
359+
);
360+
}
361+
362+
public void testDeleteAgentWithFullAccess() throws IOException {
363+
registerMLAgent(mlFullAccessClient, TestHelper.toJsonString(mlAgent), registerMLAgentResult -> {
364+
assertNotNull(registerMLAgentResult);
365+
assertTrue(registerMLAgentResult.containsKey("agent_id"));
366+
String agentId = (String) registerMLAgentResult.get("agent_id");
367+
try {
368+
deleteAgent(mlFullAccessClient, agentId, mlSearchAgentResponse -> {
369+
assertNotNull(mlSearchAgentResponse);
370+
assertTrue(mlSearchAgentResponse.containsKey("result"));
371+
assertEquals(mlSearchAgentResponse.get("result"), "deleted");
372+
});
373+
} catch (IOException e) {
374+
assertNull(e);
375+
}
376+
});
377+
}
378+
379+
public void testDeleteAgentWithNoAccess() throws IOException {
380+
exceptionRule.expect(ResponseException.class);
381+
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/agents/delete]");
382+
383+
deleteAgent(mlReadOnlyClient, "agentId", mlSearchAgentResponse -> {
384+
assertNotNull(mlSearchAgentResponse);
385+
assertTrue(mlSearchAgentResponse.containsKey("result"));
386+
assertEquals(mlSearchAgentResponse.get("result"), "deleted");
387+
});
388+
}
389+
390+
public void testRegisterAgentWithFullAccess() throws IOException {
391+
registerMLAgent(mlFullAccessClient, TestHelper.toJsonString(mlAgent), registerMLAgentResult -> {
392+
assertNotNull(registerMLAgentResult);
393+
assertTrue(registerMLAgentResult.containsKey("agent_id"));
394+
});
395+
}
396+
397+
public void testRegisterAgentWithReadOnlyMLAccess() throws IOException {
398+
exceptionRule.expect(ResponseException.class);
399+
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/agents/register]");
400+
registerMLAgent(mlReadOnlyClient, TestHelper.toJsonString(mlAgent), null);
401+
}
402+
251403
public void testTrainWithReadOnlyMLAccess() throws IOException {
252404
exceptionRule.expect(ResponseException.class);
253405
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/train]");

0 commit comments

Comments
 (0)