13
13
import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .parseInputFromLLMReturn ;
14
14
15
15
import java .util .ArrayList ;
16
+ import java .util .Collection ;
16
17
import java .util .Collections ;
17
18
import java .util .HashMap ;
18
19
import java .util .List ;
29
30
import org .apache .commons .text .StringSubstitutor ;
30
31
import org .opensearch .action .ActionRequest ;
31
32
import org .opensearch .action .StepListener ;
33
+ import org .opensearch .action .support .GroupedActionListener ;
32
34
import org .opensearch .action .update .UpdateResponse ;
33
35
import org .opensearch .client .Client ;
34
36
import org .opensearch .cluster .service .ClusterService ;
35
37
import org .opensearch .common .settings .Settings ;
36
38
import org .opensearch .core .action .ActionListener ;
39
+ import org .opensearch .core .action .ActionResponse ;
37
40
import org .opensearch .core .common .Strings ;
38
41
import org .opensearch .core .xcontent .NamedXContentRegistry ;
39
42
import org .opensearch .ml .common .FunctionName ;
55
58
import org .opensearch .ml .engine .memory .ConversationIndexMemory ;
56
59
import org .opensearch .ml .engine .memory .ConversationIndexMessage ;
57
60
import org .opensearch .ml .engine .tools .MLModelTool ;
61
+ import org .opensearch .ml .memory .action .conversation .CreateInteractionResponse ;
58
62
import org .opensearch .ml .repackage .com .google .common .collect .ImmutableMap ;
59
63
import org .opensearch .ml .repackage .com .google .common .collect .Lists ;
60
64
@@ -376,6 +380,64 @@ private void runReAct(
376
380
}
377
381
if (finalAnswer != null ) {
378
382
finalAnswer = finalAnswer .trim ();
383
+ String finalAnswer2 = finalAnswer ;
384
+ // Composite execution response and reply.
385
+ final ActionListener <Boolean > executionListener = ActionListener .notifyOnce (ActionListener .wrap (r -> {
386
+ cotModelTensors
387
+ .add (
388
+ ModelTensors
389
+ .builder ()
390
+ .mlModelTensors (
391
+ Collections .singletonList (ModelTensor .builder ().name ("response" ).result (finalAnswer2 ).build ())
392
+ )
393
+ .build ()
394
+ );
395
+
396
+ List <ModelTensors > finalModelTensors = new ArrayList <>();
397
+ finalModelTensors
398
+ .add (
399
+ ModelTensors
400
+ .builder ()
401
+ .mlModelTensors (
402
+ List
403
+ .of (
404
+ ModelTensor .builder ().name (MLAgentExecutor .MEMORY_ID ).result (sessionId ).build (),
405
+ ModelTensor
406
+ .builder ()
407
+ .name (MLAgentExecutor .PARENT_INTERACTION_ID )
408
+ .result (parentInteractionId )
409
+ .build ()
410
+ )
411
+ )
412
+ .build ()
413
+ );
414
+ finalModelTensors
415
+ .add (
416
+ ModelTensors
417
+ .builder ()
418
+ .mlModelTensors (
419
+ Collections
420
+ .singletonList (
421
+ ModelTensor
422
+ .builder ()
423
+ .name ("response" )
424
+ .dataAsMap (
425
+ ImmutableMap .of ("response" , finalAnswer2 , ADDITIONAL_INFO_FIELD , additionalInfo )
426
+ )
427
+ .build ()
428
+ )
429
+ )
430
+ .build ()
431
+ );
432
+ getFinalAnswer .set (true );
433
+ if (verbose ) {
434
+ listener .onResponse (ModelTensorOutput .builder ().mlModelOutputs (cotModelTensors ).build ());
435
+ } else {
436
+ listener .onResponse (ModelTensorOutput .builder ().mlModelOutputs (finalModelTensors ).build ());
437
+ }
438
+ }, listener ::onFailure ));
439
+ // Sending execution response by internalListener is after the trace and answer saving.
440
+ final GroupedActionListener <ActionResponse > groupedListener = createGroupedListener (2 , executionListener );
379
441
if (conversationIndexMemory != null ) {
380
442
String finalAnswer1 = finalAnswer ;
381
443
// Create final trace message.
@@ -387,71 +449,23 @@ private void runReAct(
387
449
.finalAnswer (true )
388
450
.sessionId (sessionId )
389
451
.build ();
390
- conversationIndexMemory .save (msgTemp , parentInteractionId , traceNumber .addAndGet (1 ), null );
391
- // Update root interaction.
452
+ // Save last trace and update final answer in parallel.
453
+ conversationIndexMemory
454
+ .save (
455
+ msgTemp ,
456
+ parentInteractionId ,
457
+ traceNumber .addAndGet (1 ),
458
+ null ,
459
+ ActionListener .<CreateInteractionResponse >wrap (groupedListener ::onResponse , groupedListener ::onFailure )
460
+ );
392
461
conversationIndexMemory
393
462
.getMemoryManager ()
394
463
.updateInteraction (
395
464
parentInteractionId ,
396
465
ImmutableMap .of (AI_RESPONSE_FIELD , finalAnswer1 , ADDITIONAL_INFO_FIELD , additionalInfo ),
397
- ActionListener .<UpdateResponse >wrap (updateResponse -> {
398
- log .info ("Updated final answer into interaction id: {}" , parentInteractionId );
399
- log .info ("Final answer: {}" , finalAnswer1 );
400
- }, e -> log .error ("Failed to update root interaction" , e ))
466
+ ActionListener .<UpdateResponse >wrap (groupedListener ::onResponse , groupedListener ::onFailure )
401
467
);
402
468
}
403
- cotModelTensors
404
- .add (
405
- ModelTensors
406
- .builder ()
407
- .mlModelTensors (
408
- Collections .singletonList (ModelTensor .builder ().name ("response" ).result (finalAnswer ).build ())
409
- )
410
- .build ()
411
- );
412
-
413
- List <ModelTensors > finalModelTensors = new ArrayList <>();
414
- finalModelTensors
415
- .add (
416
- ModelTensors
417
- .builder ()
418
- .mlModelTensors (
419
- List
420
- .of (
421
- ModelTensor .builder ().name (MLAgentExecutor .MEMORY_ID ).result (sessionId ).build (),
422
- ModelTensor
423
- .builder ()
424
- .name (MLAgentExecutor .PARENT_INTERACTION_ID )
425
- .result (parentInteractionId )
426
- .build ()
427
- )
428
- )
429
- .build ()
430
- );
431
- finalModelTensors
432
- .add (
433
- ModelTensors
434
- .builder ()
435
- .mlModelTensors (
436
- Collections
437
- .singletonList (
438
- ModelTensor
439
- .builder ()
440
- .name ("response" )
441
- .dataAsMap (
442
- ImmutableMap .of ("response" , finalAnswer , ADDITIONAL_INFO_FIELD , additionalInfo )
443
- )
444
- .build ()
445
- )
446
- )
447
- .build ()
448
- );
449
- getFinalAnswer .set (true );
450
- if (verbose ) {
451
- listener .onResponse (ModelTensorOutput .builder ().mlModelOutputs (cotModelTensors ).build ());
452
- } else {
453
- listener .onResponse (ModelTensorOutput .builder ().mlModelOutputs (finalModelTensors ).build ());
454
- }
455
469
return ;
456
470
}
457
471
@@ -679,4 +693,27 @@ private void runReAct(
679
693
client .execute (MLPredictionTaskAction .INSTANCE , request , firstListener );
680
694
}
681
695
696
+ private GroupedActionListener <ActionResponse > createGroupedListener (final int size , final ActionListener <Boolean > listener ) {
697
+ return new GroupedActionListener <>(new ActionListener <Collection <ActionResponse >>() {
698
+ @ Override
699
+ public void onResponse (final Collection <ActionResponse > responses ) {
700
+ CreateInteractionResponse createInteractionResponse = extractResponse (responses , CreateInteractionResponse .class );
701
+ log .info ("saved message with interaction id: {}" , createInteractionResponse .getId ());
702
+ UpdateResponse updateResponse = extractResponse (responses , UpdateResponse .class );
703
+ log .info ("Updated final answer into interaction id: {}" , updateResponse .getId ());
704
+
705
+ listener .onResponse (true );
706
+ }
707
+
708
+ @ Override
709
+ public void onFailure (final Exception e ) {
710
+ listener .onFailure (e );
711
+ }
712
+ }, size );
713
+ }
714
+
715
+ @ SuppressWarnings ("unchecked" )
716
+ private static <A extends ActionResponse > A extractResponse (final Collection <? extends ActionResponse > responses , Class <A > c ) {
717
+ return (A ) responses .stream ().filter (c ::isInstance ).findFirst ().get ();
718
+ }
682
719
}
0 commit comments