Skip to content

Commit c53d586

Browse files
authored
send agent execution response after saving memory (opensearch-project#1999)
* send agent execution response after saving memory Signed-off-by: Jing Zhang <jngz@amazon.com> * spotless Signed-off-by: Jing Zhang <jngz@amazon.com> --------- Signed-off-by: Jing Zhang <jngz@amazon.com>
1 parent 2c63e8d commit c53d586

File tree

2 files changed

+142
-58
lines changed

2 files changed

+142
-58
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

+95-58
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseInputFromLLMReturn;
1414

1515
import java.util.ArrayList;
16+
import java.util.Collection;
1617
import java.util.Collections;
1718
import java.util.HashMap;
1819
import java.util.List;
@@ -29,11 +30,13 @@
2930
import org.apache.commons.text.StringSubstitutor;
3031
import org.opensearch.action.ActionRequest;
3132
import org.opensearch.action.StepListener;
33+
import org.opensearch.action.support.GroupedActionListener;
3234
import org.opensearch.action.update.UpdateResponse;
3335
import org.opensearch.client.Client;
3436
import org.opensearch.cluster.service.ClusterService;
3537
import org.opensearch.common.settings.Settings;
3638
import org.opensearch.core.action.ActionListener;
39+
import org.opensearch.core.action.ActionResponse;
3740
import org.opensearch.core.common.Strings;
3841
import org.opensearch.core.xcontent.NamedXContentRegistry;
3942
import org.opensearch.ml.common.FunctionName;
@@ -55,6 +58,7 @@
5558
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
5659
import org.opensearch.ml.engine.memory.ConversationIndexMessage;
5760
import org.opensearch.ml.engine.tools.MLModelTool;
61+
import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse;
5862
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap;
5963
import org.opensearch.ml.repackage.com.google.common.collect.Lists;
6064

@@ -376,6 +380,64 @@ private void runReAct(
376380
}
377381
if (finalAnswer != null) {
378382
finalAnswer = finalAnswer.trim();
383+
String finalAnswer2 = finalAnswer;
384+
// Composite execution response and reply.
385+
final ActionListener<Boolean> executionListener = ActionListener.notifyOnce(ActionListener.wrap(r -> {
386+
cotModelTensors
387+
.add(
388+
ModelTensors
389+
.builder()
390+
.mlModelTensors(
391+
Collections.singletonList(ModelTensor.builder().name("response").result(finalAnswer2).build())
392+
)
393+
.build()
394+
);
395+
396+
List<ModelTensors> finalModelTensors = new ArrayList<>();
397+
finalModelTensors
398+
.add(
399+
ModelTensors
400+
.builder()
401+
.mlModelTensors(
402+
List
403+
.of(
404+
ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(),
405+
ModelTensor
406+
.builder()
407+
.name(MLAgentExecutor.PARENT_INTERACTION_ID)
408+
.result(parentInteractionId)
409+
.build()
410+
)
411+
)
412+
.build()
413+
);
414+
finalModelTensors
415+
.add(
416+
ModelTensors
417+
.builder()
418+
.mlModelTensors(
419+
Collections
420+
.singletonList(
421+
ModelTensor
422+
.builder()
423+
.name("response")
424+
.dataAsMap(
425+
ImmutableMap.of("response", finalAnswer2, ADDITIONAL_INFO_FIELD, additionalInfo)
426+
)
427+
.build()
428+
)
429+
)
430+
.build()
431+
);
432+
getFinalAnswer.set(true);
433+
if (verbose) {
434+
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(cotModelTensors).build());
435+
} else {
436+
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build());
437+
}
438+
}, listener::onFailure));
439+
// Sending execution response by internalListener is after the trace and answer saving.
440+
final GroupedActionListener<ActionResponse> groupedListener = createGroupedListener(2, executionListener);
379441
if (conversationIndexMemory != null) {
380442
String finalAnswer1 = finalAnswer;
381443
// Create final trace message.
@@ -387,71 +449,23 @@ private void runReAct(
387449
.finalAnswer(true)
388450
.sessionId(sessionId)
389451
.build();
390-
conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), null);
391-
// Update root interaction.
452+
// Save last trace and update final answer in parallel.
453+
conversationIndexMemory
454+
.save(
455+
msgTemp,
456+
parentInteractionId,
457+
traceNumber.addAndGet(1),
458+
null,
459+
ActionListener.<CreateInteractionResponse>wrap(groupedListener::onResponse, groupedListener::onFailure)
460+
);
392461
conversationIndexMemory
393462
.getMemoryManager()
394463
.updateInteraction(
395464
parentInteractionId,
396465
ImmutableMap.of(AI_RESPONSE_FIELD, finalAnswer1, ADDITIONAL_INFO_FIELD, additionalInfo),
397-
ActionListener.<UpdateResponse>wrap(updateResponse -> {
398-
log.info("Updated final answer into interaction id: {}", parentInteractionId);
399-
log.info("Final answer: {}", finalAnswer1);
400-
}, e -> log.error("Failed to update root interaction", e))
466+
ActionListener.<UpdateResponse>wrap(groupedListener::onResponse, groupedListener::onFailure)
401467
);
402468
}
403-
cotModelTensors
404-
.add(
405-
ModelTensors
406-
.builder()
407-
.mlModelTensors(
408-
Collections.singletonList(ModelTensor.builder().name("response").result(finalAnswer).build())
409-
)
410-
.build()
411-
);
412-
413-
List<ModelTensors> finalModelTensors = new ArrayList<>();
414-
finalModelTensors
415-
.add(
416-
ModelTensors
417-
.builder()
418-
.mlModelTensors(
419-
List
420-
.of(
421-
ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(),
422-
ModelTensor
423-
.builder()
424-
.name(MLAgentExecutor.PARENT_INTERACTION_ID)
425-
.result(parentInteractionId)
426-
.build()
427-
)
428-
)
429-
.build()
430-
);
431-
finalModelTensors
432-
.add(
433-
ModelTensors
434-
.builder()
435-
.mlModelTensors(
436-
Collections
437-
.singletonList(
438-
ModelTensor
439-
.builder()
440-
.name("response")
441-
.dataAsMap(
442-
ImmutableMap.of("response", finalAnswer, ADDITIONAL_INFO_FIELD, additionalInfo)
443-
)
444-
.build()
445-
)
446-
)
447-
.build()
448-
);
449-
getFinalAnswer.set(true);
450-
if (verbose) {
451-
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(cotModelTensors).build());
452-
} else {
453-
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build());
454-
}
455469
return;
456470
}
457471

