Skip to content

Commit 3b142c3

Browse files
committed
change interface to with model
Signed-off-by: xinyual <xinyual@amazon.com>
1 parent 273af49 commit 3b142c3

File tree

12 files changed

+56
-82
lines changed

12 files changed

+56
-82
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
1818
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
1919
import org.opensearch.ml.common.output.model.ModelTensorOutput;
20+
import org.opensearch.ml.common.spi.tools.Tool;
2021
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
21-
import org.opensearch.ml.common.spi.tools.WithoutModelTool;
2222
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
2323
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
2424
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
@@ -32,7 +32,7 @@
3232
*/
3333
@Log4j2
3434
@ToolAnnotation(AgentTool.TYPE)
35-
public class AgentTool implements WithoutModelTool {
35+
public class AgentTool implements Tool {
3636
public static final String TYPE = "AgentTool";
3737
private final Client client;
3838

@@ -97,7 +97,7 @@ public boolean validate(Map<String, String> parameters) {
9797
return true;
9898
}
9999

100-
public static class Factory implements WithoutModelTool.Factory<AgentTool> {
100+
public static class Factory implements Tool.Factory<AgentTool> {
101101
private Client client;
102102

103103
private static Factory INSTANCE;

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,14 @@
4545
import org.opensearch.index.IndexSettings;
4646
import org.opensearch.ml.common.output.model.ModelTensors;
4747
import org.opensearch.ml.common.spi.tools.Parser;
48+
import org.opensearch.ml.common.spi.tools.Tool;
4849
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
49-
import org.opensearch.ml.common.spi.tools.WithoutModelTool;
5050

5151
import lombok.Getter;
5252
import lombok.Setter;
5353

5454
@ToolAnnotation(CatIndexTool.TYPE)
55-
public class CatIndexTool implements WithoutModelTool {
55+
public class CatIndexTool implements Tool {
5656
public static final String TYPE = "CatIndexTool";
5757
private static final String DEFAULT_DESCRIPTION = String
5858
.join(
@@ -309,7 +309,7 @@ public boolean validate(Map<String, String> parameters) {
309309
/**
310310
* Factory for the {@link CatIndexTool}
311311
*/
312-
public static class Factory implements WithoutModelTool.Factory<CatIndexTool> {
312+
public static class Factory implements Tool.Factory<CatIndexTool> {
313313
private Client client;
314314
private ClusterService clusterService;
315315

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ConnectorTool.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
import org.opensearch.ml.common.output.model.ModelTensorOutput;
1919
import org.opensearch.ml.common.output.model.ModelTensors;
2020
import org.opensearch.ml.common.spi.tools.Parser;
21+
import org.opensearch.ml.common.spi.tools.Tool;
2122
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
22-
import org.opensearch.ml.common.spi.tools.WithoutModelTool;
2323
import org.opensearch.ml.common.transport.connector.MLExecuteConnectorAction;
2424
import org.opensearch.ml.common.transport.connector.MLExecuteConnectorRequest;
2525

@@ -32,7 +32,7 @@
3232
*/
3333
@Log4j2
3434
@ToolAnnotation(ConnectorTool.TYPE)
35-
public class ConnectorTool implements WithoutModelTool {
35+
public class ConnectorTool implements Tool {
3636
public static final String TYPE = "ConnectorTool";
3737
public static final String CONNECTOR_ID = "connector_id";
3838
public static final String CONNECTOR_ACTION = "connector_action";
@@ -102,7 +102,7 @@ public boolean validate(Map<String, String> parameters) {
102102
return true;
103103
}
104104

105-
public static class Factory implements WithoutModelTool.Factory<ConnectorTool> {
105+
public static class Factory implements Tool.Factory<ConnectorTool> {
106106
public static final String TYPE = "ConnectorTool";
107107
public static final String DEFAULT_DESCRIPTION = "This tool will invoke external service.";
108108
private Client client;

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/IndexMappingTool.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@
2424
import org.opensearch.core.action.ActionListener;
2525
import org.opensearch.ml.common.output.model.ModelTensors;
2626
import org.opensearch.ml.common.spi.tools.Parser;
27+
import org.opensearch.ml.common.spi.tools.Tool;
2728
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
28-
import org.opensearch.ml.common.spi.tools.WithoutModelTool;
2929

3030
import lombok.Getter;
3131
import lombok.Setter;
3232

3333
@ToolAnnotation(IndexMappingTool.TYPE)
34-
public class IndexMappingTool implements WithoutModelTool {
34+
public class IndexMappingTool implements Tool {
3535
public static final String TYPE = "IndexMappingTool";
3636
private static final String DEFAULT_DESCRIPTION = String
3737
.join(
@@ -158,7 +158,7 @@ public boolean validate(Map<String, String> parameters) {
158158
/**
159159
* Factory for the {@link IndexMappingTool}
160160
*/
161-
public static class Factory implements WithoutModelTool.Factory<IndexMappingTool> {
161+
public static class Factory implements Tool.Factory<IndexMappingTool> {
162162
private Client client;
163163

164164
private static Factory INSTANCE;

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
import org.opensearch.ml.common.output.model.ModelTensorOutput;
1818
import org.opensearch.ml.common.output.model.ModelTensors;
1919
import org.opensearch.ml.common.spi.tools.Parser;
20-
import org.opensearch.ml.common.spi.tools.Tool;
2120
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
21+
import org.opensearch.ml.common.spi.tools.WithModelTool;
2222
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
2323
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
2424
import org.opensearch.ml.common.utils.StringUtils;
@@ -33,7 +33,7 @@
3333
*/
3434
@Log4j2
3535
@ToolAnnotation(MLModelTool.TYPE)
36-
public class MLModelTool implements Tool {
36+
public class MLModelTool implements WithModelTool {
3737
public static final String TYPE = "MLModelTool";
3838
public static final String RESPONSE_FIELD = "response_field";
3939
public static final String MODEL_ID_FIELD = "model_id";
@@ -127,7 +127,7 @@ public boolean validate(Map<String, String> parameters) {
127127
return true;
128128
}
129129

130-
public static class Factory implements Tool.Factory<MLModelTool> {
130+
public static class Factory implements WithModelTool.Factory<MLModelTool> {
131131
private Client client;
132132

133133
private static Factory INSTANCE;

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
import org.opensearch.core.action.ActionListener;
2424
import org.opensearch.core.xcontent.NamedXContentRegistry;
2525
import org.opensearch.core.xcontent.XContentParser;
26+
import org.opensearch.ml.common.spi.tools.Tool;
2627
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
27-
import org.opensearch.ml.common.spi.tools.WithoutModelTool;
2828
import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction;
2929
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
3030
import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction;
@@ -43,7 +43,7 @@
4343
@Setter
4444
@Log4j2
4545
@ToolAnnotation(SearchIndexTool.TYPE)
46-
public class SearchIndexTool implements WithoutModelTool {
46+
public class SearchIndexTool implements Tool {
4747

4848
public static final String INPUT_FIELD = "input";
4949
public static final String INDEX_FIELD = "index";
@@ -148,7 +148,7 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
148148
}
149149
}
150150

151-
public static class Factory implements WithoutModelTool.Factory<SearchIndexTool> {
151+
public static class Factory implements Tool.Factory<SearchIndexTool> {
152152

153153
private Client client;
154154
private static Factory INSTANCE;

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VisualizationsTool.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
import org.opensearch.index.IndexNotFoundException;
2121
import org.opensearch.index.query.BoolQueryBuilder;
2222
import org.opensearch.index.query.QueryBuilders;
23+
import org.opensearch.ml.common.spi.tools.Tool;
2324
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
24-
import org.opensearch.ml.common.spi.tools.WithoutModelTool;
2525
import org.opensearch.search.SearchHits;
2626
import org.opensearch.search.builder.SearchSourceBuilder;
2727

@@ -32,7 +32,7 @@
3232

3333
@Log4j2
3434
@ToolAnnotation(VisualizationsTool.TYPE)
35-
public class VisualizationsTool implements WithoutModelTool {
35+
public class VisualizationsTool implements Tool {
3636
public static final String NAME = "FindVisualizations";
3737
public static final String TYPE = "VisualizationTool";
3838
public static final String VERSION = "v1.0";
@@ -125,7 +125,7 @@ public boolean validate(Map<String, String> parameters) {
125125
return parameters.containsKey("input") && !Strings.isNullOrEmpty(parameters.get("input"));
126126
}
127127

128-
public static class Factory implements WithoutModelTool.Factory<VisualizationsTool> {
128+
public static class Factory implements Tool.Factory<VisualizationsTool> {
129129
private Client client;
130130

131131
private static Factory INSTANCE;

ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/AgentModelsSearcher.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.opensearch.index.query.BoolQueryBuilder;
1212
import org.opensearch.index.query.QueryBuilders;
1313
import org.opensearch.ml.common.spi.tools.Tool;
14+
import org.opensearch.ml.common.spi.tools.WithModelTool;
1415
import org.opensearch.search.builder.SearchSourceBuilder;
1516

1617
public class AgentModelsSearcher {
@@ -20,7 +21,9 @@ public AgentModelsSearcher(Map<String, Tool.Factory> toolFactories) {
2021
relatedModelIdSet = new HashSet<>();
2122
for (Map.Entry<String, Tool.Factory> entry : toolFactories.entrySet()) {
2223
Tool.Factory toolFactory = entry.getValue();
23-
relatedModelIdSet.addAll(toolFactory.getAllModelKeys());
24+
if (toolFactory instanceof WithModelTool.Factory withModelTool) {
25+
relatedModelIdSet.addAll(withModelTool.getAllModelKeys());
26+
}
2427
}
2528
}
2629

plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java

+28-44
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,34 @@ private void returnFailure(BulkByScrollResponse response, String modelId, Action
242242
actionListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR));
243243
}
244244

245+
private void checkDownstreamTaskBeforeDeleteModel(String modelId, Boolean isHidden, ActionListener<DeleteResponse> actionListener) {
246+
// Now checks 3 resources associated with the model id 1. Agent 2. Search pipeline 3. ingest pipeline
247+
CountDownLatch countDownLatch = new CountDownLatch(3);
248+
AtomicBoolean noneBlocked = new AtomicBoolean(true);
249+
ConcurrentLinkedQueue<String> errorMessages = new ConcurrentLinkedQueue<>();
250+
ActionListener<Boolean> countDownActionListener = ActionListener.wrap(b -> {
251+
countDownLatch.countDown();
252+
noneBlocked.compareAndSet(true, b);
253+
if (countDownLatch.getCount() == 0) {
254+
if (noneBlocked.get()) {
255+
deleteModel(modelId, isHidden, actionListener);
256+
} else {
257+
actionListener.onFailure(new OpenSearchStatusException(String.join(". ", errorMessages), RestStatus.CONFLICT));
258+
}
259+
}
260+
}, e -> {
261+
countDownLatch.countDown();
262+
noneBlocked.set(false);
263+
errorMessages.add(e.getMessage());
264+
actionListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.CONFLICT));
265+
266+
});
267+
checkAgentBeforeDeleteModel(modelId, countDownActionListener);
268+
checkIngestPipelineBeforeDeleteModel(modelId, countDownActionListener);
269+
checkSearchPipelineBeforeDeleteModel(modelId, countDownActionListener);
270+
}
271+
272+
245273
private void deleteModel(String modelId, Boolean isHidden, ActionListener<DeleteResponse> actionListener) {
246274
DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_INDEX, modelId).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
247275
client.delete(deleteRequest, new ActionListener<>() {
@@ -271,7 +299,6 @@ private void checkAgentBeforeDeleteModel(String modelId, ActionListener<Boolean>
271299
actionListener.onResponse(true);
272300
} else {
273301
String errorMessage = formatAgentErrorMessage(searchHits);
274-
275302
actionListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.CONFLICT));
276303
}
277304

@@ -333,34 +360,6 @@ private void checkPipelineBeforeDeleteModel(
333360

334361
}
335362

336-
private void checkDownstreamTaskBeforeDeleteModel(String modelId, Boolean isHidden, ActionListener<DeleteResponse> actionListener) {
337-
// Now checks 3 resources associated with with the model id 1. Agent 2. Search pipeline 3. ingest pipeline
338-
CountDownLatch countDownLatch = new CountDownLatch(3);
339-
AtomicBoolean noneBlocked = new AtomicBoolean(true);
340-
ConcurrentLinkedQueue<String> errorMessages = new ConcurrentLinkedQueue<>();
341-
ActionListener<Boolean> countDownActionListener = ActionListener.wrap(b -> {
342-
countDownLatch.countDown();
343-
noneBlocked.compareAndSet(true, b);
344-
if (countDownLatch.getCount() == 0) {
345-
if (noneBlocked.get()) {
346-
deleteModel(modelId, isHidden, actionListener);
347-
} else {
348-
actionListener.onFailure(new OpenSearchStatusException(String.join(". ", errorMessages), RestStatus.CONFLICT));
349-
}
350-
}
351-
}, e -> {
352-
countDownLatch.countDown();
353-
noneBlocked.set(false);
354-
errorMessages.add(e.getMessage());
355-
if (countDownLatch.getCount() == 0) {
356-
actionListener.onFailure(new OpenSearchStatusException(String.join(". ", errorMessages), RestStatus.CONFLICT));
357-
}
358-
359-
});
360-
checkAgentBeforeDeleteModel(modelId, countDownActionListener);
361-
checkIngestPipelineBeforeDeleteModel(modelId, countDownActionListener);
362-
checkSearchPipelineBeforeDeleteModel(modelId, countDownActionListener);
363-
}
364363

365364
private void deleteModelChunksAndController(
366365
ActionListener<DeleteResponse> actionListener,
@@ -473,21 +472,6 @@ private List<String> findDependentPipelinesEasy(Map<String, Object> allConfigMap
473472
return dependentPipelineConfigurations;
474473
}
475474

476-
private <T> List<String> findDependentPipelines(
477-
List<T> pipelineConfigurations,
478-
String candidateModelId,
479-
Function<T, Map<String, Object>> getConfigFunction,
480-
Function<T, String> getIdFunction
481-
) {
482-
List<String> dependentPipelineConfigurations = new ArrayList<>();
483-
for (T pipelineConfiguration : pipelineConfigurations) {
484-
Map<String, Object> config = getConfigFunction.apply(pipelineConfiguration);
485-
if (searchThroughConfig(config, candidateModelId)) {
486-
dependentPipelineConfigurations.add(getIdFunction.apply(pipelineConfiguration));
487-
}
488-
}
489-
return dependentPipelineConfigurations;
490-
}
491475

492476
// This method is to go through the pipeline configs and the configuration is a map of string to objects.
493477
// Objects can be a list or a map. we will search exhaustively through the configuration for any match of the candidateId.

plugin/src/test/java/org/opensearch/ml/plugin/DummyWrongTool.java

-5
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,5 @@ public String getDefaultVersion() {
106106
return null;
107107
}
108108

109-
@Override
110-
public List<String> getAllModelKeys() {
111-
return List.of();
112-
}
113-
114109
}
115110
}

spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java

-6
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,5 @@ interface Factory<T extends Tool> {
129129
* @return the default tool version
130130
*/
131131
String getDefaultVersion();
132-
133-
/**
134-
* Get model id related field names
135-
* @return the list of all model id related field names
136-
*/
137-
List<String> getAllModelKeys();
138132
}
139133
}

spi/src/main/java/org/opensearch/ml/common/spi/tools/WithoutModelTool.java spi/src/main/java/org/opensearch/ml/common/spi/tools/WithModelTool.java

+3-5
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,16 @@
1111
/**
1212
* General tool interface.
1313
*/
14-
public interface WithoutModelTool extends Tool {
14+
public interface WithModelTool extends Tool {
1515
/**
1616
* Tool factory which can create instance of {@link Tool}.
1717
* @param <T> The subclass this factory produces
1818
*/
19-
interface Factory<T extends WithoutModelTool> extends Tool.Factory<T> {
19+
interface Factory<T extends WithModelTool> extends Tool.Factory<T> {
2020
/**
2121
* Get model id related field names
2222
* @return the list of all model id related field names
2323
*/
24-
default List<String> getAllModelKeys(){
25-
return List.of();
26-
}
24+
List<String> getAllModelKeys();
2725
}
2826
}

0 commit comments

Comments
 (0)