Skip to content

Commit 0fd7e11

Browse files
committed
add agent framework security it tests
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent eeba1c3 commit 0fd7e11

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java

+28
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@
6666
import org.opensearch.ml.common.AccessMode;
6767
import org.opensearch.ml.common.FunctionName;
6868
import org.opensearch.ml.common.MLTaskState;
69+
import org.opensearch.ml.common.agent.MLAgent;
70+
import org.opensearch.ml.common.agent.MLToolSpec;
6971
import org.opensearch.ml.common.dataset.MLInputDataset;
7072
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
7173
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
@@ -731,6 +733,32 @@ public String registerModel(String input) throws IOException {
731733
return parseTaskIdFromResponse(response);
732734
}
733735

736+
public void registerMLAgent(RestClient client, String input, Consumer<Map<String, Object>> function) throws IOException {
737+
Response response = TestHelper.makeRequest(client, "POST", "/_plugins/_ml/agents/_register", null, input, null);
738+
verifyResponse(function, response);
739+
}
740+
741+
public void executeAgent(RestClient client, String agentId, String input, Consumer<Map<String, Object>> function) throws IOException {
742+
Response response = TestHelper.makeRequest(client, "POST", "/_plugins/_ml/agents/" + agentId + "/_execute", null, input, null);
743+
verifyResponse(function, response);
744+
}
745+
746+
public MLAgent createCatIndexToolMLAgent() {
747+
MLToolSpec catIndexTool = MLToolSpec
748+
.builder()
749+
.type("CatIndexTool")
750+
.name("DemoCatIndexTool")
751+
.parameters(Map.of("input", "${parameters.question}"))
752+
.build();
753+
return MLAgent
754+
.builder()
755+
.name("Test_Agent_For_CatIndex_tool")
756+
.type("flow")
757+
.description("this is a test agent for the CatIndexTool")
758+
.tools(List.of(catIndexTool))
759+
.build();
760+
}
761+
734762
public void deployModel(RestClient client, MLRegisterModelInput registerModelInput, Consumer<Map<String, Object>> function)
735763
throws IOException,
736764
InterruptedException {

plugin/src/test/java/org/opensearch/ml/rest/SecureMLRestIT.java

+37
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
import org.opensearch.ml.common.AccessMode;
2727
import org.opensearch.ml.common.FunctionName;
2828
import org.opensearch.ml.common.MLTaskState;
29+
import org.opensearch.ml.common.agent.MLAgent;
30+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
31+
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
2932
import org.opensearch.ml.common.input.parameter.clustering.KMeansParams;
3033
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
3134
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
@@ -35,6 +38,8 @@
3538
import com.google.common.base.Throwables;
3639
import com.google.common.collect.ImmutableList;
3740

41+
import static org.opensearch.ml.common.input.execute.agent.AgentMLInput.AgentMLInputBuilder;
42+
3843
public class SecureMLRestIT extends MLCommonsRestTestCase {
3944
private String irisIndex = "iris_data_secure_ml_it";
4045

@@ -248,6 +253,38 @@ public void testDeployModelWithFullAccess() throws IOException, InterruptedExcep
248253
});
249254
}
250255

256+
public void testExecuteAgentWithFullAccess() throws IOException {
257+
MLAgent mlAgent = createCatIndexToolMLAgent();
258+
registerMLAgent(mlFullAccessClient, TestHelper.toJsonString(mlAgent), registerMLAgentResult -> {
259+
assertNotNull(registerMLAgentResult);
260+
assertTrue(registerMLAgentResult.containsKey("agent_id"));
261+
String agentId = (String) registerMLAgentResult.get("agent_id");
262+
try {
263+
AgentMLInput agentMLInput = AgentMLInput
264+
.AgentMLInputBuilder()
265+
.agentId(agentId)
266+
.functionName(FunctionName.AGENT)
267+
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(Map.of("question", "How many indices do I have?")).build())
268+
.build();
269+
270+
executeAgent(mlFullAccessClient, agentId, TestHelper.toJsonString(agentMLInput), mlExecuteTaskResponse -> {
271+
assertNotNull(mlExecuteTaskResponse);
272+
assertTrue(mlExecuteTaskResponse.containsKey("output"));
273+
});
274+
} catch (IOException e) {
275+
assertNull(e);
276+
}
277+
});
278+
}
279+
280+
public void testRegisterAgentWithFullAccess() throws IOException {
281+
MLAgent mlAgent = createCatIndexToolMLAgent();
282+
registerMLAgent(mlFullAccessClient, TestHelper.toJsonString(mlAgent), registerMLAgentResult -> {
283+
assertNotNull(registerMLAgentResult);
284+
assertTrue(registerMLAgentResult.containsKey("agent_id"));
285+
});
286+
}
287+
251288
public void testTrainWithReadOnlyMLAccess() throws IOException {
252289
exceptionRule.expect(ResponseException.class);
253290
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/train]");

0 commit comments

Comments
 (0)