Skip to content

Commit 99e75aa

Browse files
authored
add IT flow agent with search index tool (opensearch-project#2448)
Signed-off-by: Jing Zhang <jngz@amazon.com>
1 parent 7add721 commit 99e75aa

File tree

1 file changed

+71
-5
lines changed

1 file changed

+71
-5
lines changed

plugin/src/test/java/org/opensearch/ml/rest/RestMLFlowAgentIT.java

+71-5
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,26 @@
99
import java.util.List;
1010
import java.util.Map;
1111

12+
import org.apache.hc.core5.http.ParseException;
13+
import org.junit.After;
14+
import org.junit.Before;
1215
import org.opensearch.client.Response;
1316
import org.opensearch.ml.utils.TestHelper;
1417

1518
public class RestMLFlowAgentIT extends MLCommonsRestTestCase {
1619

20+
private String irisIndex = "iris_data";
21+
22+
@Before
23+
public void setup() throws IOException, ParseException {
24+
ingestIrisData(irisIndex);
25+
}
26+
27+
@After
28+
public void deleteIndices() throws IOException {
29+
deleteIndexWithAdminClient(irisIndex);
30+
}
31+
1732
public void testAgentCatIndexTool() throws IOException {
1833
// Register agent with CatIndexTool.
1934
Response response = registerAgentWithCatIndexTool();
@@ -35,6 +50,27 @@ public void testAgentCatIndexTool() throws IOException {
3550
assertTrue(result.contains(".plugins-ml-agent"));
3651
}
3752

53+
public void testAgentSearchIndexTool() throws IOException {
54+
// Register agent with SearchIndexTool.
55+
Response response = registerAgentWithSearchIndexTool();
56+
Map responseMap = parseResponseToMap(response);
57+
String agentId = (String) responseMap.get("agent_id");
58+
assertNotNull(agentId);
59+
assertEquals(20, agentId.length());
60+
61+
// Execute agent.
62+
response = executeAgentSearchIndexTool(agentId);
63+
responseMap = parseResponseToMap(response);
64+
List responseList = (List) responseMap.get("inference_results");
65+
responseMap = (Map) responseList.get(0);
66+
responseList = (List) responseMap.get("output");
67+
responseMap = (Map) responseList.get(0);
68+
assertEquals("response", responseMap.get("name"));
69+
String result = (String) responseMap.get("result");
70+
assertNotNull(result);
71+
assertTrue(result.contains("\"_source\":{\"petal_length_in_cm\""));
72+
}
73+
3874
public static Response registerAgentWithCatIndexTool() throws IOException {
3975
String registerAgentEntity = "{\n"
4076
+ " \"name\": \"Test_Agent_For_CatIndex_tool\",\n"
@@ -54,20 +90,50 @@ public static Response registerAgentWithCatIndexTool() throws IOException {
5490
.makeRequest(client(), "POST", "/_plugins/_ml/agents/_register", null, TestHelper.toHttpEntity(registerAgentEntity), null);
5591
}
5692

93+
public static Response registerAgentWithSearchIndexTool() throws IOException {
94+
String registerAgentEntity = "{\n"
95+
+ " \"name\": \"Test_Agent_For_SearchIndex_tool\",\n"
96+
+ " \"type\": \"flow\",\n"
97+
+ " \"description\": \"this is a test agent for the SearchIndexTool\",\n"
98+
+ " \"tools\": [\n"
99+
+ " {\n"
100+
+ " \"type\": \"SearchIndexTool\""
101+
+ " }\n"
102+
+ " ]\n"
103+
+ "}";
104+
return TestHelper
105+
.makeRequest(client(), "POST", "/_plugins/_ml/agents/_register", null, TestHelper.toHttpEntity(registerAgentEntity), null);
106+
}
107+
57108
public static Response executeAgentCatIndexTool(String agentId) throws IOException {
58-
String question = "How many indices do I have?";
59-
return executeAgent(agentId, question);
109+
String question = "\"How many indices do I have?\"";
110+
return executeAgent(agentId, Map.of("question", question));
60111
}
61112

62-
public static Response executeAgent(String agentId, String question) throws IOException {
63-
String executeAgentEntity = "{\n" + " \"parameters\": {\n" + " \"question\": \"" + question + " \"\n" + " }\n" + "}";
113+
public static Response executeAgentSearchIndexTool(String agentId) throws IOException {
114+
String input = "{\"index\": \"iris_data\", \"query\": {\"size\": 2, \"_source\": \"petal_length_in_cm\"}}";
115+
return executeAgent(agentId, Map.of("input", input));
116+
}
117+
118+
public static Response executeAgent(String agentId, Map<String, String> args) throws IOException {
119+
if (args == null || args.isEmpty()) {
120+
return null;
121+
}
122+
123+
// Construct parameters.
124+
StringBuilder entityBuilder = new StringBuilder("{\"parameters\":{");
125+
for (Map.Entry entry : args.entrySet()) {
126+
entityBuilder.append('"').append(entry.getKey()).append("\":").append(entry.getValue()).append(',');
127+
}
128+
entityBuilder.replace(entityBuilder.length() - 1, entityBuilder.length(), "}}");
129+
64130
return TestHelper
65131
.makeRequest(
66132
client(),
67133
"POST",
68134
String.format("/_plugins/_ml/agents/%s/_execute", agentId),
69135
null,
70-
TestHelper.toHttpEntity(executeAgentEntity),
136+
TestHelper.toHttpEntity(entityBuilder.toString()),
71137
null
72138
);
73139
}

0 commit comments

Comments
 (0)