Skip to content

Commit 851d49e

Browse files
zhanghg08jackiehanyang
authored andcommitted
Add MLTaskrunner and prediction/training transportaction (opensearch-project#37)
* Add MLTaskRunner * Add MLPredictionTaskRemoteExecutionTransportAction * Add MLPredictionTaskRemoteExecutionAction * Add missing HEADER * Add Training and Prediction Transportaction * Register transportaction in plugin
1 parent 6d6ae29 commit 851d49e

9 files changed

+674
-4
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*
8+
* Modifications Copyright OpenSearch Contributors. See
9+
* GitHub history for details.
10+
*
11+
*/
12+
13+
package org.opensearch.ml.action.prediction;
14+
15+
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskResponse;
16+
import org.opensearch.action.ActionType;
17+
18+
public class MLPredictionTaskExecutionAction extends ActionType<MLPredictionTaskResponse> {
19+
public static MLPredictionTaskExecutionAction INSTANCE = new MLPredictionTaskExecutionAction();
20+
public static final String NAME = "cluster:admin/opensearch-ml/prediction/execution";
21+
22+
private MLPredictionTaskExecutionAction() {
23+
super(NAME, MLPredictionTaskResponse::new);
24+
}
25+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*
8+
* Modifications Copyright OpenSearch Contributors. See
9+
* GitHub history for details.
10+
*
11+
*/
12+
13+
package org.opensearch.ml.action.prediction;
14+
15+
import org.opensearch.ml.task.MLTaskRunner;
16+
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
17+
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskResponse;
18+
import org.opensearch.action.ActionListener;
19+
import org.opensearch.action.support.ActionFilters;
20+
import org.opensearch.action.support.HandledTransportAction;
21+
import org.opensearch.common.inject.Inject;
22+
import org.opensearch.tasks.Task;
23+
import org.opensearch.transport.TransportService;
24+
25+
public class MLPredictionTaskExecutionTransportAction extends HandledTransportAction<MLPredictionTaskRequest, MLPredictionTaskResponse> {
26+
private final MLTaskRunner mlTaskRunner;
27+
private final TransportService transportService;
28+
29+
@Inject
30+
public MLPredictionTaskExecutionTransportAction(
31+
ActionFilters actionFilters,
32+
TransportService transportService,
33+
MLTaskRunner mlTaskRunner
34+
) {
35+
super(MLPredictionTaskExecutionAction.NAME, transportService, actionFilters, MLPredictionTaskRequest::new);
36+
this.mlTaskRunner = mlTaskRunner;
37+
this.transportService = transportService;
38+
}
39+
40+
@Override
41+
protected void doExecute(Task task, MLPredictionTaskRequest request, ActionListener<MLPredictionTaskResponse> listener) {
42+
mlTaskRunner.startPredictionTask(request, listener);
43+
}
44+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*
8+
* Modifications Copyright OpenSearch Contributors. See
9+
* GitHub history for details.
10+
*
11+
*/
12+
13+
package org.opensearch.ml.action.prediction;
14+
15+
import lombok.AccessLevel;
16+
import lombok.experimental.FieldDefaults;
17+
import lombok.extern.log4j.Log4j2;
18+
import org.opensearch.action.ActionListener;
19+
import org.opensearch.action.ActionRequest;
20+
import org.opensearch.action.support.ActionFilters;
21+
import org.opensearch.action.support.HandledTransportAction;
22+
import org.opensearch.common.inject.Inject;
23+
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
24+
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
25+
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskResponse;
26+
import org.opensearch.ml.task.MLTaskRunner;
27+
import org.opensearch.tasks.Task;
28+
import org.opensearch.transport.TransportService;
29+
30+
@Log4j2
31+
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
32+
public class TransportPredictionTaskAction extends HandledTransportAction<ActionRequest, MLPredictionTaskResponse> {
33+
MLTaskRunner mlTaskRunner;
34+
TransportService transportService;
35+
36+
@Inject
37+
public TransportPredictionTaskAction(TransportService transportService, ActionFilters actionFilters, MLTaskRunner mlTaskRunner) {
38+
super(MLPredictionTaskAction.NAME, transportService, actionFilters, MLPredictionTaskRequest::new);
39+
this.mlTaskRunner = mlTaskRunner;
40+
this.transportService = transportService;
41+
}
42+
43+
@Override
44+
protected void doExecute(Task task, ActionRequest request,
45+
ActionListener<MLPredictionTaskResponse> listener) {
46+
MLPredictionTaskRequest mlPredictionTaskRequest = MLPredictionTaskRequest.fromActionRequest(request);
47+
mlTaskRunner.runPrediction(mlPredictionTaskRequest, transportService, listener);
48+
}
49+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*
8+
* Modifications Copyright OpenSearch Contributors. See
9+
* GitHub history for details.
10+
*
11+
*/
12+
13+
package org.opensearch.ml.action.training;
14+
15+
import org.opensearch.ml.common.transport.training.MLTrainingTaskResponse;
16+
import org.opensearch.action.ActionType;
17+
18+
public class MLTrainingTaskExecutionAction extends ActionType<MLTrainingTaskResponse> {
19+
20+
public static final MLTrainingTaskExecutionAction INSTANCE = new MLTrainingTaskExecutionAction();
21+
public static final String NAME = "cluster:admin/opensearch-ml/training/execution";
22+
23+
public MLTrainingTaskExecutionAction() {
24+
super(NAME, MLTrainingTaskResponse::new);
25+
}
26+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*
8+
* Modifications Copyright OpenSearch Contributors. See
9+
* GitHub history for details.
10+
*
11+
*/
12+
13+
package org.opensearch.ml.action.training;
14+
15+
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
16+
import org.opensearch.ml.common.transport.training.MLTrainingTaskResponse;
17+
import lombok.AccessLevel;
18+
import lombok.experimental.FieldDefaults;
19+
import lombok.extern.log4j.Log4j2;
20+
import org.opensearch.action.ActionListener;
21+
import org.opensearch.action.support.ActionFilters;
22+
import org.opensearch.action.support.HandledTransportAction;
23+
import org.opensearch.common.inject.Inject;
24+
import org.opensearch.ml.task.MLTaskRunner;
25+
import org.opensearch.tasks.Task;
26+
import org.opensearch.transport.TransportService;
27+
28+
@Log4j2
29+
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
30+
public class MLTrainingTaskExecutionTransportAction extends HandledTransportAction<MLTrainingTaskRequest, MLTrainingTaskResponse> {
31+
MLTaskRunner mlTaskRunner;
32+
33+
@Inject
34+
public MLTrainingTaskExecutionTransportAction(TransportService transportService, ActionFilters actionFilters,
35+
MLTaskRunner mlTaskRunner) {
36+
super(MLTrainingTaskExecutionAction.NAME, transportService, actionFilters, MLTrainingTaskRequest::new);
37+
this.mlTaskRunner = mlTaskRunner;
38+
}
39+
40+
@Override
41+
protected void doExecute(Task task, MLTrainingTaskRequest request, ActionListener<MLTrainingTaskResponse> listener) {
42+
mlTaskRunner.startTrainingTask(request, listener);
43+
}
44+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*
8+
* Modifications Copyright OpenSearch Contributors. See
9+
* GitHub history for details.
10+
*
11+
*/
12+
13+
package org.opensearch.ml.action.training;
14+
15+
import lombok.extern.log4j.Log4j2;
16+
import org.opensearch.action.ActionListener;
17+
import org.opensearch.action.ActionRequest;
18+
import org.opensearch.action.support.ActionFilters;
19+
import org.opensearch.action.support.HandledTransportAction;
20+
import org.opensearch.common.inject.Inject;
21+
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction;
22+
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
23+
import org.opensearch.ml.common.transport.training.MLTrainingTaskResponse;
24+
import org.opensearch.ml.task.MLTaskRunner;
25+
import org.opensearch.tasks.Task;
26+
import org.opensearch.transport.TransportService;
27+
28+
@Log4j2
29+
public class TransportTrainingTaskAction extends HandledTransportAction<ActionRequest, MLTrainingTaskResponse> {
30+
31+
MLTaskRunner mlTaskRunner;
32+
TransportService transportService;
33+
34+
@Inject
35+
public TransportTrainingTaskAction(TransportService transportService, ActionFilters actionFilters,
36+
MLTaskRunner mlTaskRunner) {
37+
super(MLTrainingTaskAction.NAME, transportService, actionFilters, MLTrainingTaskRequest::new);
38+
this.mlTaskRunner = mlTaskRunner;
39+
this.transportService = transportService;
40+
}
41+
42+
@Override
43+
protected void doExecute(Task task, ActionRequest request, ActionListener<MLTrainingTaskResponse> listener) {
44+
MLTrainingTaskRequest trainingRequest = MLTrainingTaskRequest.fromActionRequest(request);
45+
mlTaskRunner.runTraining(trainingRequest, transportService, listener);
46+
}
47+
}

plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
import lombok.RequiredArgsConstructor;
1717
import lombok.experimental.FieldDefaults;
1818
import lombok.extern.log4j.Log4j2;
19-
import lombok.val;
20-
import org.opensearch.action.admin.indices.create.CreateIndexRequest;
2119
import org.opensearch.client.Client;
2220
import org.opensearch.cluster.service.ClusterService;
2321
import org.opensearch.common.xcontent.XContentType;
@@ -31,7 +29,9 @@ public class MLIndicesHandler {
3129
" \"properties\": {\n" +
3230
" \"taskId\": { \"type\": \"keyword\" },\n" +
3331
" \"algorithm\": {\"type\": \"keyword\"},\n" +
34-
" \"model\" : { \"type\": \"binary\"}\n" +
32+
" \"modelName\" : { \"type\": \"keyword\"},\n" +
33+
" \"modelVersion\" : { \"type\": \"keyword\"},\n" +
34+
" \"modelContent\" : { \"type\": \"binary\"}\n" +
3535
" }\n" +
3636
"}";
3737

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

+32-1
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,16 @@
2828
import org.opensearch.common.xcontent.NamedXContentRegistry;
2929
import org.opensearch.env.Environment;
3030
import org.opensearch.env.NodeEnvironment;
31+
import org.opensearch.ml.action.prediction.MLPredictionTaskExecutionAction;
32+
import org.opensearch.ml.action.prediction.MLPredictionTaskExecutionTransportAction;
33+
import org.opensearch.ml.action.prediction.TransportPredictionTaskAction;
3134
import org.opensearch.ml.action.stats.MLStatsNodesAction;
3235
import org.opensearch.ml.action.stats.MLStatsNodesTransportAction;
36+
import org.opensearch.ml.action.training.MLTrainingTaskExecutionAction;
37+
import org.opensearch.ml.action.training.MLTrainingTaskExecutionTransportAction;
38+
import org.opensearch.ml.action.training.TransportTrainingTaskAction;
39+
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
40+
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction;
3341
import org.opensearch.ml.rest.RestStatsMLAction;
3442
import org.opensearch.ml.stats.MLStat;
3543
import org.opensearch.ml.stats.MLStats;
@@ -44,15 +52,19 @@
4452
import org.opensearch.rest.RestController;
4553
import org.opensearch.rest.RestHandler;
4654
import org.opensearch.script.ScriptService;
55+
import org.opensearch.threadpool.ExecutorBuilder;
56+
import org.opensearch.threadpool.FixedExecutorBuilder;
4757
import org.opensearch.threadpool.ThreadPool;
4858
import org.opensearch.watcher.ResourceWatcherService;
4959

5060
import java.util.Collection;
61+
import java.util.Collections;
5162
import java.util.List;
5263
import java.util.Map;
5364
import java.util.function.Supplier;
5465

5566
public class MachineLearningPlugin extends Plugin implements ActionPlugin {
67+
public static final String TASK_THREAD_POOL = "OPENSEARCH_ML_TASK_THREAD_POOL";
5668
public static final String ML_BASE_URI = "/_opensearch/_ml";
5769

5870
private MLStats mlStats;
@@ -72,7 +84,12 @@ public Setting<Boolean> legacySetting() {
7284
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
7385
return ImmutableList.of(
7486
new ActionHandler<>(MLStatsNodesAction.INSTANCE,
75-
MLStatsNodesTransportAction.class)
87+
MLStatsNodesTransportAction.class),
88+
new ActionHandler<>(MLPredictionTaskAction.INSTANCE, TransportPredictionTaskAction.class),
89+
new ActionHandler<>(MLTrainingTaskAction.INSTANCE, TransportTrainingTaskAction.class),
90+
new ActionHandler<>(MLPredictionTaskExecutionAction.INSTANCE,
91+
MLPredictionTaskExecutionTransportAction.class),
92+
new ActionHandler<>(MLTrainingTaskExecutionAction.INSTANCE, MLTrainingTaskExecutionTransportAction.class)
7693
);
7794
}
7895

@@ -109,4 +126,18 @@ public List<RestHandler> getRestHandlers(
109126
restStatsMLAction
110127
);
111128
}
129+
130+
@Override
131+
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
132+
FixedExecutorBuilder ml = new FixedExecutorBuilder(
133+
settings,
134+
TASK_THREAD_POOL,
135+
4,
136+
4,
137+
"ml.task_thread_pool",
138+
false
139+
);
140+
141+
return Collections.singletonList(ml);
142+
}
112143
}

0 commit comments

Comments
 (0)