|
26 | 26 | import org.opensearch.ml.common.AccessMode;
|
27 | 27 | import org.opensearch.ml.common.FunctionName;
|
28 | 28 | import org.opensearch.ml.common.MLTaskState;
|
| 29 | +import org.opensearch.ml.common.agent.MLAgent; |
| 30 | +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; |
| 31 | +import org.opensearch.ml.common.input.execute.agent.AgentMLInput; |
29 | 32 | import org.opensearch.ml.common.input.parameter.clustering.KMeansParams;
|
30 | 33 | import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
|
31 | 34 | import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
|
@@ -248,6 +251,40 @@ public void testDeployModelWithFullAccess() throws IOException, InterruptedExcep
|
248 | 251 | });
|
249 | 252 | }
|
250 | 253 |
|
| 254 | + public void testExecuteAgentWithFullAccess() throws IOException { |
| 255 | + MLAgent mlAgent = createCatIndexToolMLAgent(); |
| 256 | + registerMLAgent(mlFullAccessClient, TestHelper.toJsonString(mlAgent), registerMLAgentResult -> { |
| 257 | + assertNotNull(registerMLAgentResult); |
| 258 | + assertTrue(registerMLAgentResult.containsKey("agent_id")); |
| 259 | + String agentId = (String) registerMLAgentResult.get("agent_id"); |
| 260 | + try { |
| 261 | + AgentMLInput agentMLInput = AgentMLInput |
| 262 | + .AgentMLInputBuilder() |
| 263 | + .agentId(agentId) |
| 264 | + .functionName(FunctionName.AGENT) |
| 265 | + .inputDataset( |
| 266 | + RemoteInferenceInputDataSet.builder().parameters(Map.of("question", "How many indices do I have?")).build() |
| 267 | + ) |
| 268 | + .build(); |
| 269 | + |
| 270 | + executeAgent(mlFullAccessClient, agentId, TestHelper.toJsonString(agentMLInput), mlExecuteTaskResponse -> { |
| 271 | + assertNotNull(mlExecuteTaskResponse); |
| 272 | + assertTrue(mlExecuteTaskResponse.containsKey("output")); |
| 273 | + }); |
| 274 | + } catch (IOException e) { |
| 275 | + assertNull(e); |
| 276 | + } |
| 277 | + }); |
| 278 | + } |
| 279 | + |
| 280 | + public void testRegisterAgentWithFullAccess() throws IOException { |
| 281 | + MLAgent mlAgent = createCatIndexToolMLAgent(); |
| 282 | + registerMLAgent(mlFullAccessClient, TestHelper.toJsonString(mlAgent), registerMLAgentResult -> { |
| 283 | + assertNotNull(registerMLAgentResult); |
| 284 | + assertTrue(registerMLAgentResult.containsKey("agent_id")); |
| 285 | + }); |
| 286 | + } |
| 287 | + |
251 | 288 | public void testTrainWithReadOnlyMLAccess() throws IOException {
|
252 | 289 | exceptionRule.expect(ResponseException.class);
|
253 | 290 | exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/train]");
|
|
0 commit comments