Skip to content

Commit b84b130

Browse files
authored
add memory id and interation id for non-verbose (opensearch-project#2004)
Signed-off-by: Jing Zhang <jngz@amazon.com>
1 parent a62ecc1 commit b84b130

File tree

2 files changed

+77
-3
lines changed

2 files changed

+77
-3
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

+34
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,23 @@ private void runReAct(
411411
);
412412

413413
List<ModelTensors> finalModelTensors = new ArrayList<>();
414+
finalModelTensors
415+
.add(
416+
ModelTensors
417+
.builder()
418+
.mlModelTensors(
419+
List
420+
.of(
421+
ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(),
422+
ModelTensor
423+
.builder()
424+
.name(MLAgentExecutor.PARENT_INTERACTION_ID)
425+
.result(parentInteractionId)
426+
.build()
427+
)
428+
)
429+
.build()
430+
);
414431
finalModelTensors
415432
.add(
416433
ModelTensors
@@ -603,6 +620,23 @@ private void runReAct(
603620
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(cotModelTensors).build());
604621
} else {
605622
List<ModelTensors> finalModelTensors = new ArrayList<>();
623+
finalModelTensors
624+
.add(
625+
ModelTensors
626+
.builder()
627+
.mlModelTensors(
628+
List
629+
.of(
630+
ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(),
631+
ModelTensor
632+
.builder()
633+
.name(MLAgentExecutor.PARENT_INTERACTION_ID)
634+
.result(parentInteractionId)
635+
.build()
636+
)
637+
)
638+
.build()
639+
);
606640
finalModelTensors
607641
.add(
608642
ModelTensors

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java

+43-3
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,46 @@ public void testParsingJsonBlockFromResponse3() {
268268
assertEquals("parsed final answer", modelTensor2.getResult());
269269
}
270270

271+
@Test
272+
public void testParsingJsonBlockFromResponse4() {
273+
// Prepare the response with JSON block
274+
String jsonBlock = "{\"thought\":\"parsed thought\", \"action\":\"parsed action\", "
275+
+ "\"action_input\":\"parsed action input\", \"final_answer\":\"parsed final answer\"}";
276+
String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text";
277+
278+
// Mock LLM response to not contain "thought" but contain "response" with JSON block
279+
Map<String, String> llmResponse = new HashMap<>();
280+
llmResponse.put("response", responseWithJsonBlock);
281+
doAnswer(getLLMAnswer(llmResponse))
282+
.when(client)
283+
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));
284+
285+
// Create an MLAgent and run the MLChatAgentRunner
286+
MLAgent mlAgent = createMLAgentWithTools();
287+
Map<String, String> params = new HashMap<>();
288+
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
289+
params.put("verbose", "false");
290+
mlChatAgentRunner.run(mlAgent, params, agentActionListener);
291+
292+
// Capture the response passed to the listener
293+
ArgumentCaptor<Object> responseCaptor = ArgumentCaptor.forClass(Object.class);
294+
verify(agentActionListener).onResponse(responseCaptor.capture());
295+
296+
// Extract the captured response
297+
Object capturedResponse = responseCaptor.getValue();
298+
assertTrue(capturedResponse instanceof ModelTensorOutput);
299+
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;
300+
301+
ModelTensor memoryIdModelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0);
302+
ModelTensor parentInteractionModelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1);
303+
304+
// Verify that the parsed values from JSON block are correctly set
305+
assertEquals("memory_id", memoryIdModelTensor.getName());
306+
assertEquals("conversation_id", memoryIdModelTensor.getResult());
307+
assertEquals("parent_interaction_id", parentInteractionModelTensor.getName());
308+
assertEquals("parent_interaction_id", parentInteractionModelTensor.getResult());
309+
}
310+
271311
@Test
272312
public void testRunWithIncludeOutputNotSet() {
273313
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
@@ -293,7 +333,7 @@ public void testRunWithIncludeOutputNotSet() {
293333
mlChatAgentRunner.run(mlAgent, new HashMap<>(), agentActionListener);
294334
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
295335
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
296-
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors();
336+
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors();
297337
assertEquals(1, agentOutput.size());
298338
// Respond with last tool output
299339
assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response"));
@@ -322,7 +362,7 @@ public void testRunWithIncludeOutputMLModel() {
322362
mlChatAgentRunner.run(mlAgent, new HashMap<>(), agentActionListener);
323363
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
324364
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
325-
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors();
365+
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors();
326366
assertEquals(1, agentOutput.size());
327367
// Respond with last tool output
328368
assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response"));
@@ -356,7 +396,7 @@ public void testRunWithIncludeOutputSet() {
356396
mlChatAgentRunner.run(mlAgent, params, agentActionListener);
357397
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
358398
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
359-
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors();
399+
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors();
360400
assertEquals(1, agentOutput.size());
361401
// Respond with last tool output
362402
assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response"));

0 commit comments

Comments
 (0)