Skip to content

Commit 6b7af6f

Browse files
committedMar 22, 2024
add agent framework security it tests
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent eeba1c3 commit 6b7af6f

File tree

2 files changed

+204
-0
lines changed

2 files changed

+204
-0
lines changed
 

‎plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java

+45
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,14 @@
6666
import org.opensearch.ml.common.AccessMode;
6767
import org.opensearch.ml.common.FunctionName;
6868
import org.opensearch.ml.common.MLTaskState;
69+
import org.opensearch.ml.common.agent.MLAgent;
70+
import org.opensearch.ml.common.agent.MLToolSpec;
6971
import org.opensearch.ml.common.dataset.MLInputDataset;
7072
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
7173
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
74+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
7275
import org.opensearch.ml.common.input.MLInput;
76+
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
7377
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
7478
import org.opensearch.ml.common.model.MLModelConfig;
7579
import org.opensearch.ml.common.model.MLModelFormat;
@@ -731,6 +735,47 @@ public String registerModel(String input) throws IOException {
731735
return parseTaskIdFromResponse(response);
732736
}
733737

738+
public void registerMLAgent(RestClient client, String input, Consumer<Map<String, Object>> function) throws IOException {
739+
Response response = TestHelper.makeRequest(client, "POST", "/_plugins/_ml/agents/_register", null, input, null);
740+
verifyResponse(function, response);
741+
}
742+
743+
public void executeAgent(RestClient client, String agentId, String input, Consumer<Map<String, Object>> function) throws IOException {
744+
Response response = TestHelper.makeRequest(client, "POST", "/_plugins/_ml/agents/" + agentId + "/_execute", null, input, null);
745+
verifyResponse(function, response);
746+
}
747+
748+
public void getAgent(RestClient client, String agentId, Consumer<Map<String, Object>> function) throws IOException {
749+
Response response = TestHelper.makeRequest(client, "GET", "/_plugins/_ml/agents/" + agentId, null, "", null);
750+
verifyResponse(function, response);
751+
}
752+
753+
public void searchAgent(RestClient client, String input, Consumer<Map<String, Object>> function) throws IOException {
754+
Response response = TestHelper.makeRequest(client, "POST", "/_plugins/_ml/agents/_search", null, input, null);
755+
verifyResponse(function, response);
756+
}
757+
758+
public void deleteAgent(RestClient client, String agentId, Consumer<Map<String, Object>> function) throws IOException {
759+
Response response = TestHelper.makeRequest(client, "DELETE", "/_plugins/_ml/agents/" + agentId, null, "", null);
760+
verifyResponse(function, response);
761+
}
762+
763+
public MLAgent createCatIndexToolMLAgent() {
764+
MLToolSpec catIndexTool = MLToolSpec
765+
.builder()
766+
.type("CatIndexTool")
767+
.name("DemoCatIndexTool")
768+
.parameters(Map.of("input", "${parameters.question}"))
769+
.build();
770+
return MLAgent
771+
.builder()
772+
.name("Test_Agent_For_CatIndex_tool")
773+
.type("flow")
774+
.description("this is a test agent for the CatIndexTool")
775+
.tools(List.of(catIndexTool))
776+
.build();
777+
}
778+
734779
public void deployModel(RestClient client, MLRegisterModelInput registerModelInput, Consumer<Map<String, Object>> function)
735780
throws IOException,
736781
InterruptedException {

‎plugin/src/test/java/org/opensearch/ml/rest/SecureMLRestIT.java

+159
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,23 @@
1818
import org.junit.Before;
1919
import org.junit.Rule;
2020
import org.junit.rules.ExpectedException;
21+
import org.opensearch.action.search.SearchRequest;
2122
import org.opensearch.client.Response;
2223
import org.opensearch.client.ResponseException;
2324
import org.opensearch.client.RestClient;
2425
import org.opensearch.commons.rest.SecureRestClientBuilder;
2526
import org.opensearch.index.query.MatchAllQueryBuilder;
27+
import org.opensearch.index.query.QueryBuilders;
2628
import org.opensearch.ml.common.AccessMode;
2729
import org.opensearch.ml.common.FunctionName;
2830
import org.opensearch.ml.common.MLTaskState;
31+
import org.opensearch.ml.common.agent.MLAgent;
32+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
33+
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
2934
import org.opensearch.ml.common.input.parameter.clustering.KMeansParams;
35+
import org.opensearch.ml.common.transport.agent.MLAgentDeleteRequest;
36+
import org.opensearch.ml.common.transport.agent.MLAgentGetRequest;
37+
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
3038
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
3139
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
3240
import org.opensearch.ml.utils.TestHelper;
@@ -35,6 +43,8 @@
3543
import com.google.common.base.Throwables;
3644
import com.google.common.collect.ImmutableList;
3745

46+
import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX;
47+
3848
public class SecureMLRestIT extends MLCommonsRestTestCase {
3949
private String irisIndex = "iris_data_secure_ml_it";
4050

@@ -59,6 +69,8 @@ public class SecureMLRestIT extends MLCommonsRestTestCase {
5969

6070
private String modelGroupId;
6171

72+
private MLAgent mlAgent;
73+
6274
/**
6375
* Create an unguessable password. Simple password are weak due to https://tinyurl.com/383em9zk
6476
* @return a random password.
@@ -151,6 +163,8 @@ public void setup() throws IOException {
151163
this.modelGroupId = (String) registerModelGroupResult.get("model_group_id");
152164
});
153165
mlRegisterModelInput = createRegisterModelInput(modelGroupId);
166+
167+
mlAgent = createCatIndexToolMLAgent();
154168
}
155169

156170
@After
@@ -248,6 +262,151 @@ public void testDeployModelWithFullAccess() throws IOException, InterruptedExcep
248262
});
249263
}
250264

265+
public void testExecuteAgentWithFullAccess() throws IOException {
266+
registerMLAgent(mlFullAccessClient, TestHelper.toJsonString(mlAgent), registerMLAgentResult -> {
267+
assertNotNull(registerMLAgentResult);
268+
assertTrue(registerMLAgentResult.containsKey("agent_id"));
269+
String agentId = (String) registerMLAgentResult.get("agent_id");
270+
try {
271+
AgentMLInput agentMLInput = AgentMLInput
272+
.AgentMLInputBuilder()
273+
.agentId(agentId)
274+
.functionName(FunctionName.AGENT)
275+
.inputDataset(
276+
RemoteInferenceInputDataSet.builder().parameters(Map.of("question", "How many indices do I have?")).build()
277+
)
278+
.build();
279+
280+
executeAgent(mlFullAccessClient, agentId, TestHelper.toJsonString(agentMLInput), mlExecuteTaskResponse -> {
281+
assertNotNull(mlExecuteTaskResponse);
282+
assertTrue(mlExecuteTaskResponse.containsKey("inference_results"));
283+
});
284+
} catch (IOException e) {
285+
assertNull(e);
286+
}
287+
});
288+
}
289+
290+
public void testExecuteAgentWithReadOnlyAccess() throws IOException {
291+
exceptionRule.expect(RuntimeException.class);
292+
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/execute]");
293+
AgentMLInput agentMLInput = AgentMLInput
294+
.AgentMLInputBuilder()
295+
.agentId("test-agent")
296+
.functionName(FunctionName.AGENT)
297+
.inputDataset(
298+
RemoteInferenceInputDataSet.builder().parameters(Map.of("question", "How many indices do I have?")).build()
299+
)
300+
.build();
301+
302+
executeAgent(mlReadOnlyClient, "test-agent", TestHelper.toJsonString(agentMLInput), mlExecuteTaskResponse -> {
303+
assertNotNull(mlExecuteTaskResponse);
304+
assertTrue(mlExecuteTaskResponse.containsKey("inference_results"));
305+
});
306+
}
307+
308+
public void testGetAgentWithFullAccess() throws IOException {
309+
registerMLAgent(mlFullAccessClient, TestHelper.toJsonString(mlAgent), registerMLAgentResult -> {
310+
assertNotNull(registerMLAgentResult);
311+
assertTrue(registerMLAgentResult.containsKey("agent_id"));
312+
String agentId = (String) registerMLAgentResult.get("agent_id");
313+
try {
314+
MLAgentGetRequest mlAgentGetRequest = MLAgentGetRequest.builder().agentId(agentId).build();
315+
getAgent(mlFullAccessClient, agentId, mlGetAgentResponse -> {
316+
assertNotNull(mlGetAgentResponse);
317+
assertTrue(mlGetAgentResponse.containsKey("name"));
318+
assertEquals(mlGetAgentResponse.get("name"), "Test_Agent_For_CatIndex_tool");
319+
});
320+
} catch (IOException e) {
321+
assertNull(e);
322+
}
323+
});
324+
}
325+
326+
public void testGetAgentWithNoAccess() throws IOException {
327+
exceptionRule.expect(RuntimeException.class);
328+
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/agents/get]");
329+
330+
getAgent(mlNoAccessClient, "test-agent", mlExecuteTaskResponse -> {
331+
assertNotNull(mlExecuteTaskResponse);
332+
assertTrue(mlExecuteTaskResponse.containsKey("inference_results"));
333+
});
334+
}
335+
336+
public void testSearchAgentWithFullAccess() throws IOException {
337+
registerMLAgent(mlFullAccessClient, TestHelper.toJsonString(mlAgent), registerMLAgentResult -> {
338+
assertNotNull(registerMLAgentResult);
339+
assertTrue(registerMLAgentResult.containsKey("agent_id"));
340+
try {
341+
SearchRequest searchRequest = new SearchRequest(ML_AGENT_INDEX);
342+
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
343+
searchSourceBuilder.query(QueryBuilders.matchAllQuery());
344+
searchRequest.source(searchSourceBuilder);
345+
searchAgent(mlFullAccessClient, gson.toJson(searchRequest), mlSearchAgentResponse -> {
346+
assertNotNull(mlSearchAgentResponse);
347+
assertTrue(mlSearchAgentResponse.containsKey("hits"));
348+
});
349+
} catch (IOException e) {
350+
assertNull(e);
351+
}
352+
});
353+
}
354+
355+
public void testSearchAgentWithNoAccess() throws IOException {
356+
exceptionRule.expect(RuntimeException.class);
357+
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/agents/search]");
358+
359+
SearchRequest searchRequest = new SearchRequest(ML_AGENT_INDEX);
360+
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
361+
searchSourceBuilder.query(QueryBuilders.matchAllQuery());
362+
searchRequest.source(searchSourceBuilder);
363+
searchAgent(mlNoAccessClient, gson.toJson(searchRequest), mlSearchAgentResponse -> {
364+
assertNotNull(mlSearchAgentResponse);
365+
assertTrue(mlSearchAgentResponse.containsKey("hits"));
366+
});
367+
}
368+
369+
public void testDeleteAgentWithFullAccess() throws IOException {
370+
registerMLAgent(mlFullAccessClient, TestHelper.toJsonString(mlAgent), registerMLAgentResult -> {
371+
assertNotNull(registerMLAgentResult);
372+
assertTrue(registerMLAgentResult.containsKey("agent_id"));
373+
String agentId = (String) registerMLAgentResult.get("agent_id");
374+
try {
375+
deleteAgent(mlFullAccessClient, agentId, mlSearchAgentResponse -> {
376+
assertNotNull(mlSearchAgentResponse);
377+
assertTrue(mlSearchAgentResponse.containsKey("result"));
378+
assertEquals(mlSearchAgentResponse.get("result"), "deleted");
379+
});
380+
} catch (IOException e) {
381+
assertNull(e);
382+
}
383+
});
384+
}
385+
386+
public void testDeleteAgentWithNoAccess() throws IOException {
387+
exceptionRule.expect(RuntimeException.class);
388+
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/agents/delete]");
389+
390+
deleteAgent(mlReadOnlyClient, "agentId", mlSearchAgentResponse -> {
391+
assertNotNull(mlSearchAgentResponse);
392+
assertTrue(mlSearchAgentResponse.containsKey("result"));
393+
assertEquals(mlSearchAgentResponse.get("result"), "deleted");
394+
});
395+
}
396+
397+
public void testRegisterAgentWithFullAccess() throws IOException {
398+
registerMLAgent(mlFullAccessClient, TestHelper.toJsonString(mlAgent), registerMLAgentResult -> {
399+
assertNotNull(registerMLAgentResult);
400+
assertTrue(registerMLAgentResult.containsKey("agent_id"));
401+
});
402+
}
403+
404+
public void testRegisterAgentWithReadOnlyMLAccess() throws IOException {
405+
exceptionRule.expect(ResponseException.class);
406+
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/agents/register]");
407+
registerMLAgent(mlReadOnlyClient, TestHelper.toJsonString(mlAgent), null);
408+
}
409+
251410
public void testTrainWithReadOnlyMLAccess() throws IOException {
252411
exceptionRule.expect(ResponseException.class);
253412
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/train]");

0 commit comments

Comments
 (0)