Skip to content

Commit c7439f4

Browse files
committed
add more uts
Signed-off-by: xinyual <xinyual@amazon.com>
1 parent a9a20ed commit c7439f4

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java

+40
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,46 @@ public void testParsingJsonBlockFromResponse2() {
228228
assertEquals("parsed final answer", modelTensor2.getResult());
229229
}
230230

231+
@Test
232+
public void testParsingJsonBlockFromResponse3() {
233+
// Prepare the response with JSON block
234+
String jsonBlock = "{\"thought\":\"parsed thought\", \"action\":\"parsed action\", "
235+
+ "\"action_input\":{\"a\":\"n\"}, \"final_answer\":\"parsed final answer\"}";
236+
String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text";
237+
238+
// Mock LLM response to not contain "thought" but contain "response" with JSON block
239+
Map<String, String> llmResponse = new HashMap<>();
240+
llmResponse.put("response", responseWithJsonBlock);
241+
doAnswer(getLLMAnswer(llmResponse))
242+
.when(client)
243+
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));
244+
245+
// Create an MLAgent and run the MLChatAgentRunner
246+
MLAgent mlAgent = createMLAgentWithTools();
247+
Map<String, String> params = new HashMap<>();
248+
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
249+
params.put("verbose", "true");
250+
mlChatAgentRunner.run(mlAgent, params, agentActionListener);
251+
252+
// Capture the response passed to the listener
253+
ArgumentCaptor<Object> responseCaptor = ArgumentCaptor.forClass(Object.class);
254+
verify(agentActionListener).onResponse(responseCaptor.capture());
255+
256+
// Extract the captured response
257+
Object capturedResponse = responseCaptor.getValue();
258+
assertTrue(capturedResponse instanceof ModelTensorOutput);
259+
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;
260+
261+
ModelTensor parentInteractionModelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1);
262+
ModelTensor modelTensor1 = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0);
263+
ModelTensor modelTensor2 = modelTensorOutput.getMlModelOutputs().get(2).getMlModelTensors().get(0);
264+
265+
// Verify that the parsed values from JSON block are correctly set
266+
assertEquals("parent_interaction_id", parentInteractionModelTensor.getResult());
267+
assertEquals("Thought: parsed thought", modelTensor1.getResult());
268+
assertEquals("parsed final answer", modelTensor2.getResult());
269+
}
270+
231271
@Test
232272
public void testRunWithIncludeOutputNotSet() {
233273
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();

ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java

+9-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,15 @@ public void testAgenttestRunMethod() {
7373
public void testAgentWithChatAgentInput() {
7474
Map<String, String> parameters = new HashMap<>();
7575
parameters.put("testKey", "testValue");
76-
Map<String, String> chatAgentInput = ImmutableMap.of("input", gson.toJson(parameters));
76+
Map<String, String> chatAgentInput = new HashMap<>();
77+
chatAgentInput.put("input", gson.toJson(parameters));
78+
doTestRunMethod(chatAgentInput);
79+
}
80+
81+
@Test
82+
public void testAgentWithChatAgentInputWrongFormat() {
83+
Map<String, String> chatAgentInput = new HashMap<>();
84+
chatAgentInput.put("input", "wrong format");
7785
doTestRunMethod(chatAgentInput);
7886
}
7987

0 commit comments

Comments
 (0)