Skip to content

Commit a18061c

Browse files
add action input as parameters for tool execution in conversational agent (opensearch-project#3200) (opensearch-project#3314)
* add llm generated action input as parameters for tool execution in conversational agent Signed-off-by: Jing Zhang <jngz@amazon.com> * add UT for null action input Signed-off-by: Jing Zhang <jngz@amazon.com> * change llm_generated_action_input to llm_generated_input Signed-off-by: Jing Zhang <jngz@amazon.com> --------- Signed-off-by: Jing Zhang <jngz@amazon.com> (cherry picked from commit c850eef) Co-authored-by: Jing Zhang <jngz@amazon.com>
1 parent 58a0c3e commit a18061c

File tree

3 files changed

+79
-10
lines changed

3 files changed

+79
-10
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java

+6-4
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ public class AgentUtils {
6060
public static final String PROMPT_CHAT_HISTORY_PREFIX = "prompt.chat_history_prefix";
6161
public static final String DISABLE_TRACE = "disable_trace";
6262
public static final String VERBOSE = "verbose";
63+
public static final String LLM_GEN_INPUT = "llm_generated_input";
6364

6465
public static String addExamplesToPrompt(Map<String, String> parameters, String prompt) {
6566
Map<String, String> examplesMap = new HashMap<>();
@@ -472,6 +473,11 @@ public static Map<String, String> constructToolParams(
472473
if (toolSpecConfigMap != null) {
473474
toolParams.putAll(toolSpecConfigMap);
474475
}
476+
toolParams.put(LLM_GEN_INPUT, actionInput);
477+
if (isJson(actionInput)) {
478+
Map<String, String> params = getParameterMap(gson.fromJson(actionInput, Map.class));
479+
toolParams.putAll(params);
480+
}
475481
if (tools.get(action).useOriginalInput()) {
476482
toolParams.put("input", question);
477483
lastActionInput.set(question);
@@ -486,10 +492,6 @@ public static Map<String, String> constructToolParams(
486492
}
487493
} else {
488494
toolParams.put("input", actionInput);
489-
if (isJson(actionInput)) {
490-
Map<String, String> params = getParameterMap(gson.fromJson(actionInput, Map.class));
491-
toolParams.putAll(params);
492-
}
493495
}
494496
return toolParams;
495497
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java

+69-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import static org.junit.Assert.assertEquals;
99
import static org.junit.Assert.assertThrows;
1010
import static org.mockito.Mockito.when;
11+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_GEN_INPUT;
1112
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX;
1213
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX;
1314
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION;
@@ -603,11 +604,24 @@ public void testConstructToolParams() {
603604
String question = "dummy question";
604605
String actionInput = "{'detectorName': 'abc', 'indices': 'sample-data' }";
605606
verifyConstructToolParams(question, actionInput, (toolParams) -> {
606-
Assert.assertEquals(4, toolParams.size());
607+
Assert.assertEquals(5, toolParams.size());
607608
Assert.assertEquals(actionInput, toolParams.get("input"));
608609
Assert.assertEquals("abc", toolParams.get("detectorName"));
609610
Assert.assertEquals("sample-data", toolParams.get("indices"));
610611
Assert.assertEquals("value1", toolParams.get("key1"));
612+
Assert.assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT));
613+
});
614+
}
615+
616+
@Test
617+
public void testConstructToolParamsNullActionInput() {
618+
String question = "dummy question";
619+
String actionInput = null;
620+
verifyConstructToolParams(question, actionInput, (toolParams) -> {
621+
Assert.assertEquals(3, toolParams.size());
622+
Assert.assertEquals("value1", toolParams.get("key1"));
623+
Assert.assertNull(toolParams.get(LLM_GEN_INPUT));
624+
Assert.assertNull(toolParams.get("input"));
611625
});
612626
}
613627

@@ -617,12 +631,65 @@ public void testConstructToolParams_UseOriginalInput() {
617631
String actionInput = "{'detectorName': 'abc', 'indices': 'sample-data' }";
618632
when(tool1.useOriginalInput()).thenReturn(true);
619633
verifyConstructToolParams(question, actionInput, (toolParams) -> {
620-
Assert.assertEquals(2, toolParams.size());
634+
Assert.assertEquals(5, toolParams.size());
621635
Assert.assertEquals(question, toolParams.get("input"));
622636
Assert.assertEquals("value1", toolParams.get("key1"));
637+
Assert.assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT));
638+
Assert.assertEquals("sample-data", toolParams.get("indices"));
639+
Assert.assertEquals("abc", toolParams.get("detectorName"));
623640
});
624641
}
625642

