Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix argument pass #1941

Merged
merged 14 commits into from
Feb 2, 2024
Original file line number Diff line number Diff line change
@@ -289,7 +289,6 @@ public <T> T createPredictPayload(Map<String, String> parameters) {
payload = fillNullParameters(parameters, payload);
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
payload = substitutor.replace(payload);

if (!isJson(payload)) {
throw new IllegalArgumentException("Invalid JSON in payload");
}
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.extractModelResponseJson;

import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Collections;
@@ -325,7 +326,7 @@ private void runReAct(
}
String thought = String.valueOf(dataAsMap.get("thought"));
String action = String.valueOf(dataAsMap.get("action"));
String actionInput = String.valueOf(dataAsMap.get("action_input"));
String actionInput = parseInputFromLLMReturn(dataAsMap);
String finalAnswer = (String) dataAsMap.get("final_answer");
if (!dataAsMap.containsKey("thought")) {
String response = (String) dataAsMap.get("response");
@@ -336,7 +337,7 @@ private void runReAct(
Map map = gson.fromJson(jsonBlock, Map.class);
thought = String.valueOf(map.get("thought"));
action = String.valueOf(map.get("action"));
actionInput = String.valueOf(map.get("action_input"));
actionInput = parseInputFromLLMReturn(map);
finalAnswer = (String) map.get("final_answer");
} else {
finalAnswer = response;
@@ -524,9 +525,7 @@ private void runReAct(
} else {
MLToolSpec toolSpec = toolSpecMap.get(lastAction.get());
if (toolSpec != null && toolSpec.isIncludeOutputInAgentResponse()) {
String outputString = output instanceof String
? (String) output
: AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(output));
String outputString = outputToOutputString(output);

String toolOutputKey = String.format("%s.output", toolSpec.getType());
if (additionalInfo.get(toolOutputKey) != null) {
@@ -546,7 +545,13 @@ private void runReAct(
.singletonList(
ModelTensor
.builder()
.dataAsMap(ImmutableMap.of("response", lastThought.get() + "\nObservation: " + output))
.dataAsMap(
ImmutableMap
.of(
"response",
lastThought.get() + "\nObservation: " + outputToOutputString(output)
)
)
.build()
)
)
@@ -555,7 +560,7 @@ private void runReAct(

String toolResponse = tmpParameters.get("prompt.tool_response");
StringSubstitutor toolResponseSubstitutor = new StringSubstitutor(
ImmutableMap.of("observation", output),
ImmutableMap.of("observation", outputToOutputString(output)),
"${parameters.",
"}"
);
@@ -567,7 +572,7 @@ private void runReAct(
.conversationIndexMessageBuilder()
.type("ReAct")
.question(lastActionInput.get())
.response((String) output)
.response(outputToOutputString(output))
.finalAnswer(false)
.sessionId(sessionId)
.build();
@@ -582,7 +587,7 @@ private void runReAct(
newPrompt.set(substitutor.replace(finalPrompt));
tmpParameters.put(PROMPT, newPrompt.get());

sessionMsgAnswerBuilder.append("\nObservation: ").append(output);
sessionMsgAnswerBuilder.append("\nObservation: ").append(outputToOutputString(output));
cotModelTensors
.add(
ModelTensors
@@ -652,4 +657,34 @@ private void runReAct(
client.execute(MLPredictionTaskAction.INSTANCE, request, firstListener);
}

private String outputToOutputString(Object output) throws PrivilegedActionException {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these utility methods to handle format transformation? If so, is it possible to move them to a utility class under utils that can be reused somewhere else?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already do it.

String outputString;
if (output instanceof ModelTensorOutput) {
ModelTensor outputModel = ((ModelTensorOutput) output).getMlModelOutputs().get(0).getMlModelTensors().get(0);
if (outputModel.getDataAsMap() != null) {
outputString = AccessController
.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(outputModel.getDataAsMap()));
} else {
outputString = outputModel.getResult();
}
} else if (output instanceof String) {
outputString = (String) output;
} else {
outputString = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(output));
}
return outputString;
}

private String parseInputFromLLMReturn(Map<String, ?> retMap){
Object actionInput = retMap.get("action_input");
if (actionInput instanceof Map)
{
return gson.toJson(actionInput);
}
else {
return String.valueOf(actionInput);
}

}

}
Original file line number Diff line number Diff line change
@@ -5,6 +5,8 @@

package org.opensearch.ml.engine.tools;

import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.util.Map;

import org.opensearch.action.ActionRequest;
@@ -51,6 +53,7 @@ public AgentTool(Client client, String agentId) {

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
extractFromChatParameters(parameters);
AgentMLInput agentMLInput = AgentMLInput
.AgentMLInputBuilder()
.agentId(agentId)
@@ -135,4 +138,15 @@ public String getDefaultVersion() {
return null;
}
}

private void extractFromChatParameters(Map<String, String> parameters) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this for chat only?

if (parameters.containsKey("input")) {
try {
Map<String, String> chatParameters = gson.fromJson(parameters.get("input"), Map.class);
parameters.putAll(chatParameters);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do this ? Looks it's possible to override the original params in parameters if they have same key

} finally {
return ;
}
}
}
}