8
8
import static org .junit .Assert .assertEquals ;
9
9
import static org .junit .Assert .assertThrows ;
10
10
import static org .mockito .Mockito .when ;
11
+ import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .LLM_GEN_INPUT ;
11
12
import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .PROMPT_PREFIX ;
12
13
import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .PROMPT_SUFFIX ;
13
14
import static org .opensearch .ml .engine .algorithms .agent .MLChatAgentRunner .ACTION ;
@@ -603,11 +604,24 @@ public void testConstructToolParams() {
603
604
String question = "dummy question" ;
604
605
String actionInput = "{'detectorName': 'abc', 'indices': 'sample-data' }" ;
605
606
verifyConstructToolParams (question , actionInput , (toolParams ) -> {
606
- Assert .assertEquals (4 , toolParams .size ());
607
+ Assert .assertEquals (5 , toolParams .size ());
607
608
Assert .assertEquals (actionInput , toolParams .get ("input" ));
608
609
Assert .assertEquals ("abc" , toolParams .get ("detectorName" ));
609
610
Assert .assertEquals ("sample-data" , toolParams .get ("indices" ));
610
611
Assert .assertEquals ("value1" , toolParams .get ("key1" ));
612
+ Assert .assertEquals (actionInput , toolParams .get (LLM_GEN_INPUT ));
613
+ });
614
+ }
615
+
616
+ @ Test
617
+ public void testConstructToolParamsNullActionInput () {
618
+ String question = "dummy question" ;
619
+ String actionInput = null ;
620
+ verifyConstructToolParams (question , actionInput , (toolParams ) -> {
621
+ Assert .assertEquals (3 , toolParams .size ());
622
+ Assert .assertEquals ("value1" , toolParams .get ("key1" ));
623
+ Assert .assertNull (toolParams .get (LLM_GEN_INPUT ));
624
+ Assert .assertNull (toolParams .get ("input" ));
611
625
});
612
626
}
613
627
@@ -617,12 +631,65 @@ public void testConstructToolParams_UseOriginalInput() {
617
631
String actionInput = "{'detectorName': 'abc', 'indices': 'sample-data' }" ;
618
632
when (tool1 .useOriginalInput ()).thenReturn (true );
619
633
verifyConstructToolParams (question , actionInput , (toolParams ) -> {
620
- Assert .assertEquals (2 , toolParams .size ());
634
+ Assert .assertEquals (5 , toolParams .size ());
621
635
Assert .assertEquals (question , toolParams .get ("input" ));
622
636
Assert .assertEquals ("value1" , toolParams .get ("key1" ));
637
+ Assert .assertEquals (actionInput , toolParams .get (LLM_GEN_INPUT ));
638
+ Assert .assertEquals ("sample-data" , toolParams .get ("indices" ));
639
+ Assert .assertEquals ("abc" , toolParams .get ("detectorName" ));
623
640
});
624
641
}
625
642
643
+ @ Test
644
+ public void testConstructToolParams_PlaceholderConfigInput () {
645
+ String question = "dummy question" ;
646
+ String actionInput = "action input" ;
647
+ String preConfigInputStr = "Config Input: " ;
648
+ Map <String , Tool > tools = Map .of ("tool1" , tool1 );
649
+ Map <String , MLToolSpec > toolSpecMap = Map
650
+ .of (
651
+ "tool1" ,
652
+ MLToolSpec
653
+ .builder ()
654
+ .type ("tool1" )
655
+ .parameters (Map .of ("key1" , "value1" ))
656
+ .configMap (Map .of ("input" , preConfigInputStr + "${parameters.llm_generated_input}" ))
657
+ .build ()
658
+ );
659
+ AtomicReference <String > lastActionInput = new AtomicReference <>();
660
+ String action = "tool1" ;
661
+ Map <String , String > toolParams = AgentUtils .constructToolParams (tools , toolSpecMap , question , lastActionInput , action , actionInput );
662
+ Assert .assertEquals (3 , toolParams .size ());
663
+ Assert .assertEquals (preConfigInputStr + actionInput , toolParams .get ("input" ));
664
+ Assert .assertEquals ("value1" , toolParams .get ("key1" ));
665
+ Assert .assertEquals (actionInput , toolParams .get (LLM_GEN_INPUT ));
666
+ }
667
+
668
+ @ Test
669
+ public void testConstructToolParams_PlaceholderConfigInputJson () {
670
+ String question = "dummy question" ;
671
+ String actionInput = "{'detectorName': 'abc', 'indices': 'sample-data' }" ;
672
+ String preConfigInputStr = "Config Input: " ;
673
+ Map <String , Tool > tools = Map .of ("tool1" , tool1 );
674
+ Map <String , MLToolSpec > toolSpecMap = Map
675
+ .of (
676
+ "tool1" ,
677
+ MLToolSpec
678
+ .builder ()
679
+ .type ("tool1" )
680
+ .parameters (Map .of ("key1" , "value1" ))
681
+ .configMap (Map .of ("input" , preConfigInputStr + "${parameters.detectorName}" ))
682
+ .build ()
683
+ );
684
+ AtomicReference <String > lastActionInput = new AtomicReference <>();
685
+ String action = "tool1" ;
686
+ Map <String , String > toolParams = AgentUtils .constructToolParams (tools , toolSpecMap , question , lastActionInput , action , actionInput );
687
+ Assert .assertEquals (5 , toolParams .size ());
688
+ Assert .assertEquals (preConfigInputStr + "abc" , toolParams .get ("input" ));
689
+ Assert .assertEquals ("value1" , toolParams .get ("key1" ));
690
+ Assert .assertEquals (actionInput , toolParams .get (LLM_GEN_INPUT ));
691
+ }
692
+
626
693
private void verifyConstructToolParams (String question , String actionInput , Consumer <Map <String , String >> verify ) {
627
694
Map <String , Tool > tools = Map .of ("tool1" , tool1 );
628
695
Map <String , MLToolSpec > toolSpecMap = Map
0 commit comments