@@ -679,4 +693,27 @@ private void runReAct(
679693
client.execute(MLPredictionTaskAction.INSTANCE, request, firstListener);
680694
}
681695

696+
private GroupedActionListener<ActionResponse> createGroupedListener(final int size, final ActionListener<Boolean> listener) {
697+
return new GroupedActionListener<>(new ActionListener<Collection<ActionResponse>>() {
698+
@Override
699+
public void onResponse(final Collection<ActionResponse> responses) {
700+
CreateInteractionResponse createInteractionResponse = extractResponse(responses, CreateInteractionResponse.class);
701+
log.info("saved message with interaction id: {}", createInteractionResponse.getId());
702+
UpdateResponse updateResponse = extractResponse(responses, UpdateResponse.class);
703+
log.info("Updated final answer into interaction id: {}", updateResponse.getId());
704+
705+
listener.onResponse(true);
706+
}
707+
708+
@Override
709+
public void onFailure(final Exception e) {
710+
listener.onFailure(e);
711+
}
712+
}, size);
713+
}
714+
715+
@SuppressWarnings("unchecked")
716+
private static <A extends ActionResponse> A extractResponse(final Collection<? extends ActionResponse> responses, Class<A> c) {
717+
return (A) responses.stream().filter(c::isInstance).findFirst().get();
718+
}
682719
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java

