forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMLTaskRunner.java
138 lines (121 loc) · 5.69 KB
/
MLTaskRunner.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.ml.task;
import static org.opensearch.ml.utils.MLNodeUtils.checkOpenCircuitBreaker;
import java.util.HashMap;
import java.util.Map;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.transport.TransportResponse;
import org.opensearch.ml.breaker.MLCircuitBreakerService;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.transport.MLTaskRequest;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.transport.TransportResponseHandler;
import org.opensearch.transport.TransportService;
import com.google.common.collect.ImmutableMap;
import lombok.extern.log4j.Log4j2;
/**
* MLTaskRunner has common code for dispatching and running predict/training tasks.
* @param <Request> ML task request
* @param <Response> ML task request
*/
@Log4j2
public abstract class MLTaskRunner<Request extends MLTaskRequest, Response extends TransportResponse> {
public static final int TIMEOUT_IN_MILLIS = 2000;
protected final MLTaskManager mlTaskManager;
protected final MLStats mlStats;
protected final DiscoveryNodeHelper nodeHelper;
protected final MLTaskDispatcher mlTaskDispatcher;
protected final MLCircuitBreakerService mlCircuitBreakerService;
private final ClusterService clusterService;
public MLTaskRunner(
MLTaskManager mlTaskManager,
MLStats mlStats,
DiscoveryNodeHelper nodeHelper,
MLTaskDispatcher mlTaskDispatcher,
MLCircuitBreakerService mlCircuitBreakerService,
ClusterService clusterService
) {
this.mlTaskManager = mlTaskManager;
this.mlStats = mlStats;
this.nodeHelper = nodeHelper;
this.mlTaskDispatcher = mlTaskDispatcher;
this.mlCircuitBreakerService = mlCircuitBreakerService;
this.clusterService = clusterService;
}
protected void handleAsyncMLTaskFailure(MLTask mlTask, Exception e) {
// update task state to MLTaskState.FAILED
// update task error
if (mlTask.isAsync()) {
Map<String, Object> updatedFields = ImmutableMap
.of(MLTask.STATE_FIELD, MLTaskState.FAILED.name(), MLTask.ERROR_FIELD, e.getMessage());
// wait for 2 seconds to make sure failed state persisted
mlTaskManager.updateMLTask(mlTask.getTaskId(), updatedFields, TIMEOUT_IN_MILLIS, true);
}
}
protected void handleAsyncMLTaskComplete(MLTask mlTask) {
// update task state to MLTaskState.COMPLETED
if (mlTask.isAsync()) {
Map<String, Object> updatedFields = new HashMap<>();
updatedFields.put(MLTask.STATE_FIELD, MLTaskState.COMPLETED);
if (mlTask.getModelId() != null) {
updatedFields.put(MLTask.MODEL_ID_FIELD, mlTask.getModelId());
}
// wait for 2 seconds to make sure completed state persisted
mlTaskManager.updateMLTask(mlTask.getTaskId(), updatedFields, TIMEOUT_IN_MILLIS, true);
}
}
public void run(FunctionName functionName, Request request, TransportService transportService, ActionListener<Response> listener) {
if (!request.isDispatchTask()) {
log.debug("Run ML request {} locally", request.getRequestID());
checkCBAndExecute(functionName, request, listener);
return;
}
dispatchTask(functionName, request, transportService, listener);
}
protected ActionListener<MLTaskResponse> wrappedCleanupListener(ActionListener<MLTaskResponse> listener, String taskId) {
ActionListener<MLTaskResponse> internalListener = ActionListener.runAfter(listener, () -> {
mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement();
mlTaskManager.remove(taskId);
});
return internalListener;
}
public void dispatchTask(
FunctionName functionName,
Request request,
TransportService transportService,
ActionListener<Response> listener
) {
mlTaskDispatcher.dispatch(functionName, ActionListener.wrap(node -> {
String nodeId = node.getId();
if (clusterService.localNode().getId().equals(nodeId)) {
// Execute ML task locally
log.debug("Execute ML request {} locally on node {}", request.getRequestID(), nodeId);
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
executeTask(request, listener);
} else {
// Execute ML task remotely
log.debug("Execute ML request {} remotely on node {}", request.getRequestID(), nodeId);
request.setDispatchTask(false);
transportService.sendRequest(node, getTransportActionName(), request, getResponseHandler(listener));
}
}, e -> listener.onFailure(e)));
}
protected abstract String getTransportActionName();
protected abstract TransportResponseHandler<Response> getResponseHandler(ActionListener<Response> listener);
protected abstract void executeTask(Request request, ActionListener<Response> listener);
protected void checkCBAndExecute(FunctionName functionName, Request request, ActionListener<Response> listener) {
if (functionName != FunctionName.REMOTE) {
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
}
executeTask(request, listener);
}
}