Skip to content

Commit c1bf9c5

Browse files
committed
add agent framework security it tests
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent eeba1c3 commit c1bf9c5

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java

+28
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@
6666
import org.opensearch.ml.common.AccessMode;
6767
import org.opensearch.ml.common.FunctionName;
6868
import org.opensearch.ml.common.MLTaskState;
69+
import org.opensearch.ml.common.agent.MLAgent;
70+
import org.opensearch.ml.common.agent.MLToolSpec;
6971
import org.opensearch.ml.common.dataset.MLInputDataset;
7072
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
7173
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
@@ -731,6 +733,32 @@ public String registerModel(String input) throws IOException {
731733
return parseTaskIdFromResponse(response);
732734
}
733735

736+
public void registerMLAgent(RestClient client, String input, Consumer<Map<String, Object>> function) throws IOException {
737+
Response response = TestHelper.makeRequest(client, "POST", "/_plugins/_ml/agents/_register", null, input, null);
738+
verifyResponse(function, response);
739+
}
740+
741+
public void executeAgent(RestClient client, String agentId, String input, Consumer<Map<String, Object>> function) throws IOException {
742+
Response response = TestHelper.makeRequest(client, "POST", "/_plugins/_ml/agents/" + agentId + "/_execute", null, input, null);
743+
verifyResponse(function, response);
744+
}
745+
746+
public MLAgent createCatIndexToolMLAgent() {
747+
MLToolSpec catIndexTool = MLToolSpec
748+
.builder()
749+
.type("CatIndexTool")
750+
.name("DemoCatIndexTool")
751+
.parameters(Map.of("input", "${parameters.question}"))
752+
.build();
753+
return MLAgent
754+
.builder()
755+
.name("Test_Agent_For_CatIndex_tool")
756+
.type("flow")
757+
.description("this is a test agent for the CatIndexTool")
758+
.tools(List.of(catIndexTool))
759+
.build();
760+
}
761+
734762
public void deployModel(RestClient client, MLRegisterModelInput registerModelInput, Consumer<Map<String, Object>> function)
735763
throws IOException,
736764
InterruptedException {

plugin/src/test/java/org/opensearch/ml/rest/SecureMLRestIT.java

+37
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
import org.opensearch.ml.common.AccessMode;
2727
import org.opensearch.ml.common.FunctionName;
2828
import org.opensearch.ml.common.MLTaskState;
29+
import org.opensearch.ml.common.agent.MLAgent;
30+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
31+
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
2932
import org.opensearch.ml.common.input.parameter.clustering.KMeansParams;
3033
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
3134
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
@@ -248,6 +251,40 @@ public void testDeployModelWithFullAccess() throws IOException, InterruptedExcep
248251
});
249252
}
250253

254+
public void testExecuteAgentWithFullAccess() throws IOException {
255+
MLAgent mlAgent = createCatIndexToolMLAgent();
256+
registerMLAgent(mlFullAccessClient, TestHelper.toJsonString(mlAgent), registerMLAgentResult -> {
257+
assertNotNull(registerMLAgentResult);
258+
assertTrue(registerMLAgentResult.containsKey("agent_id"));
259+
String agentId = (String) registerMLAgentResult.get("agent_id");
260+
try {
261+
AgentMLInput agentMLInput = AgentMLInput
262+
.AgentMLInputBuilder()
263+
.agentId(agentId)
264+
.functionName(FunctionName.AGENT)
265+
.inputDataset(
266+
RemoteInferenceInputDataSet.builder().parameters(Map.of("question", "How many indices do I have?")).build()
267+
)
268+
.build();
269+
270+
executeAgent(mlFullAccessClient, agentId, TestHelper.toJsonString(agentMLInput), mlExecuteTaskResponse -> {
271+
assertNotNull(mlExecuteTaskResponse);
272+
assertTrue(mlExecuteTaskResponse.containsKey("output"));
273+
});
274+
} catch (IOException e) {
275+
assertNull(e);
276+
}
277+
});
278+
}
279+
280+
public void testRegisterAgentWithFullAccess() throws IOException {
281+
MLAgent mlAgent = createCatIndexToolMLAgent();
282+
registerMLAgent(mlFullAccessClient, TestHelper.toJsonString(mlAgent), registerMLAgentResult -> {
283+
assertNotNull(registerMLAgentResult);
284+
assertTrue(registerMLAgentResult.containsKey("agent_id"));
285+
});
286+
}
287+
251288
public void testTrainWithReadOnlyMLAccess() throws IOException {
252289
exceptionRule.expect(ResponseException.class);
253290
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/train]");

0 commit comments

Comments
 (0)