|
26 | 26 | import org.opensearch.ml.common.AccessMode;
|
27 | 27 | import org.opensearch.ml.common.FunctionName;
|
28 | 28 | import org.opensearch.ml.common.MLTaskState;
|
| 29 | +import org.opensearch.ml.common.agent.MLAgent; |
| 30 | +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; |
| 31 | +import org.opensearch.ml.common.input.execute.agent.AgentMLInput; |
29 | 32 | import org.opensearch.ml.common.input.parameter.clustering.KMeansParams;
|
30 | 33 | import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
|
31 | 34 | import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
|
|
35 | 38 | import com.google.common.base.Throwables;
|
36 | 39 | import com.google.common.collect.ImmutableList;
|
37 | 40 |
|
| 41 | +import static org.opensearch.ml.common.input.execute.agent.AgentMLInput.AgentMLInputBuilder; |
| 42 | + |
38 | 43 | public class SecureMLRestIT extends MLCommonsRestTestCase {
|
39 | 44 | private String irisIndex = "iris_data_secure_ml_it";
|
40 | 45 |
|
@@ -248,6 +253,38 @@ public void testDeployModelWithFullAccess() throws IOException, InterruptedExcep
|
248 | 253 | });
|
249 | 254 | }
|
250 | 255 |
|
| 256 | + public void testExecuteAgentWithFullAccess() throws IOException { |
| 257 | + MLAgent mlAgent = createCatIndexToolMLAgent(); |
| 258 | + registerMLAgent(mlFullAccessClient, TestHelper.toJsonString(mlAgent), registerMLAgentResult -> { |
| 259 | + assertNotNull(registerMLAgentResult); |
| 260 | + assertTrue(registerMLAgentResult.containsKey("agent_id")); |
| 261 | + String agentId = (String) registerMLAgentResult.get("agent_id"); |
| 262 | + try { |
| 263 | + AgentMLInput agentMLInput = AgentMLInput |
| 264 | + .AgentMLInputBuilder() |
| 265 | + .agentId(agentId) |
| 266 | + .functionName(FunctionName.AGENT) |
| 267 | + .inputDataset(RemoteInferenceInputDataSet.builder().parameters(Map.of("question", "How many indices do I have?")).build()) |
| 268 | + .build(); |
| 269 | + |
| 270 | + executeAgent(mlFullAccessClient, agentId, TestHelper.toJsonString(agentMLInput), mlExecuteTaskResponse -> { |
| 271 | + assertNotNull(mlExecuteTaskResponse); |
| 272 | + assertTrue(mlExecuteTaskResponse.containsKey("output")); |
| 273 | + }); |
| 274 | + } catch (IOException e) { |
| 275 | + assertNull(e); |
| 276 | + } |
| 277 | + }); |
| 278 | + } |
| 279 | + |
| 280 | + public void testRegisterAgentWithFullAccess() throws IOException { |
| 281 | + MLAgent mlAgent = createCatIndexToolMLAgent(); |
| 282 | + registerMLAgent(mlFullAccessClient, TestHelper.toJsonString(mlAgent), registerMLAgentResult -> { |
| 283 | + assertNotNull(registerMLAgentResult); |
| 284 | + assertTrue(registerMLAgentResult.containsKey("agent_id")); |
| 285 | + }); |
| 286 | + } |
| 287 | + |
251 | 288 | public void testTrainWithReadOnlyMLAccess() throws IOException {
|
252 | 289 | exceptionRule.expect(ResponseException.class);
|
253 | 290 | exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/train]");
|
|
0 commit comments