Skip to content

Commit f82e148

Browse files
authored
add a flag to distinguish duplicate remote model auto deploy and tran… (opensearch-project#2410)
* add a flag to distinguish duplicate remote model auto deploy and transport deploy Signed-off-by: Xun Zhang <xunzh@amazon.com> * check for NPE for getIsAutoDeploying Signed-off-by: Xun Zhang <xunzh@amazon.com> --------- Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 21bf079 commit f82e148

File tree

3 files changed

+29
-3
lines changed

3 files changed

+29
-3
lines changed

plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ public class MLModelCache {
5454
@Setter
5555
private Boolean deployToAllNodes;
5656
private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Instant lastAccessTime;
57+
private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Boolean isAutoDeploying;
5758

5859
public MLModelCache() {
5960
targetWorkerNodes = ConcurrentHashMap.newKeySet();

plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java

+26-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import java.util.concurrent.ConcurrentHashMap;
1818
import java.util.stream.Collectors;
1919

20+
import org.apache.commons.lang3.BooleanUtils;
2021
import org.opensearch.cluster.service.ClusterService;
2122
import org.opensearch.common.settings.Settings;
2223
import org.opensearch.common.util.TokenBucket;
@@ -61,7 +62,7 @@ public synchronized void initModelState(
6162
List<String> targetWorkerNodes,
6263
boolean deployToAllNodes
6364
) {
64-
if (isModelRunningOnNode(modelId)) {
65+
if (isModelRunningOnNode(modelId) && !isAutoDeploying(modelId)) {
6566
throw new MLLimitExceededException("Duplicate deploy model task");
6667
}
6768
log.debug("init model state for model {}, state: {}", modelId, state);
@@ -74,7 +75,7 @@ public synchronized void initModelState(
7475
modelCaches.put(modelId, modelCache);
7576
}
7677

77-
public synchronized void initModelStateLocal(
78+
public synchronized void initModelStateAutoDeploy(
7879
String modelId,
7980
MLModelState state,
8081
FunctionName functionName,
@@ -92,6 +93,7 @@ public synchronized void initModelStateLocal(
9293
modelCache.setDeployToAllNodes(false);
9394
modelCache.setLastAccessTime(Instant.now());
9495
modelCaches.put(modelId, modelCache);
96+
setIsAutoDeploying(modelId, true);
9597
}
9698

9799
/**
@@ -279,6 +281,28 @@ public Boolean getIsModelEnabled(String modelId) {
279281
return modelCache.getIsModelEnabled();
280282
}
281283

284+
/**
285+
* Set a flag to show if model is in auto deploying status
286+
*
287+
* @param modelId model id
288+
* @param isModelAutoDeploying auto deploy flag
289+
*/
290+
public synchronized void setIsAutoDeploying(String modelId, Boolean isModelAutoDeploying) {
291+
log.debug("Setting the auto deploying flag for Model {}", modelId);
292+
getExistingModelCache(modelId).setIsAutoDeploying(isModelAutoDeploying);
293+
}
294+
295+
/**
296+
* Check if model is in auto deploying.
297+
*
298+
* @param modelId model id
299+
* @return true if model is auto deploying.
300+
*/
301+
public boolean isAutoDeploying(String modelId) {
302+
MLModelCache modelCache = modelCaches.get(modelId);
303+
return modelCache != null && BooleanUtils.isTrue(modelCache.getIsAutoDeploying());
304+
}
305+
282306
/**
283307
* Set memory size estimation CPU/GPU
284308
*

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -987,13 +987,14 @@ public void deployModel(
987987
if (!autoDeployModel) {
988988
modelCacheHelper.initModelState(modelId, MLModelState.DEPLOYING, functionName, workerNodes, deployToAllNodes);
989989
} else {
990-
modelCacheHelper.initModelStateLocal(modelId, MLModelState.DEPLOYING, functionName, workerNodes);
990+
modelCacheHelper.initModelStateAutoDeploy(modelId, MLModelState.DEPLOYING, functionName, workerNodes);
991991
}
992992

993993
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
994994
ActionListener<String> wrappedListener = ActionListener.runBefore(listener, () -> {
995995
context.restore();
996996
modelCacheHelper.removeAutoDeployModel(modelId);
997+
modelCacheHelper.setIsAutoDeploying(modelId, false);
997998
});
998999
if (!autoDeployModel) {
9991000
checkAndAddRunningTask(mlTask, maxDeployTasksPerNode);

0 commit comments

Comments
 (0)