Skip to content

Commit fd458c7

Browse files
committed
move common function to utils
Signed-off-by: xinyual <xinyual@amazon.com>
1 parent 57e62e3 commit fd458c7

File tree

3 files changed

+37
-35
lines changed

3 files changed

+37
-35
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java

+33
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS;
1616
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES;
1717

18+
import java.security.AccessController;
19+
import java.security.PrivilegedActionException;
20+
import java.security.PrivilegedExceptionAction;
1821
import java.util.HashMap;
1922
import java.util.List;
2023
import java.util.Map;
@@ -23,6 +26,8 @@
2326
import java.util.regex.Pattern;
2427

2528
import org.apache.commons.text.StringSubstitutor;
29+
import org.opensearch.ml.common.output.model.ModelTensor;
30+
import org.opensearch.ml.common.output.model.ModelTensorOutput;
2631
import org.opensearch.ml.common.spi.tools.Tool;
2732

2833
public class AgentUtils {
@@ -152,4 +157,32 @@ public static String extractModelResponseJson(String text) {
152157
throw new IllegalArgumentException("Model output is invalid");
153158
}
154159
}
160+
161+
public static String outputToOutputString(Object output) throws PrivilegedActionException {
162+
String outputString;
163+
if (output instanceof ModelTensorOutput) {
164+
ModelTensor outputModel = ((ModelTensorOutput) output).getMlModelOutputs().get(0).getMlModelTensors().get(0);
165+
if (outputModel.getDataAsMap() != null) {
166+
outputString = AccessController
167+
.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(outputModel.getDataAsMap()));
168+
} else {
169+
outputString = outputModel.getResult();
170+
}
171+
} else if (output instanceof String) {
172+
outputString = (String) output;
173+
} else {
174+
outputString = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(output));
175+
}
176+
return outputString;
177+
}
178+
179+
public static String parseInputFromLLMReturn(Map<String, ?> retMap) {
180+
Object actionInput = retMap.get("action_input");
181+
if (actionInput instanceof Map) {
182+
return gson.toJson(actionInput);
183+
} else {
184+
return String.valueOf(actionInput);
185+
}
186+
187+
}
155188
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

+1-34
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,8 @@
88
import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD;
99
import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD;
1010
import static org.opensearch.ml.common.utils.StringUtils.gson;
11-
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.extractModelResponseJson;
11+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.*;
1212

13-
import java.security.AccessController;
14-
import java.security.PrivilegedActionException;
15-
import java.security.PrivilegedExceptionAction;
1613
import java.util.ArrayList;
1714
import java.util.Collections;
1815
import java.util.HashMap;
@@ -657,34 +654,4 @@ private void runReAct(
657654
client.execute(MLPredictionTaskAction.INSTANCE, request, firstListener);
658655
}
659656

660-
private String outputToOutputString(Object output) throws PrivilegedActionException {
661-
String outputString;
662-
if (output instanceof ModelTensorOutput) {
663-
ModelTensor outputModel = ((ModelTensorOutput) output).getMlModelOutputs().get(0).getMlModelTensors().get(0);
664-
if (outputModel.getDataAsMap() != null) {
665-
outputString = AccessController
666-
.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(outputModel.getDataAsMap()));
667-
} else {
668-
outputString = outputModel.getResult();
669-
}
670-
} else if (output instanceof String) {
671-
outputString = (String) output;
672-
} else {
673-
outputString = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(output));
674-
}
675-
return outputString;
676-
}
677-
678-
private String parseInputFromLLMReturn(Map<String, ?> retMap){
679-
Object actionInput = retMap.get("action_input");
680-
if (actionInput instanceof Map)
681-
{
682-
return gson.toJson(actionInput);
683-
}
684-
else {
685-
return String.valueOf(actionInput);
686-
}
687-
688-
}
689-
690657
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,10 @@ private void extractFromChatParameters(Map<String, String> parameters) {
144144
try {
145145
Map<String, String> chatParameters = gson.fromJson(parameters.get("input"), Map.class);
146146
parameters.putAll(chatParameters);
147+
} catch (Exception exception) {
148+
log.info("fail extract parameters from key 'input' due to" + exception.getMessage());
147149
} finally {
148-
return ;
150+
return;
149151
}
150152
}
151153
}

0 commit comments

Comments
 (0)