8
8
import static org .opensearch .ml .common .conversation .ActionConstants .ADDITIONAL_INFO_FIELD ;
9
9
import static org .opensearch .ml .common .conversation .ActionConstants .AI_RESPONSE_FIELD ;
10
10
import static org .opensearch .ml .common .utils .StringUtils .gson ;
11
+ import static org .opensearch .ml .common .utils .StringUtils .toUTF8 ;
11
12
import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .extractModelResponseJson ;
12
13
13
14
import java .security .AccessController ;
15
+ import java .security .PrivilegedActionException ;
14
16
import java .security .PrivilegedExceptionAction ;
15
17
import java .util .ArrayList ;
16
18
import java .util .Collections ;
@@ -325,7 +327,7 @@ private void runReAct(
325
327
}
326
328
String thought = String .valueOf (dataAsMap .get ("thought" ));
327
329
String action = String .valueOf (dataAsMap .get ("action" ));
328
- String actionInput = String . valueOf (dataAsMap .get ("action_input" ));
330
+ String actionInput = gson . toJson (dataAsMap .get ("action_input" ));
329
331
String finalAnswer = (String ) dataAsMap .get ("final_answer" );
330
332
if (!dataAsMap .containsKey ("thought" )) {
331
333
String response = (String ) dataAsMap .get ("response" );
@@ -336,7 +338,7 @@ private void runReAct(
336
338
Map map = gson .fromJson (jsonBlock , Map .class );
337
339
thought = String .valueOf (map .get ("thought" ));
338
340
action = String .valueOf (map .get ("action" ));
339
- actionInput = String . valueOf (map .get ("action_input" ));
341
+ actionInput = gson . toJson (map .get ("action_input" ));
340
342
finalAnswer = (String ) map .get ("final_answer" );
341
343
} else {
342
344
finalAnswer = response ;
@@ -524,9 +526,7 @@ private void runReAct(
524
526
} else {
525
527
MLToolSpec toolSpec = toolSpecMap .get (lastAction .get ());
526
528
if (toolSpec != null && toolSpec .isIncludeOutputInAgentResponse ()) {
527
- String outputString = output instanceof String
528
- ? (String ) output
529
- : AccessController .doPrivileged ((PrivilegedExceptionAction <String >) () -> gson .toJson (output ));
529
+ String outputString = outputToOutputString (output );
530
530
531
531
String toolOutputKey = String .format ("%s.output" , toolSpec .getType ());
532
532
if (additionalInfo .get (toolOutputKey ) != null ) {
@@ -546,7 +546,7 @@ private void runReAct(
546
546
.singletonList (
547
547
ModelTensor
548
548
.builder ()
549
- .dataAsMap (ImmutableMap .of ("response" , lastThought .get () + "\n Observation: " + output ))
549
+ .dataAsMap (ImmutableMap .of ("response" , lastThought .get () + "\n Observation: " + outputToOutputString ( output ) ))
550
550
.build ()
551
551
)
552
552
)
@@ -555,7 +555,7 @@ private void runReAct(
555
555
556
556
String toolResponse = tmpParameters .get ("prompt.tool_response" );
557
557
StringSubstitutor toolResponseSubstitutor = new StringSubstitutor (
558
- ImmutableMap .of ("observation" , output ),
558
+ ImmutableMap .of ("observation" , outputToOutputString ( output ) ),
559
559
"${parameters." ,
560
560
"}"
561
561
);
@@ -567,7 +567,7 @@ private void runReAct(
567
567
.conversationIndexMessageBuilder ()
568
568
.type ("ReAct" )
569
569
.question (lastActionInput .get ())
570
- .response (( String ) output )
570
+ .response (outputToOutputString ( output ) )
571
571
.finalAnswer (false )
572
572
.sessionId (sessionId )
573
573
.build ();
@@ -582,7 +582,7 @@ private void runReAct(
582
582
newPrompt .set (substitutor .replace (finalPrompt ));
583
583
tmpParameters .put (PROMPT , newPrompt .get ());
584
584
585
- sessionMsgAnswerBuilder .append ("\n Observation: " ).append (output );
585
+ sessionMsgAnswerBuilder .append ("\n Observation: " ).append (outputToOutputString ( output ) );
586
586
cotModelTensors
587
587
.add (
588
588
ModelTensors
@@ -629,7 +629,15 @@ private void runReAct(
629
629
listener .onResponse (ModelTensorOutput .builder ().mlModelOutputs (finalModelTensors ).build ());
630
630
}
631
631
} else {
632
- client .execute (MLPredictionTaskAction .INSTANCE , request , (ActionListener <MLTaskResponse >) nextStepListener );
632
+ ActionRequest request2 = new MLPredictionTaskRequest (
633
+ llm .getModelId (),
634
+ RemoteInferenceMLInput
635
+ .builder ()
636
+ .algorithm (FunctionName .REMOTE )
637
+ .inputDataset (RemoteInferenceInputDataSet .builder ().parameters (tmpParameters ).build ())
638
+ .build ()
639
+ );
640
+ client .execute (MLPredictionTaskAction .INSTANCE , request2 , (ActionListener <MLTaskResponse >) nextStepListener );
633
641
}
634
642
}
635
643
}, e -> {
@@ -652,4 +660,27 @@ private void runReAct(
652
660
client .execute (MLPredictionTaskAction .INSTANCE , request , firstListener );
653
661
}
654
662
663
+ private String outputToOutputString (Object output ) throws PrivilegedActionException {
664
+ String outputString ;
665
+ if (output instanceof ModelTensorOutput )
666
+ {
667
+ ModelTensor outputModel = ((ModelTensorOutput ) output ).getMlModelOutputs ().get (0 ).getMlModelTensors ().get (0 );
668
+ if (outputModel .getDataAsMap () != null )
669
+ {
670
+ outputString = AccessController .doPrivileged ((PrivilegedExceptionAction <String >) () -> gson .toJson (outputModel .getDataAsMap ()));
671
+ }
672
+ else {
673
+ outputString = outputModel .getResult ();
674
+ }
675
+ }
676
+ else if (output instanceof String )
677
+ {
678
+ outputString = (String ) output ;
679
+ }
680
+ else {
681
+ outputString = AccessController .doPrivileged ((PrivilegedExceptionAction <String >) () -> gson .toJson (output ));
682
+ }
683
+ return outputString ;
684
+ }
685
+
655
686
}
0 commit comments