Skip to content

Commit 3b29152

Browse files
committed
add UTs
Signed-off-by: xinyual <xinyual@amazon.com>
1 parent d289673 commit 3b29152

File tree

2 files changed

+94
-6
lines changed

2 files changed

+94
-6
lines changed

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java

+81
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,46 @@ public void testParsingJsonBlockFromResponse() {
188188
assertEquals("parsed final answer", modelTensor2.getResult());
189189
}
190190

191+
@Test
192+
public void testParsingJsonBlockFromResponse2() {
193+
// Prepare the response with JSON block
194+
String jsonBlock = "{\"thought\":\"parsed thought\", \"action\":\"parsed action\", "
195+
+ "\"action_input\":\"parsed action input\", \"final_answer\":\"parsed final answer\"}";
196+
String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text";
197+
198+
// Mock LLM response to not contain "thought" but contain "response" with JSON block
199+
Map<String, String> llmResponse = new HashMap<>();
200+
llmResponse.put("response", responseWithJsonBlock);
201+
doAnswer(getLLMAnswer(llmResponse))
202+
.when(client)
203+
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));
204+
205+
// Create an MLAgent and run the MLChatAgentRunner
206+
MLAgent mlAgent = createMLAgentWithTools();
207+
Map<String, String> params = new HashMap<>();
208+
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
209+
params.put("verbose", "true");
210+
mlChatAgentRunner.run(mlAgent, params, agentActionListener);
211+
212+
// Capture the response passed to the listener
213+
ArgumentCaptor<Object> responseCaptor = ArgumentCaptor.forClass(Object.class);
214+
verify(agentActionListener).onResponse(responseCaptor.capture());
215+
216+
// Extract the captured response
217+
Object capturedResponse = responseCaptor.getValue();
218+
assertTrue(capturedResponse instanceof ModelTensorOutput);
219+
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;
220+
221+
ModelTensor parentInteractionModelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1);
222+
ModelTensor modelTensor1 = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0);
223+
ModelTensor modelTensor2 = modelTensorOutput.getMlModelOutputs().get(2).getMlModelTensors().get(0);
224+
225+
// Verify that the parsed values from JSON block are correctly set
226+
assertEquals("parent_interaction_id", parentInteractionModelTensor.getResult());
227+
assertEquals("Thought: parsed thought", modelTensor1.getResult());
228+
assertEquals("parsed final answer", modelTensor2.getResult());
229+
}
230+
191231
@Test
192232
public void testRunWithIncludeOutputNotSet() {
193233
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
@@ -209,6 +249,29 @@ public void testRunWithIncludeOutputNotSet() {
209249
assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response"));
210250
}
211251

252+
@Test
253+
public void testRunWithIncludeOutputMLModel() {
254+
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
255+
Mockito.doAnswer(generateToolResponseAsMLModelResult("First tool response", 1)).when(firstTool).run(Mockito.anyMap(), toolListenerCaptor.capture());
256+
Mockito.doAnswer(generateToolResponseAsMLModelResult("Second tool response", 2)).when(secondTool).run(Mockito.anyMap(), toolListenerCaptor.capture());
257+
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build();
258+
MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).build();
259+
final MLAgent mlAgent = MLAgent
260+
.builder()
261+
.name("TestAgent")
262+
.llm(llmSpec)
263+
.memory(mlMemorySpec)
264+
.tools(Arrays.asList(firstToolSpec, secondToolSpec))
265+
.build();
266+
mlChatAgentRunner.run(mlAgent, new HashMap<>(), agentActionListener);
267+
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
268+
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
269+
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors();
270+
assertEquals(1, agentOutput.size());
271+
// Respond with last tool output
272+
assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response"));
273+
}
274+
212275
@Test
213276
public void testRunWithIncludeOutputSet() {
214277
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
@@ -512,6 +575,24 @@ private Answer generateToolResponse(String response) {
512575
};
513576
}
514577

578+
private Answer generateToolResponseAsMLModelResult(String response, int type) {
579+
ModelTensor modelTensor;
580+
if (type == 1) {
581+
modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("return", response)).build();
582+
}
583+
else {
584+
modelTensor = ModelTensor.builder().result(response).build();
585+
}
586+
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
587+
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
588+
589+
return invocation -> {
590+
ActionListener<Object> listener = invocation.getArgument(1);
591+
listener.onResponse(mlModelTensorOutput);
592+
return null;
593+
};
594+
}
595+
515596
private Answer generateToolFailure(Exception e) {
516597
return invocation -> {
517598
ActionListener<Object> listener = invocation.getArgument(1);

ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java

+13-6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import static org.mockito.ArgumentMatchers.eq;
1111
import static org.mockito.Mockito.doAnswer;
1212
import static org.mockito.Mockito.verify;
13+
import static org.opensearch.ml.common.utils.StringUtils.gson;
1314
import static org.opensearch.ml.engine.tools.AgentTool.DEFAULT_DESCRIPTION;
1415

1516
import java.util.Arrays;
@@ -67,13 +68,19 @@ public void setup() {
6768
public void testAgenttestRunMethod() {
6869
Map<String, String> parameters = new HashMap<>();
6970
parameters.put("testKey", "testValue");
70-
AgentMLInput agentMLInput = AgentMLInput
71-
.AgentMLInputBuilder()
72-
.agentId("agentId")
73-
.functionName(FunctionName.AGENT)
74-
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(parameters).build())
75-
.build();
71+
doTestRunMethod(parameters);
72+
}
73+
74+
@Test
75+
public void testAgentWithChatAgentInput() {
76+
Map<String, String> parameters = new HashMap<>();
77+
parameters.put("testKey", "testValue");
78+
Map<String, String> chatAgentInput = ImmutableMap.of("input", gson.toJson(parameters));
79+
doTestRunMethod(chatAgentInput);
80+
}
7681

82+
private void doTestRunMethod(Map<String, String> parameters)
83+
{
7784
ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("thought", "thought 1", "action", "action1")).build();
7885
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
7986
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();

0 commit comments

Comments
 (0)