diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index 11b9a91f5b..7e9bdd1ab0 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -15,6 +15,9 @@ import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -23,6 +26,8 @@ import java.util.regex.Pattern; import org.apache.commons.text.StringSubstitutor; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.spi.tools.Tool; public class AgentUtils { @@ -152,4 +157,32 @@ public static String extractModelResponseJson(String text) { throw new IllegalArgumentException("Model output is invalid"); } } + + public static String outputToOutputString(Object output) throws PrivilegedActionException { + String outputString; + if (output instanceof ModelTensorOutput) { + ModelTensor outputModel = ((ModelTensorOutput) output).getMlModelOutputs().get(0).getMlModelTensors().get(0); + if (outputModel.getDataAsMap() != null) { + outputString = AccessController + .doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(outputModel.getDataAsMap())); + } else { + outputString = outputModel.getResult(); + } + } else if (output instanceof String) { + outputString = (String) output; + } else { + outputString = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(output)); + } + return outputString; + } + + public static String parseInputFromLLMReturn(Map retMap) { + Object actionInput = retMap.get("action_input"); + if (actionInput instanceof Map) { + return gson.toJson(actionInput); + } else { + return String.valueOf(actionInput); + } + + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index d2c513eabd..462115d1cf 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -9,9 +9,9 @@ import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.extractModelResponseJson; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.outputToOutputString; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseInputFromLLMReturn; -import java.security.AccessController; -import java.security.PrivilegedExceptionAction; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -325,7 +325,7 @@ private void runReAct( } String thought = String.valueOf(dataAsMap.get("thought")); String action = String.valueOf(dataAsMap.get("action")); - String actionInput = String.valueOf(dataAsMap.get("action_input")); + String actionInput = parseInputFromLLMReturn(dataAsMap); String finalAnswer = (String) dataAsMap.get("final_answer"); if (!dataAsMap.containsKey("thought")) { String response = (String) dataAsMap.get("response"); @@ -336,7 +336,7 @@ private void runReAct( Map map = gson.fromJson(jsonBlock, Map.class); thought = String.valueOf(map.get("thought")); action = String.valueOf(map.get("action")); - actionInput = String.valueOf(map.get("action_input")); + actionInput = parseInputFromLLMReturn(map); finalAnswer = (String) map.get("final_answer"); } else { finalAnswer = response; @@ -524,9 +524,7 @@ private void runReAct( } else { MLToolSpec toolSpec = toolSpecMap.get(lastAction.get()); if (toolSpec != null && toolSpec.isIncludeOutputInAgentResponse()) { - String outputString = output instanceof String - ? (String) output - : AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(output)); + String outputString = outputToOutputString(output); String toolOutputKey = String.format("%s.output", toolSpec.getType()); if (additionalInfo.get(toolOutputKey) != null) { @@ -546,7 +544,13 @@ private void runReAct( .singletonList( ModelTensor .builder() - .dataAsMap(ImmutableMap.of("response", lastThought.get() + "\nObservation: " + output)) + .dataAsMap( + ImmutableMap + .of( + "response", + lastThought.get() + "\nObservation: " + outputToOutputString(output) + ) + ) .build() ) ) @@ -555,7 +559,7 @@ private void runReAct( String toolResponse = tmpParameters.get("prompt.tool_response"); StringSubstitutor toolResponseSubstitutor = new StringSubstitutor( - ImmutableMap.of("observation", output), + ImmutableMap.of("observation", outputToOutputString(output)), "${parameters.", "}" ); @@ -567,7 +571,7 @@ private void runReAct( .conversationIndexMessageBuilder() .type("ReAct") .question(lastActionInput.get()) - .response((String) output) + .response(outputToOutputString(output)) .finalAnswer(false) .sessionId(sessionId) .build(); @@ -582,7 +586,7 @@ private void runReAct( newPrompt.set(substitutor.replace(finalPrompt)); tmpParameters.put(PROMPT, newPrompt.get()); - sessionMsgAnswerBuilder.append("\nObservation: ").append(output); + sessionMsgAnswerBuilder.append("\nObservation: ").append(outputToOutputString(output)); cotModelTensors .add( ModelTensors diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java index f048c62dc8..197f562bb6 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java @@ -5,6 +5,9 @@ package org.opensearch.ml.engine.tools; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.HashMap; import java.util.Map; import org.opensearch.action.ActionRequest; @@ -51,11 +54,12 @@ public AgentTool(Client client, String agentId) { @Override public void run(Map parameters, ActionListener listener) { + Map extractedParameters = extractInputParameters(parameters); AgentMLInput agentMLInput = AgentMLInput .AgentMLInputBuilder() .agentId(agentId) .functionName(FunctionName.AGENT) - .inputDataset(RemoteInferenceInputDataSet.builder().parameters(parameters).build()) + .inputDataset(RemoteInferenceInputDataSet.builder().parameters(extractedParameters).build()) .build(); ActionRequest request = new MLExecuteTaskRequest(FunctionName.AGENT, agentMLInput, false); client.execute(MLExecuteTaskAction.INSTANCE, request, ActionListener.wrap(r -> { @@ -135,4 +139,18 @@ public String getDefaultVersion() { return null; } } + + private Map extractInputParameters(Map parameters) { + Map extractedParameters = new HashMap<>(); + extractedParameters.putAll(parameters); + if (parameters.containsKey("input")) { + try { + Map chatParameters = gson.fromJson(parameters.get("input"), Map.class); + extractedParameters.putAll(chatParameters); + } catch (Exception exception) { + log.info("fail extract parameters from key 'input' due to" + exception.getMessage()); + } + } + return extractedParameters; + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 2247bf00c5..8e8f6c3235 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -188,6 +188,86 @@ public void testParsingJsonBlockFromResponse() { assertEquals("parsed final answer", modelTensor2.getResult()); } + @Test + public void testParsingJsonBlockFromResponse2() { + // Prepare the response with JSON block + String jsonBlock = "{\"thought\":\"parsed thought\", \"action\":\"parsed action\", " + + "\"action_input\":\"parsed action input\", \"final_answer\":\"parsed final answer\"}"; + String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text"; + + // Mock LLM response to not contain "thought" but contain "response" with JSON block + Map llmResponse = new HashMap<>(); + llmResponse.put("response", responseWithJsonBlock); + doAnswer(getLLMAnswer(llmResponse)) + .when(client) + .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + + // Create an MLAgent and run the MLChatAgentRunner + MLAgent mlAgent = createMLAgentWithTools(); + Map params = new HashMap<>(); + params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); + params.put("verbose", "true"); + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + + // Capture the response passed to the listener + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Object.class); + verify(agentActionListener).onResponse(responseCaptor.capture()); + + // Extract the captured response + Object capturedResponse = responseCaptor.getValue(); + assertTrue(capturedResponse instanceof ModelTensorOutput); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse; + + ModelTensor parentInteractionModelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1); + ModelTensor modelTensor1 = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0); + ModelTensor modelTensor2 = modelTensorOutput.getMlModelOutputs().get(2).getMlModelTensors().get(0); + + // Verify that the parsed values from JSON block are correctly set + assertEquals("parent_interaction_id", parentInteractionModelTensor.getResult()); + assertEquals("Thought: parsed thought", modelTensor1.getResult()); + assertEquals("parsed final answer", modelTensor2.getResult()); + } + + @Test + public void testParsingJsonBlockFromResponse3() { + // Prepare the response with JSON block + String jsonBlock = "{\"thought\":\"parsed thought\", \"action\":\"parsed action\", " + + "\"action_input\":{\"a\":\"n\"}, \"final_answer\":\"parsed final answer\"}"; + String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text"; + + // Mock LLM response to not contain "thought" but contain "response" with JSON block + Map llmResponse = new HashMap<>(); + llmResponse.put("response", responseWithJsonBlock); + doAnswer(getLLMAnswer(llmResponse)) + .when(client) + .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + + // Create an MLAgent and run the MLChatAgentRunner + MLAgent mlAgent = createMLAgentWithTools(); + Map params = new HashMap<>(); + params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); + params.put("verbose", "true"); + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + + // Capture the response passed to the listener + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Object.class); + verify(agentActionListener).onResponse(responseCaptor.capture()); + + // Extract the captured response + Object capturedResponse = responseCaptor.getValue(); + assertTrue(capturedResponse instanceof ModelTensorOutput); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse; + + ModelTensor parentInteractionModelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1); + ModelTensor modelTensor1 = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0); + ModelTensor modelTensor2 = modelTensorOutput.getMlModelOutputs().get(2).getMlModelTensors().get(0); + + // Verify that the parsed values from JSON block are correctly set + assertEquals("parent_interaction_id", parentInteractionModelTensor.getResult()); + assertEquals("Thought: parsed thought", modelTensor1.getResult()); + assertEquals("parsed final answer", modelTensor2.getResult()); + } + @Test public void testRunWithIncludeOutputNotSet() { LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); @@ -209,6 +289,35 @@ public void testRunWithIncludeOutputNotSet() { assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response")); } + @Test + public void testRunWithIncludeOutputMLModel() { + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); + Mockito + .doAnswer(generateToolResponseAsMLModelResult("First tool response", 1)) + .when(firstTool) + .run(Mockito.anyMap(), toolListenerCaptor.capture()); + Mockito + .doAnswer(generateToolResponseAsMLModelResult("Second tool response", 2)) + .when(secondTool) + .run(Mockito.anyMap(), toolListenerCaptor.capture()); + MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build(); + MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).build(); + final MLAgent mlAgent = MLAgent + .builder() + .name("TestAgent") + .llm(llmSpec) + .memory(mlMemorySpec) + .tools(Arrays.asList(firstToolSpec, secondToolSpec)) + .build(); + mlChatAgentRunner.run(mlAgent, new HashMap<>(), agentActionListener); + Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue(); + List agentOutput = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors(); + assertEquals(1, agentOutput.size()); + // Respond with last tool output + assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response")); + } + @Test public void testRunWithIncludeOutputSet() { LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); @@ -512,6 +621,23 @@ private Answer generateToolResponse(String response) { }; } + private Answer generateToolResponseAsMLModelResult(String response, int type) { + ModelTensor modelTensor; + if (type == 1) { + modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("return", response)).build(); + } else { + modelTensor = ModelTensor.builder().result(response).build(); + } + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + return invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mlModelTensorOutput); + return null; + }; + } + private Answer generateToolFailure(Exception e) { return invocation -> { ActionListener listener = invocation.getArgument(1); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java index 431e609bba..02b6627a31 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java @@ -10,6 +10,7 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.verify; +import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.engine.tools.AgentTool.DEFAULT_DESCRIPTION; import java.util.Arrays; @@ -24,8 +25,6 @@ import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; -import org.opensearch.ml.common.input.execute.agent.AgentMLInput; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; @@ -67,13 +66,28 @@ public void setup() { public void testAgenttestRunMethod() { Map parameters = new HashMap<>(); parameters.put("testKey", "testValue"); - AgentMLInput agentMLInput = AgentMLInput - .AgentMLInputBuilder() - .agentId("agentId") - .functionName(FunctionName.AGENT) - .inputDataset(RemoteInferenceInputDataSet.builder().parameters(parameters).build()) - .build(); + doTestRunMethod(parameters); + } + + @Test + public void testAgentWithChatAgentInput() { + Map parameters = new HashMap<>(); + parameters.put("testKey", "testValue"); + Map chatAgentInput = new HashMap<>(); + chatAgentInput.put("input", gson.toJson(parameters)); + doTestRunMethod(chatAgentInput); + assertEquals(chatAgentInput.size(), 1); + assertEquals(chatAgentInput.get("input"), gson.toJson(parameters)); // assert no influence on original parameters + } + + @Test + public void testAgentWithChatAgentInputWrongFormat() { + Map chatAgentInput = new HashMap<>(); + chatAgentInput.put("input", "wrong format"); + doTestRunMethod(chatAgentInput); + } + private void doTestRunMethod(Map parameters) { ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("thought", "thought 1", "action", "action1")).build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();