Skip to content

Commit 57e62e3

Browse files
committed
change argument parsing
Signed-off-by: xinyual <xinyual@amazon.com>
1 parent e86a863 commit 57e62e3

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

+14-2
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ private void runReAct(
326326
}
327327
String thought = String.valueOf(dataAsMap.get("thought"));
328328
String action = String.valueOf(dataAsMap.get("action"));
329-
String actionInput = gson.toJson(dataAsMap.get("action_input"));
329+
String actionInput = parseInputFromLLMReturn(dataAsMap);
330330
String finalAnswer = (String) dataAsMap.get("final_answer");
331331
if (!dataAsMap.containsKey("thought")) {
332332
String response = (String) dataAsMap.get("response");
@@ -337,7 +337,7 @@ private void runReAct(
337337
Map map = gson.fromJson(jsonBlock, Map.class);
338338
thought = String.valueOf(map.get("thought"));
339339
action = String.valueOf(map.get("action"));
340-
actionInput = gson.toJson(map.get("action_input"));
340+
actionInput = parseInputFromLLMReturn(map);
341341
finalAnswer = (String) map.get("final_answer");
342342
} else {
343343
finalAnswer = response;
@@ -675,4 +675,16 @@ private String outputToOutputString(Object output) throws PrivilegedActionExcept
675675
return outputString;
676676
}
677677

678+
private String parseInputFromLLMReturn(Map<String, ?> retMap){
679+
Object actionInput = retMap.get("action_input");
680+
if (actionInput instanceof Map)
681+
{
682+
return gson.toJson(actionInput);
683+
}
684+
else {
685+
return String.valueOf(actionInput);
686+
}
687+
688+
}
689+
678690
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java

+3-4
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public AgentTool(Client client, String agentId) {
5353

5454
@Override
5555
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
56-
parameters = extractFromChatParameters(parameters);
56+
extractFromChatParameters(parameters);
5757
AgentMLInput agentMLInput = AgentMLInput
5858
.AgentMLInputBuilder()
5959
.agentId(agentId)
@@ -139,15 +139,14 @@ public String getDefaultVersion() {
139139
}
140140
}
141141

142-
private Map<String, String> extractFromChatParameters(Map<String, String> parameters) {
142+
private void extractFromChatParameters(Map<String, String> parameters) {
143143
if (parameters.containsKey("input")) {
144144
try {
145145
Map<String, String> chatParameters = gson.fromJson(parameters.get("input"), Map.class);
146146
parameters.putAll(chatParameters);
147147
} finally {
148-
return parameters;
148+
return ;
149149
}
150150
}
151-
return parameters;
152151
}
153152
}

0 commit comments

Comments
 (0)