From 97f1199c971f23ab8a27a3fc26b1585490ddbc4a Mon Sep 17 00:00:00 2001 From: xinyual Date: Sat, 27 Jan 2024 11:12:38 +0800 Subject: [PATCH 01/14] add logs Signed-off-by: xinyual --- .../opensearch/ml/engine/tools/AgentTool.java | 19 +++++++++++++++++++ .../ml/engine/tools/MLModelTool.java | 2 ++ 2 files changed, 21 insertions(+) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java index f048c62dc8..ab614640f0 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java @@ -24,6 +24,8 @@ import lombok.Setter; import lombok.extern.log4j.Log4j2; +import static org.opensearch.ml.common.utils.StringUtils.gson; + /** * This tool supports running any Agent. */ @@ -51,6 +53,11 @@ public AgentTool(Client client, String agentId) { @Override public void run(Map parameters, ActionListener listener) { + log.info("Agent tool before"); + log.info(parameters); + parameters = extractFromChatParameters(parameters); + log.info("Agent tool after"); + log.info(parameters); AgentMLInput agentMLInput = AgentMLInput .AgentMLInputBuilder() .agentId(agentId) @@ -135,4 +142,16 @@ public String getDefaultVersion() { return null; } } + + private Map extractFromChatParameters(Map parameters) { + if (parameters.containsKey("input")) { + try { + Map chatParameters = gson.fromJson(parameters.get("input"), Map.class); + parameters.putAll(chatParameters); + } finally { + return parameters; + } + } + return parameters; + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java index cba0d2ee6a..ef0d05a788 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java @@ -65,6 +65,8 @@ public MLModelTool(Client client, String modelId) { @Override public void run(Map parameters, ActionListener listener) { + log.info("ML input"); + log.info(parameters); RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); ActionRequest request = new MLPredictionTaskRequest( modelId, From 1d9d3b7d983768cb3a3ba33fef6f07c0572d98d0 Mon Sep 17 00:00:00 2001 From: xinyual Date: Mon, 29 Jan 2024 10:36:27 +0800 Subject: [PATCH 02/14] option2FixBug Signed-off-by: xinyual --- .../ml/common/connector/HttpConnector.java | 3 +- .../algorithms/agent/MLChatAgentRunner.java | 51 +++++++++++++++---- .../remote/AwsConnectorExecutor.java | 2 + 3 files changed, 45 insertions(+), 11 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index ef0e4bf4a1..5e7191836b 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -289,7 +289,8 @@ public T createPredictPayload(Map parameters) { payload = fillNullParameters(parameters, payload); StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); payload = substitutor.replace(payload); - + log.info("to LLM"); + log.info(payload); if (!isJson(payload)) { throw new IllegalArgumentException("Invalid JSON in payload"); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index d2c513eabd..093b17c5bb 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -8,9 +8,11 @@ import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD; import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.common.utils.StringUtils.toUTF8; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.extractModelResponseJson; import java.security.AccessController; +import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; import java.util.ArrayList; import java.util.Collections; @@ -325,7 +327,7 @@ private void runReAct( } String thought = String.valueOf(dataAsMap.get("thought")); String action = String.valueOf(dataAsMap.get("action")); - String actionInput = String.valueOf(dataAsMap.get("action_input")); + String actionInput = gson.toJson(dataAsMap.get("action_input")); String finalAnswer = (String) dataAsMap.get("final_answer"); if (!dataAsMap.containsKey("thought")) { String response = (String) dataAsMap.get("response"); @@ -336,7 +338,7 @@ private void runReAct( Map map = gson.fromJson(jsonBlock, Map.class); thought = String.valueOf(map.get("thought")); action = String.valueOf(map.get("action")); - actionInput = String.valueOf(map.get("action_input")); + actionInput = gson.toJson(map.get("action_input")); finalAnswer = (String) map.get("final_answer"); } else { finalAnswer = response; @@ -524,9 +526,7 @@ private void runReAct( } else { MLToolSpec toolSpec = toolSpecMap.get(lastAction.get()); if (toolSpec != null && toolSpec.isIncludeOutputInAgentResponse()) { - String outputString = output instanceof String - ? (String) output - : AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(output)); + String outputString = outputToOutputString(output); String toolOutputKey = String.format("%s.output", toolSpec.getType()); if (additionalInfo.get(toolOutputKey) != null) { @@ -546,7 +546,7 @@ private void runReAct( .singletonList( ModelTensor .builder() - .dataAsMap(ImmutableMap.of("response", lastThought.get() + "\nObservation: " + output)) + .dataAsMap(ImmutableMap.of("response", lastThought.get() + "\nObservation: " + outputToOutputString(output))) .build() ) ) @@ -555,7 +555,7 @@ private void runReAct( String toolResponse = tmpParameters.get("prompt.tool_response"); StringSubstitutor toolResponseSubstitutor = new StringSubstitutor( - ImmutableMap.of("observation", output), + ImmutableMap.of("observation", outputToOutputString(output)), "${parameters.", "}" ); @@ -567,7 +567,7 @@ private void runReAct( .conversationIndexMessageBuilder() .type("ReAct") .question(lastActionInput.get()) - .response((String) output) + .response(outputToOutputString(output)) .finalAnswer(false) .sessionId(sessionId) .build(); @@ -582,7 +582,7 @@ private void runReAct( newPrompt.set(substitutor.replace(finalPrompt)); tmpParameters.put(PROMPT, newPrompt.get()); - sessionMsgAnswerBuilder.append("\nObservation: ").append(output); + sessionMsgAnswerBuilder.append("\nObservation: ").append(outputToOutputString(output)); cotModelTensors .add( ModelTensors @@ -629,7 +629,15 @@ private void runReAct( listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build()); } } else { - client.execute(MLPredictionTaskAction.INSTANCE, request, (ActionListener) nextStepListener); + ActionRequest request2 = new MLPredictionTaskRequest( + llm.getModelId(), + RemoteInferenceMLInput + .builder() + .algorithm(FunctionName.REMOTE) + .inputDataset(RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build()) + .build() + ); + client.execute(MLPredictionTaskAction.INSTANCE, request2, (ActionListener) nextStepListener); } } }, e -> { @@ -652,4 +660,27 @@ private void runReAct( client.execute(MLPredictionTaskAction.INSTANCE, request, firstListener); } + private String outputToOutputString(Object output) throws PrivilegedActionException { + String outputString; + if (output instanceof ModelTensorOutput) + { + ModelTensor outputModel = ((ModelTensorOutput) output).getMlModelOutputs().get(0).getMlModelTensors().get(0); + if (outputModel.getDataAsMap() != null) + { + outputString = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(outputModel.getDataAsMap())); + } + else { + outputString = outputModel.getResult(); + } + } + else if (output instanceof String) + { + outputString = (String) output; + } + else { + outputString = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(output)); + } + return outputString; + } + } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index 0e8169ac64..b7ef5c8d7a 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -117,6 +117,8 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S throw new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST); } String modelResponse = responseBuilder.toString(); + log.info("from LLM"); + log.info(modelResponse); if (statusCode < 200 || statusCode >= 300) { throw new OpenSearchStatusException(REMOTE_SERVICE_ERROR + modelResponse, RestStatus.fromCode(statusCode)); } From 20aa5bb8431bc1dcf627ff04dc0f9fc23ca61e3e Mon Sep 17 00:00:00 2001 From: xinyual Date: Mon, 29 Jan 2024 10:40:49 +0800 Subject: [PATCH 03/14] remove useless log Signed-off-by: xinyual --- .../java/org/opensearch/ml/common/connector/HttpConnector.java | 2 -- .../ml/engine/algorithms/remote/AwsConnectorExecutor.java | 2 -- 2 files changed, 4 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index 5e7191836b..d3b4892e3e 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -289,8 +289,6 @@ public T createPredictPayload(Map parameters) { payload = fillNullParameters(parameters, payload); StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); payload = substitutor.replace(payload); - log.info("to LLM"); - log.info(payload); if (!isJson(payload)) { throw new IllegalArgumentException("Invalid JSON in payload"); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index b7ef5c8d7a..0e8169ac64 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -117,8 +117,6 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S throw new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST); } String modelResponse = responseBuilder.toString(); - log.info("from LLM"); - log.info(modelResponse); if (statusCode < 200 || statusCode >= 300) { throw new OpenSearchStatusException(REMOTE_SERVICE_ERROR + modelResponse, RestStatus.fromCode(statusCode)); } From 70ad2234ae1538bafcab88d5e5122f57fb978946 Mon Sep 17 00:00:00 2001 From: xinyual Date: Mon, 29 Jan 2024 10:45:34 +0800 Subject: [PATCH 04/14] remove useless log Signed-off-by: xinyual --- .../ml/engine/algorithms/agent/MLChatAgentRunner.java | 10 +--------- .../java/org/opensearch/ml/engine/tools/AgentTool.java | 4 ---- .../org/opensearch/ml/engine/tools/MLModelTool.java | 2 -- 3 files changed, 1 insertion(+), 15 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 093b17c5bb..1101d17de8 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -629,15 +629,7 @@ private void runReAct( listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build()); } } else { - ActionRequest request2 = new MLPredictionTaskRequest( - llm.getModelId(), - RemoteInferenceMLInput - .builder() - .algorithm(FunctionName.REMOTE) - .inputDataset(RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build()) - .build() - ); - client.execute(MLPredictionTaskAction.INSTANCE, request2, (ActionListener) nextStepListener); + client.execute(MLPredictionTaskAction.INSTANCE, request, (ActionListener) nextStepListener); } } }, e -> { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java index ab614640f0..4a45d47a86 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java @@ -53,11 +53,7 @@ public AgentTool(Client client, String agentId) { @Override public void run(Map parameters, ActionListener listener) { - log.info("Agent tool before"); - log.info(parameters); parameters = extractFromChatParameters(parameters); - log.info("Agent tool after"); - log.info(parameters); AgentMLInput agentMLInput = AgentMLInput .AgentMLInputBuilder() .agentId(agentId) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java index ef0d05a788..cba0d2ee6a 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java @@ -65,8 +65,6 @@ public MLModelTool(Client client, String modelId) { @Override public void run(Map parameters, ActionListener listener) { - log.info("ML input"); - log.info(parameters); RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); ActionRequest request = new MLPredictionTaskRequest( modelId, From e86a8630eeff08739ca8813a2877cd72e175ad80 Mon Sep 17 00:00:00 2001 From: xinyual Date: Mon, 29 Jan 2024 11:01:36 +0800 Subject: [PATCH 05/14] fix spot less Signed-off-by: xinyual --- .../algorithms/agent/MLChatAgentRunner.java | 28 +++++++++---------- .../opensearch/ml/engine/tools/AgentTool.java | 4 +-- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 1101d17de8..2fd10c0621 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -8,7 +8,6 @@ import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD; import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; -import static org.opensearch.ml.common.utils.StringUtils.toUTF8; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.extractModelResponseJson; import java.security.AccessController; @@ -546,7 +545,13 @@ private void runReAct( .singletonList( ModelTensor .builder() - .dataAsMap(ImmutableMap.of("response", lastThought.get() + "\nObservation: " + outputToOutputString(output))) + .dataAsMap( + ImmutableMap + .of( + "response", + lastThought.get() + "\nObservation: " + outputToOutputString(output) + ) + ) .build() ) ) @@ -654,22 +659,17 @@ private void runReAct( private String outputToOutputString(Object output) throws PrivilegedActionException { String outputString; - if (output instanceof ModelTensorOutput) - { + if (output instanceof ModelTensorOutput) { ModelTensor outputModel = ((ModelTensorOutput) output).getMlModelOutputs().get(0).getMlModelTensors().get(0); - if (outputModel.getDataAsMap() != null) - { - outputString = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(outputModel.getDataAsMap())); - } - else { + if (outputModel.getDataAsMap() != null) { + outputString = AccessController + .doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(outputModel.getDataAsMap())); + } else { outputString = outputModel.getResult(); } - } - else if (output instanceof String) - { + } else if (output instanceof String) { outputString = (String) output; - } - else { + } else { outputString = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(output)); } return outputString; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java index 4a45d47a86..9e4f843530 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java @@ -5,6 +5,8 @@ package org.opensearch.ml.engine.tools; +import static org.opensearch.ml.common.utils.StringUtils.gson; + import java.util.Map; import org.opensearch.action.ActionRequest; @@ -24,8 +26,6 @@ import lombok.Setter; import lombok.extern.log4j.Log4j2; -import static org.opensearch.ml.common.utils.StringUtils.gson; - /** * This tool supports running any Agent. */ From 57e62e3c91537afeabc831fbc7363e6687dd5919 Mon Sep 17 00:00:00 2001 From: xinyual Date: Mon, 29 Jan 2024 16:00:40 +0800 Subject: [PATCH 06/14] change argument parsing Signed-off-by: xinyual --- .../algorithms/agent/MLChatAgentRunner.java | 16 ++++++++++++++-- .../opensearch/ml/engine/tools/AgentTool.java | 7 +++---- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 2fd10c0621..91441cf7e7 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -326,7 +326,7 @@ private void runReAct( } String thought = String.valueOf(dataAsMap.get("thought")); String action = String.valueOf(dataAsMap.get("action")); - String actionInput = gson.toJson(dataAsMap.get("action_input")); + String actionInput = parseInputFromLLMReturn(dataAsMap); String finalAnswer = (String) dataAsMap.get("final_answer"); if (!dataAsMap.containsKey("thought")) { String response = (String) dataAsMap.get("response"); @@ -337,7 +337,7 @@ private void runReAct( Map map = gson.fromJson(jsonBlock, Map.class); thought = String.valueOf(map.get("thought")); action = String.valueOf(map.get("action")); - actionInput = gson.toJson(map.get("action_input")); + actionInput = parseInputFromLLMReturn(map); finalAnswer = (String) map.get("final_answer"); } else { finalAnswer = response; @@ -675,4 +675,16 @@ private String outputToOutputString(Object output) throws PrivilegedActionExcept return outputString; } + private String parseInputFromLLMReturn(Map retMap){ + Object actionInput = retMap.get("action_input"); + if (actionInput instanceof Map) + { + return gson.toJson(actionInput); + } + else { + return String.valueOf(actionInput); + } + + } + } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java index 9e4f843530..d86a085ada 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java @@ -53,7 +53,7 @@ public AgentTool(Client client, String agentId) { @Override public void run(Map parameters, ActionListener listener) { - parameters = extractFromChatParameters(parameters); + extractFromChatParameters(parameters); AgentMLInput agentMLInput = AgentMLInput .AgentMLInputBuilder() .agentId(agentId) @@ -139,15 +139,14 @@ public String getDefaultVersion() { } } - private Map extractFromChatParameters(Map parameters) { + private void extractFromChatParameters(Map parameters) { if (parameters.containsKey("input")) { try { Map chatParameters = gson.fromJson(parameters.get("input"), Map.class); parameters.putAll(chatParameters); } finally { - return parameters; + return ; } } - return parameters; } } From fd458c7546c90406ab7f977769b3f8b029858bc6 Mon Sep 17 00:00:00 2001 From: xinyual Date: Tue, 30 Jan 2024 09:18:26 +0800 Subject: [PATCH 07/14] move common function to utils Signed-off-by: xinyual --- .../engine/algorithms/agent/AgentUtils.java | 33 +++++++++++++++++ .../algorithms/agent/MLChatAgentRunner.java | 35 +------------------ .../opensearch/ml/engine/tools/AgentTool.java | 4 ++- 3 files changed, 37 insertions(+), 35 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index 11b9a91f5b..7e9bdd1ab0 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -15,6 +15,9 @@ import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -23,6 +26,8 @@ import java.util.regex.Pattern; import org.apache.commons.text.StringSubstitutor; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.spi.tools.Tool; public class AgentUtils { @@ -152,4 +157,32 @@ public static String extractModelResponseJson(String text) { throw new IllegalArgumentException("Model output is invalid"); } } + + public static String outputToOutputString(Object output) throws PrivilegedActionException { + String outputString; + if (output instanceof ModelTensorOutput) { + ModelTensor outputModel = ((ModelTensorOutput) output).getMlModelOutputs().get(0).getMlModelTensors().get(0); + if (outputModel.getDataAsMap() != null) { + outputString = AccessController + .doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(outputModel.getDataAsMap())); + } else { + outputString = outputModel.getResult(); + } + } else if (output instanceof String) { + outputString = (String) output; + } else { + outputString = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(output)); + } + return outputString; + } + + public static String parseInputFromLLMReturn(Map retMap) { + Object actionInput = retMap.get("action_input"); + if (actionInput instanceof Map) { + return gson.toJson(actionInput); + } else { + return String.valueOf(actionInput); + } + + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 91441cf7e7..61ab663c76 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -8,11 +8,8 @@ import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD; import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; -import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.extractModelResponseJson; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.*; -import java.security.AccessController; -import java.security.PrivilegedActionException; -import java.security.PrivilegedExceptionAction; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -657,34 +654,4 @@ private void runReAct( client.execute(MLPredictionTaskAction.INSTANCE, request, firstListener); } - private String outputToOutputString(Object output) throws PrivilegedActionException { - String outputString; - if (output instanceof ModelTensorOutput) { - ModelTensor outputModel = ((ModelTensorOutput) output).getMlModelOutputs().get(0).getMlModelTensors().get(0); - if (outputModel.getDataAsMap() != null) { - outputString = AccessController - .doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(outputModel.getDataAsMap())); - } else { - outputString = outputModel.getResult(); - } - } else if (output instanceof String) { - outputString = (String) output; - } else { - outputString = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(output)); - } - return outputString; - } - - private String parseInputFromLLMReturn(Map retMap){ - Object actionInput = retMap.get("action_input"); - if (actionInput instanceof Map) - { - return gson.toJson(actionInput); - } - else { - return String.valueOf(actionInput); - } - - } - } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java index d86a085ada..bfbda1ed5b 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java @@ -144,8 +144,10 @@ private void extractFromChatParameters(Map parameters) { try { Map chatParameters = gson.fromJson(parameters.get("input"), Map.class); parameters.putAll(chatParameters); + } catch (Exception exception) { + log.info("fail extract parameters from key 'input' due to" + exception.getMessage()); } finally { - return ; + return; } } } From a30111c9f3a4a9ce8cb0ccf3d0baf2233bfe7a80 Mon Sep 17 00:00:00 2001 From: xinyual Date: Tue, 30 Jan 2024 09:21:16 +0800 Subject: [PATCH 08/14] checkout for typo Signed-off-by: xinyual --- .../java/org/opensearch/ml/common/connector/HttpConnector.java | 1 + 1 file changed, 1 insertion(+) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index d3b4892e3e..ef0e4bf4a1 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -289,6 +289,7 @@ public T createPredictPayload(Map parameters) { payload = fillNullParameters(parameters, payload); StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); payload = substitutor.replace(payload); + if (!isJson(payload)) { throw new IllegalArgumentException("Invalid JSON in payload"); } From d289673a16b11def41bfbd448d402b95ed30c7f0 Mon Sep 17 00:00:00 2001 From: xinyual Date: Thu, 1 Feb 2024 14:18:18 +0800 Subject: [PATCH 09/14] remove useless code Signed-off-by: xinyual --- .../src/main/java/org/opensearch/ml/engine/tools/AgentTool.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java index bfbda1ed5b..a200d16632 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java @@ -146,8 +146,6 @@ private void extractFromChatParameters(Map parameters) { parameters.putAll(chatParameters); } catch (Exception exception) { log.info("fail extract parameters from key 'input' due to" + exception.getMessage()); - } finally { - return; } } } From 3b29152942ff88f33464ef75a3afeb6965153996 Mon Sep 17 00:00:00 2001 From: xinyual Date: Thu, 1 Feb 2024 15:55:27 +0800 Subject: [PATCH 10/14] add UTs Signed-off-by: xinyual --- .../agent/MLChatAgentRunnerTest.java | 81 +++++++++++++++++++ .../ml/engine/tools/AgentToolTests.java | 19 +++-- 2 files changed, 94 insertions(+), 6 deletions(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 2247bf00c5..faafe0ac3b 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -188,6 +188,46 @@ public void testParsingJsonBlockFromResponse() { assertEquals("parsed final answer", modelTensor2.getResult()); } + @Test + public void testParsingJsonBlockFromResponse2() { + // Prepare the response with JSON block + String jsonBlock = "{\"thought\":\"parsed thought\", \"action\":\"parsed action\", " + + "\"action_input\":\"parsed action input\", \"final_answer\":\"parsed final answer\"}"; + String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text"; + + // Mock LLM response to not contain "thought" but contain "response" with JSON block + Map llmResponse = new HashMap<>(); + llmResponse.put("response", responseWithJsonBlock); + doAnswer(getLLMAnswer(llmResponse)) + .when(client) + .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + + // Create an MLAgent and run the MLChatAgentRunner + MLAgent mlAgent = createMLAgentWithTools(); + Map params = new HashMap<>(); + params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); + params.put("verbose", "true"); + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + + // Capture the response passed to the listener + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Object.class); + verify(agentActionListener).onResponse(responseCaptor.capture()); + + // Extract the captured response + Object capturedResponse = responseCaptor.getValue(); + assertTrue(capturedResponse instanceof ModelTensorOutput); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse; + + ModelTensor parentInteractionModelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1); + ModelTensor modelTensor1 = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0); + ModelTensor modelTensor2 = modelTensorOutput.getMlModelOutputs().get(2).getMlModelTensors().get(0); + + // Verify that the parsed values from JSON block are correctly set + assertEquals("parent_interaction_id", parentInteractionModelTensor.getResult()); + assertEquals("Thought: parsed thought", modelTensor1.getResult()); + assertEquals("parsed final answer", modelTensor2.getResult()); + } + @Test public void testRunWithIncludeOutputNotSet() { LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); @@ -209,6 +249,29 @@ public void testRunWithIncludeOutputNotSet() { assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response")); } + @Test + public void testRunWithIncludeOutputMLModel() { + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); + Mockito.doAnswer(generateToolResponseAsMLModelResult("First tool response", 1)).when(firstTool).run(Mockito.anyMap(), toolListenerCaptor.capture()); + Mockito.doAnswer(generateToolResponseAsMLModelResult("Second tool response", 2)).when(secondTool).run(Mockito.anyMap(), toolListenerCaptor.capture()); + MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build(); + MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).build(); + final MLAgent mlAgent = MLAgent + .builder() + .name("TestAgent") + .llm(llmSpec) + .memory(mlMemorySpec) + .tools(Arrays.asList(firstToolSpec, secondToolSpec)) + .build(); + mlChatAgentRunner.run(mlAgent, new HashMap<>(), agentActionListener); + Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue(); + List agentOutput = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors(); + assertEquals(1, agentOutput.size()); + // Respond with last tool output + assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response")); + } + @Test public void testRunWithIncludeOutputSet() { LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); @@ -512,6 +575,24 @@ private Answer generateToolResponse(String response) { }; } + private Answer generateToolResponseAsMLModelResult(String response, int type) { + ModelTensor modelTensor; + if (type == 1) { + modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("return", response)).build(); + } + else { + modelTensor = ModelTensor.builder().result(response).build(); + } + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + return invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mlModelTensorOutput); + return null; + }; + } + private Answer generateToolFailure(Exception e) { return invocation -> { ActionListener listener = invocation.getArgument(1); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java index 431e609bba..e16738985f 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java @@ -10,6 +10,7 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.verify; +import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.engine.tools.AgentTool.DEFAULT_DESCRIPTION; import java.util.Arrays; @@ -67,13 +68,19 @@ public void setup() { public void testAgenttestRunMethod() { Map parameters = new HashMap<>(); parameters.put("testKey", "testValue"); - AgentMLInput agentMLInput = AgentMLInput - .AgentMLInputBuilder() - .agentId("agentId") - .functionName(FunctionName.AGENT) - .inputDataset(RemoteInferenceInputDataSet.builder().parameters(parameters).build()) - .build(); + doTestRunMethod(parameters); + } + + @Test + public void testAgentWithChatAgentInput() { + Map parameters = new HashMap<>(); + parameters.put("testKey", "testValue"); + Map chatAgentInput = ImmutableMap.of("input", gson.toJson(parameters)); + doTestRunMethod(chatAgentInput); + } + private void doTestRunMethod(Map parameters) + { ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("thought", "thought 1", "action", "action1")).build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); From a9a20ed0fb0da8ef2eba48a6ec1e675d71e98fe8 Mon Sep 17 00:00:00 2001 From: xinyual Date: Fri, 2 Feb 2024 09:25:18 +0800 Subject: [PATCH 11/14] apply spotless Signed-off-by: xinyual --- .../agent/MLChatAgentRunnerTest.java | 31 +++++++++++-------- .../ml/engine/tools/AgentToolTests.java | 5 +-- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index faafe0ac3b..ac61cf26eb 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -192,15 +192,15 @@ public void testParsingJsonBlockFromResponse() { public void testParsingJsonBlockFromResponse2() { // Prepare the response with JSON block String jsonBlock = "{\"thought\":\"parsed thought\", \"action\":\"parsed action\", " - + "\"action_input\":\"parsed action input\", \"final_answer\":\"parsed final answer\"}"; + + "\"action_input\":\"parsed action input\", \"final_answer\":\"parsed final answer\"}"; String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text"; // Mock LLM response to not contain "thought" but contain "response" with JSON block Map llmResponse = new HashMap<>(); llmResponse.put("response", responseWithJsonBlock); doAnswer(getLLMAnswer(llmResponse)) - .when(client) - .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + .when(client) + .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); // Create an MLAgent and run the MLChatAgentRunner MLAgent mlAgent = createMLAgentWithTools(); @@ -252,17 +252,23 @@ public void testRunWithIncludeOutputNotSet() { @Test public void testRunWithIncludeOutputMLModel() { LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); - Mockito.doAnswer(generateToolResponseAsMLModelResult("First tool response", 1)).when(firstTool).run(Mockito.anyMap(), toolListenerCaptor.capture()); - Mockito.doAnswer(generateToolResponseAsMLModelResult("Second tool response", 2)).when(secondTool).run(Mockito.anyMap(), toolListenerCaptor.capture()); + Mockito + .doAnswer(generateToolResponseAsMLModelResult("First tool response", 1)) + .when(firstTool) + .run(Mockito.anyMap(), toolListenerCaptor.capture()); + Mockito + .doAnswer(generateToolResponseAsMLModelResult("Second tool response", 2)) + .when(secondTool) + .run(Mockito.anyMap(), toolListenerCaptor.capture()); MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build(); MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).build(); final MLAgent mlAgent = MLAgent - .builder() - .name("TestAgent") - .llm(llmSpec) - .memory(mlMemorySpec) - .tools(Arrays.asList(firstToolSpec, secondToolSpec)) - .build(); + .builder() + .name("TestAgent") + .llm(llmSpec) + .memory(mlMemorySpec) + .tools(Arrays.asList(firstToolSpec, secondToolSpec)) + .build(); mlChatAgentRunner.run(mlAgent, new HashMap<>(), agentActionListener); Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue(); @@ -579,8 +585,7 @@ private Answer generateToolResponseAsMLModelResult(String response, int type) { ModelTensor modelTensor; if (type == 1) { modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("return", response)).build(); - } - else { + } else { modelTensor = ModelTensor.builder().result(response).build(); } ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java index e16738985f..854c153345 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java @@ -25,8 +25,6 @@ import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; -import org.opensearch.ml.common.input.execute.agent.AgentMLInput; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; @@ -79,8 +77,7 @@ public void testAgentWithChatAgentInput() { doTestRunMethod(chatAgentInput); } - private void doTestRunMethod(Map parameters) - { + private void doTestRunMethod(Map parameters) { ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("thought", "thought 1", "action", "action1")).build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); From c7439f496f77130c1b7ee71ed6f34b4e46f5586e Mon Sep 17 00:00:00 2001 From: xinyual Date: Fri, 2 Feb 2024 10:37:19 +0800 Subject: [PATCH 12/14] add more uts Signed-off-by: xinyual --- .../agent/MLChatAgentRunnerTest.java | 40 +++++++++++++++++++ .../ml/engine/tools/AgentToolTests.java | 10 ++++- 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index ac61cf26eb..8e8f6c3235 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -228,6 +228,46 @@ public void testParsingJsonBlockFromResponse2() { assertEquals("parsed final answer", modelTensor2.getResult()); } + @Test + public void testParsingJsonBlockFromResponse3() { + // Prepare the response with JSON block + String jsonBlock = "{\"thought\":\"parsed thought\", \"action\":\"parsed action\", " + + "\"action_input\":{\"a\":\"n\"}, \"final_answer\":\"parsed final answer\"}"; + String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text"; + + // Mock LLM response to not contain "thought" but contain "response" with JSON block + Map llmResponse = new HashMap<>(); + llmResponse.put("response", responseWithJsonBlock); + doAnswer(getLLMAnswer(llmResponse)) + .when(client) + .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + + // Create an MLAgent and run the MLChatAgentRunner + MLAgent mlAgent = createMLAgentWithTools(); + Map params = new HashMap<>(); + params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); + params.put("verbose", "true"); + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + + // Capture the response passed to the listener + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Object.class); + verify(agentActionListener).onResponse(responseCaptor.capture()); + + // Extract the captured response + Object capturedResponse = responseCaptor.getValue(); + assertTrue(capturedResponse instanceof ModelTensorOutput); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse; + + ModelTensor parentInteractionModelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1); + ModelTensor modelTensor1 = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0); + ModelTensor modelTensor2 = modelTensorOutput.getMlModelOutputs().get(2).getMlModelTensors().get(0); + + // Verify that the parsed values from JSON block are correctly set + assertEquals("parent_interaction_id", parentInteractionModelTensor.getResult()); + assertEquals("Thought: parsed thought", modelTensor1.getResult()); + assertEquals("parsed final answer", modelTensor2.getResult()); + } + @Test public void testRunWithIncludeOutputNotSet() { LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java index 854c153345..86a7de2ce6 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java @@ -73,7 +73,15 @@ public void testAgenttestRunMethod() { public void testAgentWithChatAgentInput() { Map parameters = new HashMap<>(); parameters.put("testKey", "testValue"); - Map chatAgentInput = ImmutableMap.of("input", gson.toJson(parameters)); + Map chatAgentInput = new HashMap<>(); + chatAgentInput.put("input", gson.toJson(parameters)); + doTestRunMethod(chatAgentInput); + } + + @Test + public void testAgentWithChatAgentInputWrongFormat() { + Map chatAgentInput = new HashMap<>(); + chatAgentInput.put("input", "wrong format"); doTestRunMethod(chatAgentInput); } From f14a50905248ac18b547aae789d2c501cfaaf366 Mon Sep 17 00:00:00 2001 From: xinyual Date: Fri, 2 Feb 2024 12:01:35 +0800 Subject: [PATCH 13/14] modify import Signed-off-by: xinyual --- .../ml/engine/algorithms/agent/MLChatAgentRunner.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 61ab663c76..462115d1cf 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -8,7 +8,9 @@ import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD; import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; -import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.*; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.extractModelResponseJson; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.outputToOutputString; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseInputFromLLMReturn; import java.util.ArrayList; import java.util.Collections; From dea4cf634dd728abb31e3b87d9dd4409d79d3438 Mon Sep 17 00:00:00 2001 From: xinyual Date: Fri, 2 Feb 2024 15:55:22 +0800 Subject: [PATCH 14/14] protect original parameters Signed-off-by: xinyual --- .../org/opensearch/ml/engine/tools/AgentTool.java | 12 ++++++++---- .../opensearch/ml/engine/tools/AgentToolTests.java | 2 ++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java index a200d16632..197f562bb6 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java @@ -7,6 +7,7 @@ import static org.opensearch.ml.common.utils.StringUtils.gson; +import java.util.HashMap; import java.util.Map; import org.opensearch.action.ActionRequest; @@ -53,12 +54,12 @@ public AgentTool(Client client, String agentId) { @Override public void run(Map parameters, ActionListener listener) { - extractFromChatParameters(parameters); + Map extractedParameters = extractInputParameters(parameters); AgentMLInput agentMLInput = AgentMLInput .AgentMLInputBuilder() .agentId(agentId) .functionName(FunctionName.AGENT) - .inputDataset(RemoteInferenceInputDataSet.builder().parameters(parameters).build()) + .inputDataset(RemoteInferenceInputDataSet.builder().parameters(extractedParameters).build()) .build(); ActionRequest request = new MLExecuteTaskRequest(FunctionName.AGENT, agentMLInput, false); client.execute(MLExecuteTaskAction.INSTANCE, request, ActionListener.wrap(r -> { @@ -139,14 +140,17 @@ public String getDefaultVersion() { } } - private void extractFromChatParameters(Map parameters) { + private Map extractInputParameters(Map parameters) { + Map extractedParameters = new HashMap<>(); + extractedParameters.putAll(parameters); if (parameters.containsKey("input")) { try { Map chatParameters = gson.fromJson(parameters.get("input"), Map.class); - parameters.putAll(chatParameters); + extractedParameters.putAll(chatParameters); } catch (Exception exception) { log.info("fail extract parameters from key 'input' due to" + exception.getMessage()); } } + return extractedParameters; } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java index 86a7de2ce6..02b6627a31 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java @@ -76,6 +76,8 @@ public void testAgentWithChatAgentInput() { Map chatAgentInput = new HashMap<>(); chatAgentInput.put("input", gson.toJson(parameters)); doTestRunMethod(chatAgentInput); + assertEquals(chatAgentInput.size(), 1); + assertEquals(chatAgentInput.get("input"), gson.toJson(parameters)); // assert no influence on original parameters } @Test