Skip to content

Commit d0695a0

Browse files
set the mlModel info in the deploy stage (opensearch-project#2389) (opensearch-project#2390)
Signed-off-by: Xun Zhang <xunzh@amazon.com> (cherry picked from commit ff6048f) Co-authored-by: Xun Zhang <xunzh@amazon.com>
1 parent 77a644c commit d0695a0

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java

+6-5
Original file line numberDiff line numberDiff line change
@@ -386,9 +386,11 @@ public String[] getLocalDeployedModels() {
386386
*/
387387
public String[] getExpiredModels() {
388388
return modelCaches.entrySet().stream().filter(entry -> {
389-
MLModel mlModel = entry.getValue().getCachedModelInfo();
390-
if (mlModel.getDeploySetting() == null) {
391-
return false; // no TTL, never expire
389+
MLModelCache modelCache = entry.getValue();
390+
MLModel mlModel = modelCache.getCachedModelInfo();
391+
MLModelState modelState = modelCache.getModelState();
392+
if (mlModel == null || mlModel.getDeploySetting() == null) {
393+
return false; // no TTL, never expire
392394
}
393395
Duration liveDuration = Duration.between(entry.getValue().getLastAccessTime(), Instant.now());
394396
Long ttlInMinutes = mlModel.getDeploySetting().getModelTTLInMinutes();
@@ -397,8 +399,7 @@ public String[] getExpiredModels() {
397399
}
398400
Duration ttl = Duration.ofMinutes(ttlInMinutes);
399401
boolean isModelExpired = liveDuration.getSeconds() >= ttl.getSeconds();
400-
return isModelExpired
401-
&& (mlModel.getModelState() == MLModelState.DEPLOYED || mlModel.getModelState() == MLModelState.PARTIALLY_DEPLOYED);
402+
return isModelExpired && (modelState == MLModelState.DEPLOYED || modelState == MLModelState.PARTIALLY_DEPLOYED);
402403
}).map(entry -> entry.getKey()).collect(Collectors.toList()).toArray(new String[0]);
403404
}
404405

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

+1
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,7 @@ public void deployModel(
996996
}
997997
this.getModel(modelId, threadedActionListener(DEPLOY_THREAD_POOL, ActionListener.wrap(mlModel -> {
998998
modelCacheHelper.setIsModelEnabled(modelId, mlModel.getIsEnabled());
999+
modelCacheHelper.setModelInfo(modelId, mlModel);
9991000
if (FunctionName.REMOTE == mlModel.getAlgorithm()
10001001
|| (!FunctionName.isDLModel(mlModel.getAlgorithm()) && mlModel.getAlgorithm() != FunctionName.METRICS_CORRELATION)) {
10011002
// deploy remote model or model trained by built-in algorithm like kmeans

0 commit comments

Comments
 (0)