+47
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.opensearch.action.ActionRequest;
3636
import org.opensearch.action.ActionType;
3737
import org.opensearch.action.StepListener;
38+
import org.opensearch.action.update.UpdateResponse;
3839
import org.opensearch.client.Client;
3940
import org.opensearch.cluster.service.ClusterService;
4041
import org.opensearch.common.settings.Settings;
@@ -53,6 +54,7 @@
5354
import org.opensearch.ml.common.transport.MLTaskResponse;
5455
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
5556
import org.opensearch.ml.engine.memory.MLMemoryManager;
57+
import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse;
5658
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap;
5759

5860
public class MLChatAgentRunnerTest {
@@ -97,13 +99,21 @@ public class MLChatAgentRunnerTest {
9799
private ConversationIndexMemory conversationIndexMemory;
98100
@Mock
99101
private MLMemoryManager mlMemoryManager;
102+
@Mock
103+
private CreateInteractionResponse createInteractionResponse;
104+
@Mock
105+
private UpdateResponse updateResponse;
100106

101107
@Mock
102108
private ConversationIndexMemory.Factory memoryFactory;
103109
@Captor
104110
private ArgumentCaptor<ActionListener<ConversationIndexMemory>> memoryFactoryCapture;
105111
@Captor
106112
private ArgumentCaptor<ActionListener<List<Interaction>>> memoryInteractionCapture;
113+
@Captor
114+
private ArgumentCaptor<ActionListener<CreateInteractionResponse>> conversationIndexMemoryCapture;
115+
@Captor
116+
private ArgumentCaptor<ActionListener<UpdateResponse>> mlMemoryManagerCapture;
107117

108118
@Before
109119
@SuppressWarnings("unchecked")
@@ -127,6 +137,18 @@ public void setup() {
127137
listener.onResponse(conversationIndexMemory);
128138
return null;
129139
}).when(memoryFactory).create(any(), any(), any(), memoryFactoryCapture.capture());
140+
when(createInteractionResponse.getId()).thenReturn("create_interaction_id");
141+
doAnswer(invocation -> {
142+
ActionListener<CreateInteractionResponse> listener = invocation.getArgument(4);
143+
listener.onResponse(createInteractionResponse);
144+
return null;
145+
}).when(conversationIndexMemory).save(any(), any(), any(), any(), conversationIndexMemoryCapture.capture());
146+
when(updateResponse.getId()).thenReturn("update_interaction_id");
147+
doAnswer(invocation -> {
148+
ActionListener<UpdateResponse> listener = invocation.getArgument(2);
149+
listener.onResponse(updateResponse);
150+
return null;
151+
}).when(mlMemoryManager).updateInteraction(any(), any(), mlMemoryManagerCapture.capture());
130152

131153
mlChatAgentRunner = new MLChatAgentRunner(client, settings, clusterService, xContentRegistry, toolFactories, memoryMap);
132154
when(firstToolFactory.create(Mockito.anyMap())).thenReturn(firstTool);
@@ -668,6 +690,31 @@ public void testToolParameters() {
668690
assertNotNull(modelTensorOutput);
669691
}
670692

693+
@Test
694+
public void testSaveLastTraceFailure() {
695+
// Mock tool validation to return true.
696+
when(firstTool.validate(any())).thenReturn(true);
697+
698+
// Create an MLAgent with tools
699+
MLAgent mlAgent = createMLAgentWithTools();
700+
701+
// Create parameters for the agent
702+
Map<String, String> params = createAgentParamsWithAction(FIRST_TOOL, "someInput");
703+
704+
doAnswer(invocation -> {
705+
ActionListener<CreateInteractionResponse> listener = invocation.getArgument(4);
706+
listener.onFailure(new IllegalArgumentException());
707+
return null;
708+
}).when(conversationIndexMemory).save(any(), any(), any(), any(), conversationIndexMemoryCapture.capture());
709+
// Run the MLChatAgentRunner
710+
mlChatAgentRunner.run(mlAgent, params, agentActionListener);
711+
712+
// Verify that the tool's run method was called
713+
verify(firstTool).run(any(), any());
714+
715+
Mockito.verify(agentActionListener).onFailure(any(IllegalArgumentException.class));
716+
}
717+
671718
// Helper methods to create MLAgent and parameters
672719
private MLAgent createMLAgentWithTools() {
673720
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();

0 commit comments

Comments
 (0)