Skip to content

Commit 388fa33

Browse files
Add GetTool API and ListTools API (opensearch-project#1818) (opensearch-project#1850)
* Add GetTool API and ListTools API Signed-off-by: Jackie Han <jkhanjob@gmail.com> * rename externalTools parameter as toolMetadataList Signed-off-by: Jackie Han <jkhanjob@gmail.com> * spotless apply Signed-off-by: Jackie Han <jkhanjob@gmail.com> * add more unit tests Signed-off-by: Jackie Han <jkhanjob@gmail.com> * tweak unit test cases Signed-off-by: Jackie Han <jkhanjob@gmail.com> * increase test coverage Signed-off-by: Jackie Han <jkhanjob@gmail.com> * increase test coverage Signed-off-by: Jackie Han <jkhanjob@gmail.com> * add more tests Signed-off-by: Jackie Han <jkhanjob@gmail.com> * Include Type and Version in GetTool and ListTools API responses Signed-off-by: Jackie Han <jkhanjob@gmail.com> * tweak ListTools result format Signed-off-by: Jackie Han <jkhanjob@gmail.com> * change term no version found to undefined Signed-off-by: Jackie Han <jkhanjob@gmail.com> --------- Signed-off-by: Jackie Han <jkhanjob@gmail.com> (cherry picked from commit deb51f6) Co-authored-by: Jackie Han <jkhanjob@gmail.com>
1 parent e35e8b7 commit 388fa33

30 files changed

+1843
-2
lines changed

client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java

+38
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.ml.client;
77

8+
import java.util.List;
89
import java.util.Map;
910

1011
import org.opensearch.action.delete.DeleteResponse;
@@ -17,6 +18,7 @@
1718
import org.opensearch.ml.common.FunctionName;
1819
import org.opensearch.ml.common.MLModel;
1920
import org.opensearch.ml.common.MLTask;
21+
import org.opensearch.ml.common.ToolMetadata;
2022
import org.opensearch.ml.common.agent.MLAgent;
2123
import org.opensearch.ml.common.input.Input;
2224
import org.opensearch.ml.common.input.MLInput;
@@ -390,4 +392,40 @@ default ActionFuture<DeleteResponse> deleteAgent(String agentId) {
390392

391393
void deleteAgent(String agentId, ActionListener<DeleteResponse> listener);
392394

395+
/**
396+
* Get a list of ToolMetadata and return ActionFuture.
397+
* For more info on list tools, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#list-tools
398+
* @return ActionFuture of a list of tool metadata
399+
*/
400+
default ActionFuture<List<ToolMetadata>> listTools() {
401+
PlainActionFuture<List<ToolMetadata>> actionFuture = PlainActionFuture.newFuture();
402+
listTools(actionFuture);
403+
return actionFuture;
404+
}
405+
406+
/**
407+
* List ToolMetadata and return a list of ToolMetadata in listener
408+
* For more info on get tools, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#list-tools
409+
* @param listener action listener
410+
*/
411+
void listTools(ActionListener<List<ToolMetadata>> listener);
412+
413+
/**
414+
* Get ToolMetadata and return ActionFuture.
415+
* For more info on get tool, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#get-tool
416+
* @return ActionFuture of tool metadata
417+
*/
418+
default ActionFuture<ToolMetadata> getTool(String toolName) {
419+
PlainActionFuture<ToolMetadata> actionFuture = PlainActionFuture.newFuture();
420+
getTool(toolName, actionFuture);
421+
return actionFuture;
422+
}
423+
424+
/**
425+
* Get ToolMetadata and return ToolMetadata in listener
426+
* For more info on get tool, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#get-tool
427+
* @param listener action listener
428+
*/
429+
void getTool(String toolName, ActionListener<ToolMetadata> listener);
430+
393431
}

client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java

+44
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import static org.opensearch.ml.common.input.InputHelper.getAction;
1515
import static org.opensearch.ml.common.input.InputHelper.getFunctionName;
1616

17+
import java.util.List;
1718
import java.util.Map;
1819
import java.util.function.Function;
1920

@@ -26,6 +27,7 @@
2627
import org.opensearch.ml.common.FunctionName;
2728
import org.opensearch.ml.common.MLModel;
2829
import org.opensearch.ml.common.MLTask;
30+
import org.opensearch.ml.common.ToolMetadata;
2931
import org.opensearch.ml.common.agent.MLAgent;
3032
import org.opensearch.ml.common.input.Input;
3133
import org.opensearch.ml.common.input.MLInput;
@@ -71,6 +73,12 @@
7173
import org.opensearch.ml.common.transport.task.MLTaskGetRequest;
7274
import org.opensearch.ml.common.transport.task.MLTaskGetResponse;
7375
import org.opensearch.ml.common.transport.task.MLTaskSearchAction;
76+
import org.opensearch.ml.common.transport.tools.MLGetToolAction;
77+
import org.opensearch.ml.common.transport.tools.MLListToolsAction;
78+
import org.opensearch.ml.common.transport.tools.MLToolGetRequest;
79+
import org.opensearch.ml.common.transport.tools.MLToolGetResponse;
80+
import org.opensearch.ml.common.transport.tools.MLToolsListRequest;
81+
import org.opensearch.ml.common.transport.tools.MLToolsListResponse;
7482
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction;
7583
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
7684
import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction;
@@ -287,6 +295,42 @@ public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener)
287295
}, listener::onFailure));
288296
}
289297

