Skip to content

Commit 1d9d3b7

Browse files
committed
option2FixBug
Signed-off-by: xinyual <xinyual@amazon.com>
1 parent 97f1199 commit 1d9d3b7

File tree

3 files changed

+45
-11
lines changed

3 files changed

+45
-11
lines changed

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,8 @@ public <T> T createPredictPayload(Map<String, String> parameters) {
289289
payload = fillNullParameters(parameters, payload);
290290
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
291291
payload = substitutor.replace(payload);
292-
292+
log.info("to LLM");
293+
log.info(payload);
293294
if (!isJson(payload)) {
294295
throw new IllegalArgumentException("Invalid JSON in payload");
295296
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

+41-10
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD;
99
import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD;
1010
import static org.opensearch.ml.common.utils.StringUtils.gson;
11+
import static org.opensearch.ml.common.utils.StringUtils.toUTF8;
1112
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.extractModelResponseJson;
1213

1314
import java.security.AccessController;
15+
import java.security.PrivilegedActionException;
1416
import java.security.PrivilegedExceptionAction;
1517
import java.util.ArrayList;
1618
import java.util.Collections;
@@ -325,7 +327,7 @@ private void runReAct(
325327
}
326328
String thought = String.valueOf(dataAsMap.get("thought"));
327329
String action = String.valueOf(dataAsMap.get("action"));
328-
String actionInput = String.valueOf(dataAsMap.get("action_input"));
330+
String actionInput = gson.toJson(dataAsMap.get("action_input"));
329331
String finalAnswer = (String) dataAsMap.get("final_answer");
330332
if (!dataAsMap.containsKey("thought")) {
331333
String response = (String) dataAsMap.get("response");
@@ -336,7 +338,7 @@ private void runReAct(
336338
Map map = gson.fromJson(jsonBlock, Map.class);
337339
thought = String.valueOf(map.get("thought"));
338340
action = String.valueOf(map.get("action"));
339-
actionInput = String.valueOf(map.get("action_input"));
341+
actionInput = gson.toJson(map.get("action_input"));
340342
finalAnswer = (String) map.get("final_answer");
341343
} else {
342344
finalAnswer = response;
@@ -524,9 +526,7 @@ private void runReAct(
524526
} else {
525527
MLToolSpec toolSpec = toolSpecMap.get(lastAction.get());
526528
if (toolSpec != null && toolSpec.isIncludeOutputInAgentResponse()) {
527-
String outputString = output instanceof String
528-
? (String) output
529-
: AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(output));
529+
String outputString = outputToOutputString(output);
530530

531531
String toolOutputKey = String.format("%s.output", toolSpec.getType());
532532
if (additionalInfo.get(toolOutputKey) != null) {
@@ -546,7 +546,7 @@ private void runReAct(
546546
.singletonList(
547547
ModelTensor
548548
.builder()
549-
.dataAsMap(ImmutableMap.of("response", lastThought.get() + "\nObservation: " + output))
549+
.dataAsMap(ImmutableMap.of("response", lastThought.get() + "\nObservation: " + outputToOutputString(output)))
550550
.build()
551551
)
552552
)
@@ -555,7 +555,7 @@ private void runReAct(
555555

556556
String toolResponse = tmpParameters.get("prompt.tool_response");
557557
StringSubstitutor toolResponseSubstitutor = new StringSubstitutor(
558-
ImmutableMap.of("observation", output),
558+
ImmutableMap.of("observation", outputToOutputString(output)),
559559
"${parameters.",
560560
"}"
561561
);
@@ -567,7 +567,7 @@ private void runReAct(
567567
.conversationIndexMessageBuilder()
568568
.type("ReAct")
569569
.question(lastActionInput.get())
570-
.response((String) output)
570+
.response(outputToOutputString(output))
571571
.finalAnswer(false)
572572
.sessionId(sessionId)
573573
.build();
@@ -582,7 +582,7 @@ private void runReAct(
582582
newPrompt.set(substitutor.replace(finalPrompt));
583583
tmpParameters.put(PROMPT, newPrompt.get());
584584

585-
sessionMsgAnswerBuilder.append("\nObservation: ").append(output);
585+
sessionMsgAnswerBuilder.append("\nObservation: ").append(outputToOutputString(output));
586586
cotModelTensors
587587
.add(
588588
ModelTensors
@@ -629,7 +629,15 @@ private void runReAct(
629629
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build());
630630
}
631631
} else {
632-
client.execute(MLPredictionTaskAction.INSTANCE, request, (ActionListener<MLTaskResponse>) nextStepListener);
632+
ActionRequest request2 = new MLPredictionTaskRequest(
633+
llm.getModelId(),
634+
RemoteInferenceMLInput
635+
.builder()
636+
.algorithm(FunctionName.REMOTE)
637+
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build())
638+
.build()
639+
);
640+
client.execute(MLPredictionTaskAction.INSTANCE, request2, (ActionListener<MLTaskResponse>) nextStepListener);
633641
}
634642
}
635643
}, e -> {
@@ -652,4 +660,27 @@ private void runReAct(
652660
client.execute(MLPredictionTaskAction.INSTANCE, request, firstListener);
653661
}
654662

663+
private String outputToOutputString(Object output) throws PrivilegedActionException {
664+
String outputString;
665+
if (output instanceof ModelTensorOutput)
666+
{
667+
ModelTensor outputModel = ((ModelTensorOutput) output).getMlModelOutputs().get(0).getMlModelTensors().get(0);
668+
if (outputModel.getDataAsMap() != null)
669+
{
670+
outputString = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(outputModel.getDataAsMap()));
671+
}
672+
else {
673+
outputString = outputModel.getResult();
674+
}
675+
}
676+
else if (output instanceof String)
677+
{
678+
outputString = (String) output;
679+
}
680+
else {
681+
outputString = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(output));
682+
}
683+
return outputString;
684+
}
685+
655686
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java

+2
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
117117
throw new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST);
118118
}
119119
String modelResponse = responseBuilder.toString();
120+
log.info("from LLM");
121+
log.info(modelResponse);
120122
if (statusCode < 200 || statusCode >= 300) {
121123
throw new OpenSearchStatusException(REMOTE_SERVICE_ERROR + modelResponse, RestStatus.fromCode(statusCode));
122124
}

0 commit comments

Comments
 (0)