9
9
import java .util .List ;
10
10
import java .util .Map ;
11
11
12
+ import org .apache .hc .core5 .http .ParseException ;
13
+ import org .junit .After ;
14
+ import org .junit .Before ;
12
15
import org .opensearch .client .Response ;
13
16
import org .opensearch .ml .utils .TestHelper ;
14
17
15
18
public class RestMLFlowAgentIT extends MLCommonsRestTestCase {
16
19
20
+ private String irisIndex = "iris_data" ;
21
+
22
+ @ Before
23
+ public void setup () throws IOException , ParseException {
24
+ ingestIrisData (irisIndex );
25
+ }
26
+
27
+ @ After
28
+ public void deleteIndices () throws IOException {
29
+ deleteIndexWithAdminClient (irisIndex );
30
+ }
31
+
17
32
public void testAgentCatIndexTool () throws IOException {
18
33
// Register agent with CatIndexTool.
19
34
Response response = registerAgentWithCatIndexTool ();
@@ -35,6 +50,27 @@ public void testAgentCatIndexTool() throws IOException {
35
50
assertTrue (result .contains (".plugins-ml-agent" ));
36
51
}
37
52
53
+ public void testAgentSearchIndexTool () throws IOException {
54
+ // Register agent with SearchIndexTool.
55
+ Response response = registerAgentWithSearchIndexTool ();
56
+ Map responseMap = parseResponseToMap (response );
57
+ String agentId = (String ) responseMap .get ("agent_id" );
58
+ assertNotNull (agentId );
59
+ assertEquals (20 , agentId .length ());
60
+
61
+ // Execute agent.
62
+ response = executeAgentSearchIndexTool (agentId );
63
+ responseMap = parseResponseToMap (response );
64
+ List responseList = (List ) responseMap .get ("inference_results" );
65
+ responseMap = (Map ) responseList .get (0 );
66
+ responseList = (List ) responseMap .get ("output" );
67
+ responseMap = (Map ) responseList .get (0 );
68
+ assertEquals ("response" , responseMap .get ("name" ));
69
+ String result = (String ) responseMap .get ("result" );
70
+ assertNotNull (result );
71
+ assertTrue (result .contains ("\" _source\" :{\" petal_length_in_cm\" " ));
72
+ }
73
+
38
74
public static Response registerAgentWithCatIndexTool () throws IOException {
39
75
String registerAgentEntity = "{\n "
40
76
+ " \" name\" : \" Test_Agent_For_CatIndex_tool\" ,\n "
@@ -54,20 +90,50 @@ public static Response registerAgentWithCatIndexTool() throws IOException {
54
90
.makeRequest (client (), "POST" , "/_plugins/_ml/agents/_register" , null , TestHelper .toHttpEntity (registerAgentEntity ), null );
55
91
}
56
92
93
+ public static Response registerAgentWithSearchIndexTool () throws IOException {
94
+ String registerAgentEntity = "{\n "
95
+ + " \" name\" : \" Test_Agent_For_SearchIndex_tool\" ,\n "
96
+ + " \" type\" : \" flow\" ,\n "
97
+ + " \" description\" : \" this is a test agent for the SearchIndexTool\" ,\n "
98
+ + " \" tools\" : [\n "
99
+ + " {\n "
100
+ + " \" type\" : \" SearchIndexTool\" "
101
+ + " }\n "
102
+ + " ]\n "
103
+ + "}" ;
104
+ return TestHelper
105
+ .makeRequest (client (), "POST" , "/_plugins/_ml/agents/_register" , null , TestHelper .toHttpEntity (registerAgentEntity ), null );
106
+ }
107
+
57
108
public static Response executeAgentCatIndexTool (String agentId ) throws IOException {
58
- String question = "How many indices do I have?" ;
59
- return executeAgent (agentId , question );
109
+ String question = "\" How many indices do I have?\" " ;
110
+ return executeAgent (agentId , Map . of ( " question" , question ) );
60
111
}
61
112
62
- public static Response executeAgent (String agentId , String question ) throws IOException {
63
- String executeAgentEntity = "{\n " + " \" parameters\" : {\n " + " \" question\" : \" " + question + " \" \n " + " }\n " + "}" ;
113
+ public static Response executeAgentSearchIndexTool (String agentId ) throws IOException {
114
+ String input = "{\" index\" : \" iris_data\" , \" query\" : {\" size\" : 2, \" _source\" : \" petal_length_in_cm\" }}" ;
115
+ return executeAgent (agentId , Map .of ("input" , input ));
116
+ }
117
+
118
+ public static Response executeAgent (String agentId , Map <String , String > args ) throws IOException {
119
+ if (args == null || args .isEmpty ()) {
120
+ return null ;
121
+ }
122
+
123
+ // Construct parameters.
124
+ StringBuilder entityBuilder = new StringBuilder ("{\" parameters\" :{" );
125
+ for (Map .Entry entry : args .entrySet ()) {
126
+ entityBuilder .append ('"' ).append (entry .getKey ()).append ("\" :" ).append (entry .getValue ()).append (',' );
127
+ }
128
+ entityBuilder .replace (entityBuilder .length () - 1 , entityBuilder .length (), "}}" );
129
+
64
130
return TestHelper
65
131
.makeRequest (
66
132
client (),
67
133
"POST" ,
68
134
String .format ("/_plugins/_ml/agents/%s/_execute" , agentId ),
69
135
null ,
70
- TestHelper .toHttpEntity (executeAgentEntity ),
136
+ TestHelper .toHttpEntity (entityBuilder . toString () ),
71
137
null
72
138
);
73
139
}
0 commit comments