Skip to content

Commit b4adbc7

Browse files
committed
remove and add comment
Signed-off-by: xinyual <xinyual@amazon.com>
1 parent 7cc0a08 commit b4adbc7

File tree

3 files changed

+52
-28
lines changed

3 files changed

+52
-28
lines changed

plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java

+47-27
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,18 @@
1616
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;
1717
import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext;
1818

19+
import java.util.ArrayDeque;
1920
import java.util.ArrayList;
2021
import java.util.Arrays;
22+
import java.util.Deque;
2123
import java.util.List;
2224
import java.util.Map;
2325
import java.util.Objects;
2426
import java.util.concurrent.CountDownLatch;
2527
import java.util.concurrent.atomic.AtomicBoolean;
2628
import java.util.function.Function;
2729

30+
import org.apache.commons.lang3.tuple.Pair;
2831
import org.opensearch.OpenSearchStatusException;
2932
import org.opensearch.ResourceNotFoundException;
3033
import org.opensearch.action.ActionRequest;
@@ -290,21 +293,21 @@ private void checkAgentBeforeDeleteModel(String modelId, ActionListener<Boolean>
290293
private void checkIngestPipelineBeforeDeleteModel(String modelId, ActionListener<Boolean> actionListener) {
291294
GetPipelineRequest getPipelineRequest = new GetPipelineRequest();
292295
client.execute(GetPipelineAction.INSTANCE, getPipelineRequest, ActionListener.wrap(ingestPipelineResponse -> {
293-
List<String> allRelevantPipelineIds = findRelevantPipelines(
296+
List<String> allDependentPipelineIds = findDependentPipelines(
294297
ingestPipelineResponse.pipelines(),
295298
modelId,
296299
org.opensearch.ingest.PipelineConfiguration::getConfigAsMap,
297300
org.opensearch.ingest.PipelineConfiguration::getId
298301
);
299-
if (allRelevantPipelineIds.isEmpty()) {
302+
if (allDependentPipelineIds.isEmpty()) {
300303
actionListener.onResponse(true);
301304
} else {
302305
actionListener
303306
.onFailure(
304307
new OpenSearchStatusException(
305-
allRelevantPipelineIds.size()
308+
allDependentPipelineIds.size()
306309
+ " ingest pipelines are still using this model, please delete or update the pipelines first: "
307-
+ Arrays.toString(allRelevantPipelineIds.toArray(new String[0])),
310+
+ Arrays.toString(allDependentPipelineIds.toArray(new String[0])),
308311
RestStatus.CONFLICT
309312
)
310313
);
@@ -320,21 +323,21 @@ private void checkIngestPipelineBeforeDeleteModel(String modelId, ActionListener
320323
private void checkSearchPipelineBeforeDeleteModel(String modelId, ActionListener<Boolean> actionListener) {
321324
GetSearchPipelineRequest getSearchPipelineRequest = new GetSearchPipelineRequest();
322325
client.execute(GetSearchPipelineAction.INSTANCE, getSearchPipelineRequest, ActionListener.wrap(searchPipelineResponse -> {
323-
List<String> allRelevantPipelineIds = findRelevantPipelines(
326+
List<String> allDependentPipelineIds = findDependentPipelines(
324327
searchPipelineResponse.pipelines(),
325328
modelId,
326329
org.opensearch.search.pipeline.PipelineConfiguration::getConfigAsMap,
327330
org.opensearch.search.pipeline.PipelineConfiguration::getId
328331
);
329-
if (allRelevantPipelineIds.isEmpty()) {
332+
if (allDependentPipelineIds.isEmpty()) {
330333
actionListener.onResponse(true);
331334
} else {
332335
actionListener
333336
.onFailure(
334337
new OpenSearchStatusException(
335-
allRelevantPipelineIds.size()
338+
allDependentPipelineIds.size()
336339
+ " search pipelines are still using this model, please delete or update the pipelines first: "
337-
+ Arrays.toString(allRelevantPipelineIds.toArray(new String[0])),
340+
+ Arrays.toString(allDependentPipelineIds.toArray(new String[0])),
338341
RestStatus.CONFLICT
339342
)
340343
);
@@ -475,40 +478,57 @@ private Boolean isModelNotDeployed(MLModelState mlModelState) {
475478
&& !mlModelState.equals(MLModelState.PARTIALLY_DEPLOYED);
476479
}
477480

478-
private <T> List<String> findRelevantPipelines(
481+
private <T> List<String> findDependentPipelines(
479482
List<T> pipelineConfigurations,
480483
String candidateModelId,
481484
Function<T, Map<String, Object>> getConfigFunction,
482485
Function<T, String> getIdFunction
483486
) {
484-
List<String> relevantPipelineConfigurations = new ArrayList<>();
487+
List<String> dependentPipelineConfigurations = new ArrayList<>();
485488
for (T pipelineConfiguration : pipelineConfigurations) {
486489
Map<String, Object> config = getConfigFunction.apply(pipelineConfiguration);
487490
if (searchThroughConfig(config, candidateModelId, "")) {
488-
relevantPipelineConfigurations.add(getIdFunction.apply(pipelineConfiguration));
491+
dependentPipelineConfigurations.add(getIdFunction.apply(pipelineConfiguration));
489492
}
490493
}
491-
return relevantPipelineConfigurations;
494+
return dependentPipelineConfigurations;
492495
}
493496

497+
// This method is to go through the pipeline configs and only when the key is model id and value is
498+
// 1. String and equal to candidate id 2. A list of String containing candidate id We will return True. Otherwise False
494499
private Boolean searchThroughConfig(Object searchCandidate, String candidateId, String targetModelKey) {
495-
boolean flag = false;
496-
if (searchCandidate instanceof String
497-
&& Objects.equals(targetModelKey, PIPELINE_TARGET_MODEL_KEY)
498-
&& Objects.equals(candidateId, searchCandidate)) {
499-
return true;
500-
} else if (searchCandidate instanceof List<?>) {
501-
for (Object v : (List<?>) searchCandidate) {
502-
flag = flag || searchThroughConfig(v, candidateId, targetModelKey);
503-
}
504-
} else if (searchCandidate instanceof Map<?, ?>) {
505-
for (Map.Entry<String, Object> entry : ((Map<String, Object>) searchCandidate).entrySet()) {
506-
String key = entry.getKey();
507-
Object value = entry.getValue();
508-
flag = flag || searchThroughConfig(value, candidateId, key);
500+
// Use a stack to store the elements to be processed
501+
Deque<Pair<String, Object>> stack = new ArrayDeque<>();
502+
stack.push(Pair.of(targetModelKey, searchCandidate));
503+
504+
while (!stack.isEmpty()) {
505+
// Pop an item from the stack
506+
Pair<String, Object> current = stack.pop();
507+
String currentKey = current.getLeft();
508+
Object currentCandidate = current.getRight();
509+
510+
if (currentCandidate instanceof String) {
511+
// Check for a match
512+
if (Objects.equals(currentKey, PIPELINE_TARGET_MODEL_KEY) && Objects.equals(candidateId, currentCandidate)) {
513+
return true;
514+
}
515+
} else if (currentCandidate instanceof List<?>) {
516+
// Push all elements in the list onto the stack
517+
for (Object v : (List<?>) currentCandidate) {
518+
stack.push(Pair.of(currentKey, v));
519+
}
520+
} else if (currentCandidate instanceof Map<?, ?>) {
521+
// Push all values in the map onto the stack
522+
for (Map.Entry<?, ?> entry : ((Map<?, ?>) currentCandidate).entrySet()) {
523+
String key = (String) entry.getKey();
524+
Object value = entry.getValue();
525+
stack.push(Pair.of(key, value));
526+
}
509527
}
510528
}
511-
return flag;
529+
530+
// If no match is found
531+
return false;
512532
}
513533

514534
// this method is only to stub static method.

plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,7 @@ private void prepare() throws IOException {
749749
return null;
750750
}).when(client).execute(eq(GetSearchPipelineAction.INSTANCE), any(), any());
751751
configDataMap = Map
752-
.of("model_id", "test_id", "list_model_id", List.of("test_list_id"), "test_map_id", Map.of("test_key", "test_map_id"));
752+
.of("single_model_id", "test_id", "list_model_id", List.of("test_id"), "test_map_id", Map.of("model_id", "test_id"));
753753
doAnswer(invocation -> new SearchRequest()).when(agentModelsSearcher).constructQueryRequest(any());
754754

755755
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null, false);

spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java

+4
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ interface Factory<T extends Tool> {
130130
*/
131131
String getDefaultVersion();
132132

133+
/**
134+
* Get model id related field names
135+
* @return the list of all model id related field names
136+
*/
133137
List<String> getAllModelKeys();
134138
}
135139
}

0 commit comments

Comments
 (0)