298+
@Override
299+
public void listTools(ActionListener<List<ToolMetadata>> listener) {
300+
MLToolsListRequest mlToolsListRequest = MLToolsListRequest.builder().build();
301+
302+
client.execute(MLListToolsAction.INSTANCE, mlToolsListRequest, getMlListToolsResponseActionListener(listener));
303+
}
304+
305+
@Override
306+
public void getTool(String toolName, ActionListener<ToolMetadata> listener) {
307+
MLToolGetRequest mlToolGetRequest = MLToolGetRequest.builder().toolName(toolName).build();
308+
309+
client.execute(MLGetToolAction.INSTANCE, mlToolGetRequest, getMlGetToolResponseActionListener(listener));
310+
}
311+
312+
private ActionListener<MLToolsListResponse> getMlListToolsResponseActionListener(ActionListener<List<ToolMetadata>> listener) {
313+
ActionListener<MLToolsListResponse> internalListener = ActionListener.wrap(mlModelListResponse -> {
314+
listener.onResponse(mlModelListResponse.getToolMetadataList());
315+
}, listener::onFailure);
316+
ActionListener<MLToolsListResponse> actionListener = wrapActionListener(internalListener, res -> {
317+
MLToolsListResponse getResponse = MLToolsListResponse.fromActionResponse(res);
318+
return getResponse;
319+
});
320+
return actionListener;
321+
}
322+
323+
private ActionListener<MLToolGetResponse> getMlGetToolResponseActionListener(ActionListener<ToolMetadata> listener) {
324+
ActionListener<MLToolGetResponse> internalListener = ActionListener.wrap(mlModelGetResponse -> {
325+
listener.onResponse(mlModelGetResponse.getToolMetadata());
326+
}, listener::onFailure);
327+
ActionListener<MLToolGetResponse> actionListener = wrapActionListener(internalListener, res -> {
328+
MLToolGetResponse getResponse = MLToolGetResponse.fromActionResponse(res);
329+
return getResponse;
330+
});
331+
return actionListener;
332+
}
333+
290334
private ActionListener<MLRegisterAgentResponse> getMLRegisterAgentResponseActionListener(
291335
ActionListener<MLRegisterAgentResponse> listener
292336
) {

client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java

+32
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.opensearch.ml.common.FunctionName;
3232
import org.opensearch.ml.common.MLModel;
3333
import org.opensearch.ml.common.MLTask;
34+
import org.opensearch.ml.common.ToolMetadata;
3435
import org.opensearch.ml.common.agent.MLAgent;
3536
import org.opensearch.ml.common.dataframe.DataFrame;
3637
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
@@ -100,6 +101,8 @@ public class MachineLearningClientTest {
100101
private String modekId = "test_model_id";
101102
private MLModel mlModel;
102103
private MLTask mlTask;
104+
private ToolMetadata toolMetadata;
105+
private List<ToolMetadata> toolsList = new ArrayList<>();
103106

104107
@Before
105108
public void setUp() {
@@ -111,6 +114,15 @@ public void setUp() {
111114
String modelContent = "test content";
112115
mlModel = MLModel.builder().algorithm(FunctionName.KMEANS).name("test").content(modelContent).build();
113116

117+
toolMetadata = ToolMetadata
118+
.builder()
119+
.name("MathTool")
120+
.description("Use this tool to calculate any math problem.")
121+
.type("MathTool")
122+
.version(null)
123+
.build();
124+
toolsList.add(toolMetadata);
125+
114126
machineLearningClient = new MachineLearningClient() {
115127
@Override
116128
public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
@@ -192,6 +204,16 @@ public void deleteConnector(String connectorId, ActionListener<DeleteResponse> l
192204
listener.onResponse(deleteResponse);
193205
}
194206

207+
@Override
208+
public void listTools(ActionListener<List<ToolMetadata>> listener) {
209+
listener.onResponse(toolsList);
210+
}
211+
212+
@Override
213+
public void getTool(String toolName, ActionListener<ToolMetadata> listener) {
214+
listener.onResponse(toolMetadata);
215+
}
216+
195217
public void registerModelGroup(
196218
MLRegisterModelGroupInput mlRegisterModelGroupInput,
197219
ActionListener<MLRegisterModelGroupResponse> listener
@@ -470,4 +492,14 @@ public void testRegisterAgent() {
470492
public void deleteAgent() {
471493
assertEquals(deleteResponse, machineLearningClient.deleteAgent("agentId").actionGet());
472494
}
495+
496+
@Test
497+
public void getTool() {
498+
assertEquals(toolMetadata, machineLearningClient.getTool("MathTool").actionGet());
499+
}
500+
501+
@Test
502+
public void listTools() {
503+
assertEquals(toolMetadata, machineLearningClient.listTools().actionGet().get(0));
504+
}
473505
}

client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java

+63
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
import org.opensearch.ml.common.MLTask;
6161
import org.opensearch.ml.common.MLTaskState;
6262
import org.opensearch.ml.common.MLTaskType;
63+
import org.opensearch.ml.common.ToolMetadata;
6364
import org.opensearch.ml.common.agent.MLAgent;
6465
import org.opensearch.ml.common.dataframe.DataFrame;
6566
import org.opensearch.ml.common.dataset.MLInputDataset;
@@ -116,6 +117,12 @@
116117
import org.opensearch.ml.common.transport.task.MLTaskGetRequest;
117118
import org.opensearch.ml.common.transport.task.MLTaskGetResponse;
118119
import org.opensearch.ml.common.transport.task.MLTaskSearchAction;
120+
import org.opensearch.ml.common.transport.tools.MLGetToolAction;
121+
import org.opensearch.ml.common.transport.tools.MLListToolsAction;
122+
import org.opensearch.ml.common.transport.tools.MLToolGetRequest;
123+
import org.opensearch.ml.common.transport.tools.MLToolGetResponse;
124+
import org.opensearch.ml.common.transport.tools.MLToolsListRequest;
125+
import org.opensearch.ml.common.transport.tools.MLToolsListResponse;
119126
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction;
120127
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
121128
import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction;
@@ -192,6 +199,12 @@ public class MachineLearningNodeClientTest {
192199
@Mock
193200
ActionListener<DeleteResponse> deleteAgentActionListener;
194201

202+
@Mock
203+
ActionListener<List<ToolMetadata>> listToolsActionListener;
204+
205+
@Mock
206+
ActionListener<ToolMetadata> getToolActionListener;
207+
195208
@InjectMocks
196209
MachineLearningNodeClient machineLearningNodeClient;
197210

@@ -887,6 +900,56 @@ public void deleteAgent() {
887900
assertEquals(agentId, (argumentCaptor.getValue()).getId());
888901
}
889902

903+
@Test
904+
public void getTool() {
905+
ToolMetadata toolMetadata = ToolMetadata
906+
.builder()
907+
.name("MathTool")
908+
.description("Use this tool to calculate any math problem.")
909+
.build();
910+
911+
doAnswer(invocation -> {
912+
ActionListener<MLToolGetResponse> actionListener = invocation.getArgument(2);
913+
MLToolGetResponse output = MLToolGetResponse.builder().toolMetadata(toolMetadata).build();
914+
actionListener.onResponse(output);
915+
return null;
916+
}).when(client).execute(eq(MLGetToolAction.INSTANCE), any(), any());
917+
918+
ArgumentCaptor<ToolMetadata> argumentCaptor = ArgumentCaptor.forClass(ToolMetadata.class);
919+
machineLearningNodeClient.getTool("MathTool", getToolActionListener);
920+
921+
verify(client).execute(eq(MLGetToolAction.INSTANCE), isA(MLToolGetRequest.class), any());
922+
verify(getToolActionListener).onResponse(argumentCaptor.capture());
923+
assertEquals("MathTool", argumentCaptor.getValue().getName());
924+
assertEquals("Use this tool to calculate any math problem.", argumentCaptor.getValue().getDescription());
925+
}
926+
927+
@Test
928+
public void listTools() {
929+
List<ToolMetadata> toolMetadataList = new ArrayList<>();
930+
ToolMetadata wikipediaTool = ToolMetadata
931+
.builder()
932+
.name("WikipediaTool")
933+
.description("Use this tool to search general knowledge on wikipedia.")
934+
.build();
935+
toolMetadataList.add(wikipediaTool);
936+
937+
doAnswer(invocation -> {
938+
ActionListener<MLToolsListResponse> actionListener = invocation.getArgument(2);
939+
MLToolsListResponse output = MLToolsListResponse.builder().toolMetadata(toolMetadataList).build();
940+
actionListener.onResponse(output);
941+
return null;
942+
}).when(client).execute(eq(MLListToolsAction.INSTANCE), any(), any());
943+
944+
ArgumentCaptor<List<ToolMetadata>> argumentCaptor = ArgumentCaptor.forClass(List.class);
945+
machineLearningNodeClient.listTools(listToolsActionListener);
946+
947+
verify(client).execute(eq(MLListToolsAction.INSTANCE), isA(MLToolsListRequest.class), any());
948+
verify(listToolsActionListener).onResponse(argumentCaptor.capture());
949+
assertEquals("WikipediaTool", argumentCaptor.getValue().get(0).getName());
950+
assertEquals("Use this tool to search general knowledge on wikipedia.", argumentCaptor.getValue().get(0).getDescription());
951+
}
952+
890953
private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
891954
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
892955

0 commit comments

Comments
 (0)