Skip to content

Commit 1659a60

Browse files
authored
applying multi-tenancy to task apis, deploy, predict apis (opensearch-project#3416)
* applying multi-tenancy to task, deploy, predict Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * addressed comments Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> --------- Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent f63b961 commit 1659a60

File tree

54 files changed

+1169
-465
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+1169
-465
lines changed

client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java

+36-3
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,18 @@ default ActionFuture<MLTask> getTask(String taskId) {
173173
* @param taskId id of the model
174174
* @param listener action listener
175175
*/
176-
void getTask(String taskId, ActionListener<MLTask> listener);
176+
default void getTask(String taskId, ActionListener<MLTask> listener) {
177+
getTask(taskId, null, listener);
178+
}
179+
180+
/**
181+
* Get MLTask and return task in listener
182+
* For more info on get task, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#get-task-information
183+
* @param taskId id of the model
184+
* @param tenantId the tenant id. This is necessary for multi-tenancy.
185+
* @param listener action listener
186+
*/
187+
void getTask(String taskId, String tenantId, ActionListener<MLTask> listener);
177188

178189
/**
179190
* Delete the model with modelId.
@@ -224,7 +235,18 @@ default ActionFuture<DeleteResponse> deleteTask(String taskId) {
224235
* @param taskId id of the task
225236
* @param listener action listener
226237
*/
227-
void deleteTask(String taskId, ActionListener<DeleteResponse> listener);
238+
default void deleteTask(String taskId, ActionListener<DeleteResponse> listener) {
239+
deleteTask(taskId, null, listener);
240+
}
241+
242+
/**
243+
* Delete MLTask
244+
* For more info on delete task, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#delete-task
245+
* @param taskId id of the task
246+
* @param tenantId the tenant id. This is necessary for multi-tenancy.
247+
* @param listener action listener
248+
*/
249+
void deleteTask(String taskId, String tenantId, ActionListener<DeleteResponse> listener);
228250

229251
/**
230252
* For more info on search model, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#search-model
@@ -298,7 +320,18 @@ default ActionFuture<MLDeployModelResponse> deploy(String modelId) {
298320
* @param modelId the model id
299321
* @param listener a listener to be notified of the result
300322
*/
301-
void deploy(String modelId, ActionListener<MLDeployModelResponse> listener);
323+
default void deploy(String modelId, ActionListener<MLDeployModelResponse> listener) {
324+
deploy(modelId, null, listener);
325+
}
326+
327+
/**
328+
* Deploy model
329+
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/deploy-model/
330+
* @param modelId the model id
331+
* @param tenantId the tenant id. This is necessary for multi-tenancy.
332+
* @param listener a listener to be notified of the result
333+
*/
334+
void deploy(String modelId, String tenantId, ActionListener<MLDeployModelResponse> listener);
302335

303336
/**
304337
* Undeploy models

client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java

+16-2
Original file line numberDiff line numberDiff line change
@@ -223,13 +223,27 @@ public void getTask(String taskId, ActionListener<MLTask> listener) {
223223
client.execute(MLTaskGetAction.INSTANCE, mlTaskGetRequest, getMLTaskResponseActionListener(listener));
224224
}
225225

226+
@Override
227+
public void getTask(String taskId, String tenantId, ActionListener<MLTask> listener) {
228+
MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder().taskId(taskId).tenantId(tenantId).build();
229+
230+
client.execute(MLTaskGetAction.INSTANCE, mlTaskGetRequest, getMLTaskResponseActionListener(listener));
231+
}
232+
226233
@Override
227234
public void deleteTask(String taskId, ActionListener<DeleteResponse> listener) {
228235
MLTaskDeleteRequest mlTaskDeleteRequest = MLTaskDeleteRequest.builder().taskId(taskId).build();
229236

230237
client.execute(MLTaskDeleteAction.INSTANCE, mlTaskDeleteRequest, ActionListener.wrap(listener::onResponse, listener::onFailure));
231238
}
232239

240+
@Override
241+
public void deleteTask(String taskId, String tenantId, ActionListener<DeleteResponse> listener) {
242+
MLTaskDeleteRequest mlTaskDeleteRequest = MLTaskDeleteRequest.builder().taskId(taskId).tenantId(tenantId).build();
243+
244+
client.execute(MLTaskDeleteAction.INSTANCE, mlTaskDeleteRequest, ActionListener.wrap(listener::onResponse, listener::onFailure));
245+
}
246+
233247
@Override
234248
public void searchTask(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
235249
client.execute(MLTaskSearchAction.INSTANCE, searchRequest, ActionListener.wrap(listener::onResponse, listener::onFailure));
@@ -242,8 +256,8 @@ public void register(MLRegisterModelInput mlInput, ActionListener<MLRegisterMode
242256
}
243257

244258
@Override
245-
public void deploy(String modelId, ActionListener<MLDeployModelResponse> listener) {
246-
MLDeployModelRequest deployModelRequest = new MLDeployModelRequest(modelId, false);
259+
public void deploy(String modelId, String tenantId, ActionListener<MLDeployModelResponse> listener) {
260+
MLDeployModelRequest deployModelRequest = new MLDeployModelRequest(modelId, tenantId, false);
247261
client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, getMlDeployModelResponseActionListener(listener));
248262
}
249263

client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java

+19-4
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,12 @@ public class MachineLearningClientTest {
109109
@Mock
110110
MLConfigGetResponse configGetResponse;
111111

112-
private String modekId = "test_model_id";
112+
private final String modekId = "test_model_id";
113113
private MLModel mlModel;
114114
private MLTask mlTask;
115115
private MLConfig mlConfig;
116116
private ToolMetadata toolMetadata;
117-
private List<ToolMetadata> toolsList = new ArrayList<>();
117+
private final List<ToolMetadata> toolsList = new ArrayList<>();
118118

119119
@Before
120120
public void setUp() {
@@ -194,11 +194,21 @@ public void getTask(String taskId, ActionListener<MLTask> listener) {
194194
listener.onResponse(mlTask);
195195
}
196196

197+
@Override
198+
public void getTask(String taskId, String tenantId, ActionListener<MLTask> listener) {
199+
listener.onResponse(mlTask);
200+
}
201+
197202
@Override
198203
public void deleteTask(String taskId, ActionListener<DeleteResponse> listener) {
199204
listener.onResponse(deleteResponse);
200205
}
201206

207+
@Override
208+
public void deleteTask(String taskId, String tenantId, ActionListener<DeleteResponse> listener) {
209+
listener.onResponse(deleteResponse);
210+
}
211+
202212
@Override
203213
public void searchTask(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
204214
listener.onResponse(searchResponse);
@@ -214,6 +224,11 @@ public void deploy(String modelId, ActionListener<MLDeployModelResponse> listene
214224
listener.onResponse(deployModelResponse);
215225
}
216226

227+
@Override
228+
public void deploy(String modelId, String tenantId, ActionListener<MLDeployModelResponse> listener) {
229+
listener.onResponse(deployModelResponse);
230+
}
231+
217232
@Override
218233
public void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener) {
219234
listener.onResponse(undeployModelsResponse);
@@ -487,8 +502,8 @@ public void createConnector() {
487502
@Test
488503
public void executeMetricsCorrelation() {
489504
List<float[]> inputData = new ArrayList<>(
490-
Arrays
491-
.asList(
505+
List
506+
.of(
492507
new float[] {
493508
0.89451003f,
494509
4.2006273f,

common/src/main/java/org/opensearch/ml/common/MLModel.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -744,8 +744,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
744744
}
745745

746746
public static MLModel fromStream(StreamInput in) throws IOException {
747-
MLModel mlModel = new MLModel(in);
748-
return mlModel;
747+
return new MLModel(in);
749748
}
750749

751750
}

common/src/main/java/org/opensearch/ml/common/MLTask.java

+20-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
package org.opensearch.ml.common;
77

88
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
9+
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
910
import static org.opensearch.ml.common.CommonValue.USER;
11+
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;
1012

1113
import java.io.IOException;
1214
import java.time.Instant;
@@ -72,6 +74,7 @@ public class MLTask implements ToXContentObject, Writeable {
7274
private boolean async;
7375
@Setter
7476
private Map<String, Object> remoteJob;
77+
private String tenantId;
7578

7679
@Builder(toBuilder = true)
7780
public MLTask(
@@ -89,7 +92,8 @@ public MLTask(
8992
String error,
9093
User user,
9194
boolean async,
92-
Map<String, Object> remoteJob
95+
Map<String, Object> remoteJob,
96+
String tenantId
9397
) {
9498
this.taskId = taskId;
9599
this.modelId = modelId;
@@ -106,6 +110,7 @@ public MLTask(
106110
this.user = user;
107111
this.async = async;
108112
this.remoteJob = remoteJob;
113+
this.tenantId = tenantId;
109114
}
110115

111116
public MLTask(StreamInput input) throws IOException {
@@ -134,9 +139,10 @@ public MLTask(StreamInput input) throws IOException {
134139
this.async = input.readBoolean();
135140
if (streamInputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_BATCH_PREDICTION_JOB)) {
136141
if (input.readBoolean()) {
137-
this.remoteJob = input.readMap(s -> s.readString(), s -> s.readGenericValue());
142+
this.remoteJob = input.readMap(StreamInput::readString, StreamInput::readGenericValue);
138143
}
139144
}
145+
tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
140146
}
141147

142148
@Override
@@ -173,6 +179,9 @@ public void writeTo(StreamOutput out) throws IOException {
173179
out.writeBoolean(false);
174180
}
175181
}
182+
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
183+
out.writeOptionalString(tenantId);
184+
}
176185
}
177186

178187
@Override
@@ -221,12 +230,14 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
221230
if (remoteJob != null) {
222231
builder.field(REMOTE_JOB_FIELD, remoteJob);
223232
}
233+
if (tenantId != null) {
234+
builder.field(TENANT_ID_FIELD, tenantId);
235+
}
224236
return builder.endObject();
225237
}
226238

227239
public static MLTask fromStream(StreamInput in) throws IOException {
228-
MLTask mlTask = new MLTask(in);
229-
return mlTask;
240+
return new MLTask(in);
230241
}
231242

232243
public static MLTask parse(XContentParser parser) throws IOException {
@@ -245,6 +256,7 @@ public static MLTask parse(XContentParser parser) throws IOException {
245256
User user = null;
246257
boolean async = false;
247258
Map<String, Object> remoteJob = null;
259+
String tenantId = null;
248260

249261
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
250262
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -305,6 +317,9 @@ public static MLTask parse(XContentParser parser) throws IOException {
305317
case REMOTE_JOB_FIELD:
306318
remoteJob = parser.map();
307319
break;
320+
case TENANT_ID_FIELD:
321+
tenantId = parser.textOrNull();
322+
break;
308323
default:
309324
parser.skipChildren();
310325
break;
@@ -327,6 +342,7 @@ public static MLTask parse(XContentParser parser) throws IOException {
327342
.user(user)
328343
.async(async)
329344
.remoteJob(remoteJob)
345+
.tenantId(tenantId)
330346
.build();
331347
}
332348
}

common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelInput.java

+13-1
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55

66
package org.opensearch.ml.common.transport.deploy;
77

8+
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;
9+
810
import java.io.IOException;
911

12+
import org.opensearch.Version;
1013
import org.opensearch.core.common.io.stream.StreamInput;
1114
import org.opensearch.core.common.io.stream.StreamOutput;
1215
import org.opensearch.core.common.io.stream.Writeable;
@@ -18,6 +21,7 @@
1821
@Data
1922
public class MLDeployModelInput implements Writeable {
2023
private String modelId;
24+
private String tenantId;
2125
private String taskId;
2226
private String modelContentHash;
2327
private Integer nodeCount;
@@ -26,13 +30,15 @@ public class MLDeployModelInput implements Writeable {
2630
private MLTask mlTask;
2731

2832
public MLDeployModelInput(StreamInput in) throws IOException {
33+
Version streamInputVersion = in.getVersion();
2934
this.modelId = in.readString();
3035
this.taskId = in.readString();
3136
this.modelContentHash = in.readOptionalString();
3237
this.nodeCount = in.readInt();
3338
this.coordinatingNodeId = in.readString();
3439
this.isDeployToAllNodes = in.readOptionalBoolean();
3540
this.mlTask = new MLTask(in);
41+
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null;
3642
}
3743

3844
@Builder
@@ -43,7 +49,8 @@ public MLDeployModelInput(
4349
Integer nodeCount,
4450
String coordinatingNodeId,
4551
Boolean isDeployToAllNodes,
46-
MLTask mlTask
52+
MLTask mlTask,
53+
String tenantId
4754
) {
4855
this.modelId = modelId;
4956
this.taskId = taskId;
@@ -52,19 +59,24 @@ public MLDeployModelInput(
5259
this.coordinatingNodeId = coordinatingNodeId;
5360
this.isDeployToAllNodes = isDeployToAllNodes;
5461
this.mlTask = mlTask;
62+
this.tenantId = tenantId;
5563
}
5664

5765
public MLDeployModelInput() {}
5866

5967
@Override
6068
public void writeTo(StreamOutput out) throws IOException {
69+
Version streamOutputVersion = out.getVersion();
6170
out.writeString(modelId);
6271
out.writeString(taskId);
6372
out.writeOptionalString(modelContentHash);
6473
out.writeInt(nodeCount);
6574
out.writeString(coordinatingNodeId);
6675
out.writeOptionalBoolean(isDeployToAllNodes);
6776
mlTask.writeTo(out);
77+
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
78+
out.writeOptionalString(tenantId);
79+
}
6880
}
6981

7082
}

0 commit comments

Comments
 (0)