Skip to content

Commit a5c500c

Browse files
authored
bug fix - tool parameters missing (opensearch-project#1911)
* bug fix - tool parameters missing Signed-off-by: Jing Zhang <jngz@amazon.com> * address comments Signed-off-by: Jing Zhang <jngz@amazon.com> --------- Signed-off-by: Jing Zhang <jngz@amazon.com>
1 parent 991193c commit a5c500c

File tree

2 files changed

+59
-6
lines changed

2 files changed

+59
-6
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ private void runReAct(
465465
action = toolName;
466466

467467
if (tools.containsKey(action) && inputTools.contains(action)) {
468-
Map<String, String> toolParams = new HashMap<>();
468+
Map<String, String> toolParams = new HashMap<>(toolSpecMap.get(action).getParameters());
469469
toolParams.put("input", actionInput);
470470
if (tools.get(action).validate(toolParams)) {
471471
try {

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java

+58-5
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,18 @@ public void testParsingJsonBlockFromResponse() {
191191
@Test
192192
public void testRunWithIncludeOutputNotSet() {
193193
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
194-
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build();
195-
MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).build();
194+
MLToolSpec firstToolSpec = MLToolSpec
195+
.builder()
196+
.name(FIRST_TOOL)
197+
.type(FIRST_TOOL)
198+
.parameters(ImmutableMap.of("key1", "value1", "key2", "value2"))
199+
.build();
200+
MLToolSpec secondToolSpec = MLToolSpec
201+
.builder()
202+
.name(SECOND_TOOL)
203+
.type(SECOND_TOOL)
204+
.parameters(ImmutableMap.of("key1", "value1", "key2", "value2"))
205+
.build();
196206
final MLAgent mlAgent = MLAgent
197207
.builder()
198208
.name("TestAgent")
@@ -212,8 +222,20 @@ public void testRunWithIncludeOutputNotSet() {
212222
@Test
213223
public void testRunWithIncludeOutputSet() {
214224
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
215-
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).includeOutputInAgentResponse(false).build();
216-
MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).includeOutputInAgentResponse(true).build();
225+
MLToolSpec firstToolSpec = MLToolSpec
226+
.builder()
227+
.name(FIRST_TOOL)
228+
.type(FIRST_TOOL)
229+
.includeOutputInAgentResponse(false)
230+
.parameters(ImmutableMap.of("key1", "value1", "key2", "value2"))
231+
.build();
232+
MLToolSpec secondToolSpec = MLToolSpec
233+
.builder()
234+
.name(SECOND_TOOL)
235+
.type(SECOND_TOOL)
236+
.includeOutputInAgentResponse(true)
237+
.parameters(ImmutableMap.of("key1", "value1", "key2", "value2"))
238+
.build();
217239
final MLAgent mlAgent = MLAgent
218240
.builder()
219241
.name("TestAgent")
@@ -471,10 +493,41 @@ public void testToolThrowException() {
471493
assertNotNull(modelTensorOutput);
472494
}
473495

496+
@Test
497+
public void testToolParameters() {
498+
// Mock tool validation to return false.
499+
when(firstTool.validate(any())).thenReturn(true);
500+
501+
// Create an MLAgent with a tool including two parameters.
502+
MLAgent mlAgent = createMLAgentWithTools();
503+
504+
// Create parameters for the agent.
505+
Map<String, String> params = createAgentParamsWithAction(FIRST_TOOL, "someInput");
506+
507+
// Run the MLChatAgentRunner.
508+
mlChatAgentRunner.run(mlAgent, params, agentActionListener);
509+
510+
// Verify that the tool's run method was called.
511+
verify(firstTool).run(any(), any());
512+
// Verify the size of parameters passed in the tool run method.
513+
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
514+
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
515+
assertEquals(3, ((Map) argumentCaptor.getValue()).size());
516+
517+
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
518+
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
519+
assertNotNull(modelTensorOutput);
520+
}
521+
474522
// Helper methods to create MLAgent and parameters
475523
private MLAgent createMLAgentWithTools() {
476524
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
477-
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build();
525+
MLToolSpec firstToolSpec = MLToolSpec
526+
.builder()
527+
.name(FIRST_TOOL)
528+
.type(FIRST_TOOL)
529+
.parameters(ImmutableMap.of("key1", "value1", "key2", "value2"))
530+
.build();
478531
return MLAgent.builder().name("TestAgent").tools(Arrays.asList(firstToolSpec)).memory(mlMemorySpec).llm(llmSpec).build();
479532
}
480533

0 commit comments

Comments
 (0)