643+
@Test
644+
public void testConstructToolParams_PlaceholderConfigInput() {
645+
String question = "dummy question";
646+
String actionInput = "action input";
647+
String preConfigInputStr = "Config Input: ";
648+
Map<String, Tool> tools = Map.of("tool1", tool1);
649+
Map<String, MLToolSpec> toolSpecMap = Map
650+
.of(
651+
"tool1",
652+
MLToolSpec
653+
.builder()
654+
.type("tool1")
655+
.parameters(Map.of("key1", "value1"))
656+
.configMap(Map.of("input", preConfigInputStr + "${parameters.llm_generated_input}"))
657+
.build()
658+
);
659+
AtomicReference<String> lastActionInput = new AtomicReference<>();
660+
String action = "tool1";
661+
Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput);
662+
Assert.assertEquals(3, toolParams.size());
663+
Assert.assertEquals(preConfigInputStr + actionInput, toolParams.get("input"));
664+
Assert.assertEquals("value1", toolParams.get("key1"));
665+
Assert.assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT));
666+
}
667+
668+
@Test
669+
public void testConstructToolParams_PlaceholderConfigInputJson() {
670+
String question = "dummy question";
671+
String actionInput = "{'detectorName': 'abc', 'indices': 'sample-data' }";
672+
String preConfigInputStr = "Config Input: ";
673+
Map<String, Tool> tools = Map.of("tool1", tool1);
674+
Map<String, MLToolSpec> toolSpecMap = Map
675+
.of(
676+
"tool1",
677+
MLToolSpec
678+
.builder()
679+
.type("tool1")
680+
.parameters(Map.of("key1", "value1"))
681+
.configMap(Map.of("input", preConfigInputStr + "${parameters.detectorName}"))
682+
.build()
683+
);
684+
AtomicReference<String> lastActionInput = new AtomicReference<>();
685+
String action = "tool1";
686+
Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput);
687+
Assert.assertEquals(5, toolParams.size());
688+
Assert.assertEquals(preConfigInputStr + "abc", toolParams.get("input"));
689+
Assert.assertEquals("value1", toolParams.get("key1"));
690+
Assert.assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT));
691+
}
692+
626693
private void verifyConstructToolParams(String question, String actionInput, Consumer<Map<String, String>> verify) {
627694
Map<String, Tool> tools = Map.of("tool1", tool1);
628695
Map<String, MLToolSpec> toolSpecMap = Map

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,7 @@ public void testToolParameters() {
706706
// Verify the size of parameters passed in the tool run method.
707707
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
708708
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
709-
assertEquals(14, ((Map) argumentCaptor.getValue()).size());
709+
assertEquals(15, ((Map) argumentCaptor.getValue()).size());
710710

711711
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
712712
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
@@ -734,7 +734,7 @@ public void testToolUseOriginalInput() {
734734
// Verify the size of parameters passed in the tool run method.
735735
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
736736
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
737-
assertEquals(15, ((Map) argumentCaptor.getValue()).size());
737+
assertEquals(16, ((Map) argumentCaptor.getValue()).size());
738738
assertEquals("raw input", ((Map<?, ?>) argumentCaptor.getValue()).get("input"));
739739

740740
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
@@ -763,7 +763,7 @@ public void testToolConfig() {
763763
// Verify the size of parameters passed in the tool run method.
764764
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
765765
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
766-
assertEquals(15, ((Map) argumentCaptor.getValue()).size());
766+
assertEquals(16, ((Map) argumentCaptor.getValue()).size());
767767
// The value of input should be "config_value".
768768
assertEquals("config_value", ((Map<?, ?>) argumentCaptor.getValue()).get("input"));
769769

@@ -793,7 +793,7 @@ public void testToolConfigWithInputPlaceholder() {
793793
// Verify the size of parameters passed in the tool run method.
794794
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
795795
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
796-
assertEquals(15, ((Map) argumentCaptor.getValue()).size());
796+
assertEquals(16, ((Map) argumentCaptor.getValue()).size());
797797
// The value of input should be replaced with the value associated with the key "key2" of the first tool.
798798
assertEquals("value2", ((Map<?, ?>) argumentCaptor.getValue()).get("input"));
799799

0 commit comments

Comments
 (0)