Skip to content

Commit 570edaf

Browse files
authored
Check before delete (#3209)
* add logic to detect agent before deleting Signed-off-by: xinyual <xinyual@amazon.com> * add logic to detect agent before deleting Signed-off-by: xinyual <xinyual@amazon.com> * add logic to detect pipelines before delete model Signed-off-by: xinyual <xinyual@amazon.com> * check pipeline before deleting Signed-off-by: xinyual <xinyual@amazon.com> * apply spotless Signed-off-by: xinyual <xinyual@amazon.com> * remove useless file Signed-off-by: xinyual <xinyual@amazon.com> * rename functions Signed-off-by: xinyual <xinyual@amazon.com> * fix failure test Signed-off-by: xinyual <xinyual@amazon.com> * add UT Signed-off-by: xinyual <xinyual@amazon.com> * apply spotless Signed-off-by: xinyual <xinyual@amazon.com> * renam Signed-off-by: xinyual <xinyual@amazon.com> * refactor to parallel check Signed-off-by: xinyual <xinyual@amazon.com> * concate error message Signed-off-by: xinyual <xinyual@amazon.com> * move logic after user access check Signed-off-by: xinyual <xinyual@amazon.com> * change agent model searcher map to set Signed-off-by: xinyual <xinyual@amazon.com> * rename and remove useless method Signed-off-by: xinyual <xinyual@amazon.com> * fix bug to fetch all pipelines Signed-off-by: xinyual <xinyual@amazon.com> * apply spotless Signed-off-by: xinyual <xinyual@amazon.com> * apply spotless Signed-off-by: xinyual <xinyual@amazon.com> * remove and add comment Signed-off-by: xinyual <xinyual@amazon.com> * rename and add more UTs Signed-off-by: xinyual <xinyual@amazon.com> * use correct key Signed-off-by: xinyual <xinyual@amazon.com> * simplify function Signed-off-by: xinyual <xinyual@amazon.com> * change to a better class Signed-off-by: xinyual <xinyual@amazon.com> * apply spotless Signed-off-by: xinyual <xinyual@amazon.com> * change compareAndSet to set Signed-off-by: xinyual <xinyual@amazon.com> * apply comment Signed-off-by: xinyual <xinyual@amazon.com> * change name and reformat logic Signed-off-by: xinyual <xinyual@amazon.com> * change name Signed-off-by: xinyual <xinyual@amazon.com> * remove useless line Signed-off-by: xinyual <xinyual@amazon.com> * change to a better method Signed-off-by: xinyual <xinyual@amazon.com> * change name Signed-off-by: xinyual <xinyual@amazon.com> * apply spotless Signed-off-by: xinyual <xinyual@amazon.com> * add java doc for function Signed-off-by: xinyual <xinyual@amazon.com> * add another interface Signed-off-by: xinyual <xinyual@amazon.com> * apply java spotless Signed-off-by: xinyual <xinyual@amazon.com> * change interface to with model Signed-off-by: xinyual <xinyual@amazon.com> * apply spot less Signed-off-by: xinyual <xinyual@amazon.com> * add settings Signed-off-by: xinyual <xinyual@amazon.com> * apply spot less Signed-off-by: xinyual <xinyual@amazon.com> * add test for cluster setting Signed-off-by: xinyual <xinyual@amazon.com> * apply spotless Signed-off-by: xinyual <xinyual@amazon.com> * recover useless change Signed-off-by: xinyual <xinyual@amazon.com> * change default value of cluster setting Signed-off-by: xinyual <xinyual@amazon.com> * rename setting and add comment Signed-off-by: xinyual <xinyual@amazon.com> * apply spot Signed-off-by: xinyual <xinyual@amazon.com> * remove logic for hidden model Signed-off-by: xinyual <xinyual@amazon.com> * reorder code Signed-off-by: xinyual <xinyual@amazon.com> * reorder code Signed-off-by: xinyual <xinyual@amazon.com> * reorder code Signed-off-by: xinyual <xinyual@amazon.com> * apply spot Signed-off-by: xinyual <xinyual@amazon.com> * add UT Signed-off-by: xinyual <xinyual@amazon.com> * add more UT Signed-off-by: xinyual <xinyual@amazon.com> * remove search for hidden agent Signed-off-by: xinyual <xinyual@amazon.com> * fix logic and apply spot Signed-off-by: xinyual <xinyual@amazon.com> * add exist for UT Signed-off-by: xinyual <xinyual@amazon.com> * change dsl to query index Signed-off-by: xinyual <xinyual@amazon.com> * change query logic Signed-off-by: xinyual <xinyual@amazon.com> * remove useless ut Signed-off-by: xinyual <xinyual@amazon.com> * rebert Signed-off-by: xinyual <xinyual@amazon.com> * apply spot Signed-off-by: xinyual <xinyual@amazon.com> * rechange code Signed-off-by: xinyual <xinyual@amazon.com> * apply spot Signed-off-by: xinyual <xinyual@amazon.com> * remove useless should Signed-off-by: xinyual <xinyual@amazon.com> * apply spot Signed-off-by: xinyual <xinyual@amazon.com> * fix final dsl logic and ut Signed-off-by: xinyual <xinyual@amazon.com> --------- Signed-off-by: xinyual <xinyual@amazon.com>
1 parent af96fe0 commit 570edaf

File tree

10 files changed

+825
-47
lines changed

10 files changed

+825
-47
lines changed

common/src/main/java/org/opensearch/ml/common/CommonValue.java

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ public class CommonValue {
4545
public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message";
4646
public static final String ML_STOP_WORDS_INDEX = ".plugins-ml-stop-words";
4747
public static final Set<String> stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words");
48+
public static final String TOOL_PARAMETERS_PREFIX = "tools.parameters.";
4849

4950
// Index mapping paths
5051
public static final String ML_MODEL_GROUP_INDEX_MAPPING_PATH = "index-mappings/ml_model_group.json";

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java

+8-3
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
import org.opensearch.ml.common.output.model.ModelTensorOutput;
1818
import org.opensearch.ml.common.output.model.ModelTensors;
1919
import org.opensearch.ml.common.spi.tools.Parser;
20-
import org.opensearch.ml.common.spi.tools.Tool;
2120
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
21+
import org.opensearch.ml.common.spi.tools.WithModelTool;
2222
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
2323
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
2424
import org.opensearch.ml.common.utils.StringUtils;
@@ -33,7 +33,7 @@
3333
*/
3434
@Log4j2
3535
@ToolAnnotation(MLModelTool.TYPE)
36-
public class MLModelTool implements Tool {
36+
public class MLModelTool implements WithModelTool {
3737
public static final String TYPE = "MLModelTool";
3838
public static final String RESPONSE_FIELD = "response_field";
3939
public static final String MODEL_ID_FIELD = "model_id";
@@ -127,7 +127,7 @@ public boolean validate(Map<String, String> parameters) {
127127
return true;
128128
}
129129

130-
public static class Factory implements Tool.Factory<MLModelTool> {
130+
public static class Factory implements WithModelTool.Factory<MLModelTool> {
131131
private Client client;
132132

133133
private static Factory INSTANCE;
@@ -172,5 +172,10 @@ public String getDefaultType() {
172172
public String getDefaultVersion() {
173173
return null;
174174
}
175+
176+
@Override
177+
public List<String> getAllModelKeys() {
178+
return List.of(MODEL_ID_FIELD);
179+
}
175180
}
176181
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package org.opensearch.ml.engine.utils;
2+
3+
import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX;
4+
import static org.opensearch.ml.common.CommonValue.TOOL_PARAMETERS_PREFIX;
5+
6+
import java.util.HashSet;
7+
import java.util.Map;
8+
import java.util.Set;
9+
10+
import org.opensearch.action.search.SearchRequest;
11+
import org.opensearch.index.query.BoolQueryBuilder;
12+
import org.opensearch.index.query.QueryBuilders;
13+
import org.opensearch.ml.common.agent.MLAgent;
14+
import org.opensearch.ml.common.spi.tools.Tool;
15+
import org.opensearch.ml.common.spi.tools.WithModelTool;
16+
import org.opensearch.search.builder.SearchSourceBuilder;
17+
18+
public class AgentModelsSearcher {
19+
private final Set<String> relatedModelIdSet;
20+
21+
public AgentModelsSearcher(Map<String, Tool.Factory> toolFactories) {
22+
relatedModelIdSet = new HashSet<>();
23+
for (Map.Entry<String, Tool.Factory> entry : toolFactories.entrySet()) {
24+
Tool.Factory toolFactory = entry.getValue();
25+
if (toolFactory instanceof WithModelTool.Factory) {
26+
WithModelTool.Factory withModelTool = (WithModelTool.Factory) toolFactory;
27+
relatedModelIdSet.addAll(withModelTool.getAllModelKeys());
28+
}
29+
}
30+
}
31+
32+
/**
33+
* Construct a should query to search all agent which containing candidate model Id
34+
35+
@param candidateModelId the candidate model Id
36+
@return a should search request towards agent index.
37+
*/
38+
public SearchRequest constructQueryRequestToSearchModelIdInsideAgent(String candidateModelId) {
39+
SearchRequest searchRequest = new SearchRequest(ML_AGENT_INDEX);
40+
// Two conditions here
41+
// 1. {[(exists hidden field) and (hidden field = false)] or (not exist hidden field)} and
42+
// 2. Any model field contains candidate ID
43+
BoolQueryBuilder searchAgentQuery = QueryBuilders.boolQuery();
44+
45+
BoolQueryBuilder hiddenFieldQuery = QueryBuilders.boolQuery();
46+
// not exist hidden
47+
hiddenFieldQuery.should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(MLAgent.IS_HIDDEN_FIELD)));
48+
// exist but equal to false
49+
BoolQueryBuilder existHiddenFieldQuery = QueryBuilders.boolQuery();
50+
existHiddenFieldQuery.must(QueryBuilders.termsQuery(MLAgent.IS_HIDDEN_FIELD, false));
51+
existHiddenFieldQuery.must(QueryBuilders.existsQuery(MLAgent.IS_HIDDEN_FIELD));
52+
hiddenFieldQuery.should(existHiddenFieldQuery);
53+
54+
//
55+
BoolQueryBuilder modelIdQuery = QueryBuilders.boolQuery();
56+
for (String keyField : relatedModelIdSet) {
57+
modelIdQuery.should(QueryBuilders.termsQuery(TOOL_PARAMETERS_PREFIX + keyField, candidateModelId));
58+
}
59+
60+
searchAgentQuery.must(hiddenFieldQuery);
61+
searchAgentQuery.must(modelIdQuery);
62+
searchRequest.source(new SearchSourceBuilder().query(searchAgentQuery));
63+
return searchRequest;
64+
}
65+
66+
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import static org.mockito.Mockito.doAnswer;
1515
import static org.mockito.Mockito.verify;
1616
import static org.opensearch.ml.engine.tools.MLModelTool.DEFAULT_DESCRIPTION;
17+
import static org.opensearch.ml.engine.tools.MLModelTool.MODEL_ID_FIELD;
1718

1819
import java.util.Arrays;
1920
import java.util.Collections;
@@ -218,5 +219,6 @@ public void testTool() {
218219
assertTrue(tool.validate(otherParams));
219220
assertFalse(tool.validate(emptyParams));
220221
assertEquals(DEFAULT_DESCRIPTION, tool.getDescription());
222+
assertEquals(List.of(MODEL_ID_FIELD), MLModelTool.Factory.getInstance().getAllModelKeys());
221223
}
222224
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.engine.utils;
7+
8+
import static org.junit.Assert.assertEquals;
9+
import static org.junit.Assert.assertTrue;
10+
import static org.mockito.Mockito.mock;
11+
import static org.mockito.Mockito.when;
12+
13+
import java.util.Arrays;
14+
import java.util.Collections;
15+
import java.util.HashMap;
16+
import java.util.Map;
17+
18+
import org.junit.Test;
19+
import org.opensearch.action.search.SearchRequest;
20+
import org.opensearch.index.query.BoolQueryBuilder;
21+
import org.opensearch.index.query.ExistsQueryBuilder;
22+
import org.opensearch.index.query.QueryBuilder;
23+
import org.opensearch.index.query.TermsQueryBuilder;
24+
import org.opensearch.ml.common.agent.MLAgent;
25+
import org.opensearch.ml.common.spi.tools.Tool;
26+
import org.opensearch.ml.common.spi.tools.WithModelTool;
27+
28+
public class AgentModelSearcherTests {
29+
30+
@Test
31+
public void testConstructor_CollectsModelIds() {
32+
// Arrange
33+
WithModelTool.Factory withModelToolFactory1 = mock(WithModelTool.Factory.class);
34+
when(withModelToolFactory1.getAllModelKeys()).thenReturn(Arrays.asList("modelKey1", "modelKey2"));
35+
36+
WithModelTool.Factory withModelToolFactory2 = mock(WithModelTool.Factory.class);
37+
when(withModelToolFactory2.getAllModelKeys()).thenReturn(Collections.singletonList("anotherModelKey"));
38+
39+
// This tool factory does not implement WithModelTool.Factory
40+
Tool.Factory regularToolFactory = mock(Tool.Factory.class);
41+
42+
Map<String, Tool.Factory> toolFactories = new HashMap<>();
43+
toolFactories.put("withModelTool1", withModelToolFactory1);
44+
toolFactories.put("withModelTool2", withModelToolFactory2);
45+
toolFactories.put("regularTool", regularToolFactory);
46+
47+
// Act
48+
AgentModelsSearcher searcher = new AgentModelsSearcher(toolFactories);
49+
50+
// (Optional) We can't directly access relatedModelIdSet,
51+
// but we can test the behavior indirectly using the search call:
52+
SearchRequest request = searcher.constructQueryRequestToSearchModelIdInsideAgent("candidateId");
53+
54+
// Assert
55+
// Verify the searchRequest uses all keys from the WithModelTool factories
56+
BoolQueryBuilder boolQueryBuilder = (BoolQueryBuilder) request.source().query();
57+
// We expect modelKey1, modelKey2, anotherModelKey => total 3 "should" clauses
58+
assertEquals(2, boolQueryBuilder.must().size());
59+
for (QueryBuilder query : boolQueryBuilder.must()) {
60+
BoolQueryBuilder subBoolQueryBuilder = (BoolQueryBuilder) query;
61+
assertTrue(subBoolQueryBuilder.should().size() == 2 || subBoolQueryBuilder.should().size() == 3);
62+
if (subBoolQueryBuilder.should().size() == 3) {
63+
boolQueryBuilder.should().forEach(subQuery -> {
64+
assertTrue(subQuery instanceof TermsQueryBuilder);
65+
TermsQueryBuilder termsQuery = (TermsQueryBuilder) subQuery;
66+
// Each TermsQueryBuilder should contain candidateModelId
67+
assertTrue(termsQuery.values().contains("candidateId"));
68+
});
69+
} else {
70+
boolQueryBuilder.should().forEach(subQuery -> {
71+
assertTrue(subQuery instanceof BoolQueryBuilder);
72+
BoolQueryBuilder boolQuery = (BoolQueryBuilder) subQuery;
73+
assertTrue(boolQuery.must().size() == 2 || boolQuery.mustNot().size() == 1);
74+
if (boolQuery.must().size() == 2) {
75+
boolQuery.must().forEach(existSubQuery -> {
76+
assertTrue(existSubQuery instanceof ExistsQueryBuilder || existSubQuery instanceof TermsQueryBuilder);
77+
if (existSubQuery instanceof TermsQueryBuilder) {
78+
TermsQueryBuilder termsQuery = (TermsQueryBuilder) existSubQuery;
79+
assertTrue(termsQuery.fieldName().equals(MLAgent.IS_HIDDEN_FIELD));
80+
assertTrue(termsQuery.values().contains(false));
81+
} else {
82+
ExistsQueryBuilder existsQuery = (ExistsQueryBuilder) existSubQuery;
83+
assertTrue(existsQuery.fieldName().equals(MLAgent.IS_HIDDEN_FIELD));
84+
}
85+
});
86+
} else {
87+
QueryBuilder mustNotQuery = boolQuery.mustNot().get(0);
88+
assertTrue(mustNotQuery instanceof ExistsQueryBuilder);
89+
assertEquals(MLAgent.IS_HIDDEN_FIELD, ((ExistsQueryBuilder) mustNotQuery).fieldName());
90+
}
91+
});
92+
}
93+
}
94+
95+
}
96+
}

0 commit comments

Comments
 (0)