Skip to content

Commit 68ecd9a

Browse files
committed
multi-tenancy + sdk client related changes in agents
Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent af96fe0 commit 68ecd9a

File tree

50 files changed

+1318
-565
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1318
-565
lines changed

client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java

+17-1
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,23 @@ default ActionFuture<DeleteResponse> deleteAgent(String agentId) {
474474
return actionFuture;
475475
}
476476

477-
void deleteAgent(String agentId, ActionListener<DeleteResponse> listener);
477+
/**
478+
* Delete agent
479+
* @param agentId The id of the agent to delete
480+
* @param listener a listener to be notified of the result
481+
*/
482+
default void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
483+
PlainActionFuture<DeleteResponse> actionFuture = PlainActionFuture.newFuture();
484+
deleteAgent(agentId, null, actionFuture);
485+
}
486+
487+
/**
488+
* Delete agent
489+
* @param agentId The id of the agent to delete
490+
* @param tenantId the tenant id. This is necessary for multi-tenancy.
491+
* @param listener a listener to be notified of the result
492+
*/
493+
void deleteAgent(String agentId, String tenantId, ActionListener<DeleteResponse> listener);
478494

479495
/**
480496
* Get a list of ToolMetadata and return ActionFuture.

client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,8 @@ public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentRespons
292292
}
293293

294294
@Override
295-
public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
296-
MLAgentDeleteRequest agentDeleteRequest = new MLAgentDeleteRequest(agentId);
295+
public void deleteAgent(String agentId, String tenantId, ActionListener<DeleteResponse> listener) {
296+
MLAgentDeleteRequest agentDeleteRequest = new MLAgentDeleteRequest(agentId, tenantId);
297297
client.execute(MLAgentDeleteAction.INSTANCE, agentDeleteRequest, ActionListener.wrap(listener::onResponse, listener::onFailure));
298298
}
299299

client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java

+5
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,11 @@ public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener)
291291
listener.onResponse(deleteResponse);
292292
}
293293

294+
@Override
295+
public void deleteAgent(String agentId, String tenantId, ActionListener<DeleteResponse> listener) {
296+
listener.onResponse(deleteResponse);
297+
}
298+
294299
@Override
295300
public void getConfig(String configId, ActionListener<MLConfig> listener) {
296301
listener.onResponse(mlConfig);

common/src/main/java/org/opensearch/ml/common/MLModel.java

+1-3
Original file line numberDiff line numberDiff line change
@@ -315,9 +315,7 @@ public MLModel(StreamInput input) throws IOException {
315315
if (input.readBoolean()) {
316316
modelInterface = input.readMap(StreamInput::readString, StreamInput::readString);
317317
}
318-
if (streamInputVersion.onOrAfter(VERSION_2_19_0)) {
319-
tenantId = input.readOptionalString();
320-
}
318+
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
321319
}
322320
}
323321

common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java

+21-5
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
package org.opensearch.ml.common.agent;
77

88
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
9+
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
10+
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;
911
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
1012

1113
import java.io.IOException;
@@ -63,6 +65,7 @@ public class MLAgent implements ToXContentObject, Writeable {
6365
private Instant lastUpdateTime;
6466
private String appType;
6567
private Boolean isHidden;
68+
private final String tenantId;
6669

6770
@Builder(toBuilder = true)
6871
public MLAgent(
@@ -76,7 +79,8 @@ public MLAgent(
7679
Instant createdTime,
7780
Instant lastUpdateTime,
7881
String appType,
79-
Boolean isHidden
82+
Boolean isHidden,
83+
String tenantId
8084
) {
8185
this.name = name;
8286
this.type = type;
@@ -90,6 +94,7 @@ public MLAgent(
9094
this.appType = appType;
9195
// is_hidden field isn't going to be set by user. It will be set by the code.
9296
this.isHidden = isHidden;
97+
this.tenantId = tenantId;
9398
validate();
9499
}
95100

@@ -155,6 +160,7 @@ public MLAgent(StreamInput input) throws IOException {
155160
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT)) {
156161
isHidden = input.readOptionalBoolean();
157162
}
163+
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
158164
validate();
159165
}
160166

@@ -169,7 +175,7 @@ public void writeTo(StreamOutput out) throws IOException {
169175
} else {
170176
out.writeBoolean(false);
171177
}
172-
if (tools != null && tools.size() > 0) {
178+
if (tools != null && !tools.isEmpty()) {
173179
out.writeBoolean(true);
174180
out.writeInt(tools.size());
175181
for (MLToolSpec tool : tools) {
@@ -178,7 +184,7 @@ public void writeTo(StreamOutput out) throws IOException {
178184
} else {
179185
out.writeBoolean(false);
180186
}
181-
if (parameters != null && parameters.size() > 0) {
187+
if (parameters != null && !parameters.isEmpty()) {
182188
out.writeBoolean(true);
183189
out.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeOptionalString);
184190
} else {
@@ -197,6 +203,9 @@ public void writeTo(StreamOutput out) throws IOException {
197203
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT)) {
198204
out.writeOptionalBoolean(isHidden);
199205
}
206+
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
207+
out.writeOptionalString(tenantId);
208+
}
200209
}
201210

202211
@Override
@@ -236,6 +245,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
236245
if (isHidden != null) {
237246
builder.field(MLModel.IS_HIDDEN_FIELD, isHidden);
238247
}
248+
if (tenantId != null) {
249+
builder.field(TENANT_ID_FIELD, tenantId);
250+
}
239251
builder.endObject();
240252
return builder;
241253
}
@@ -260,6 +272,7 @@ private static MLAgent parseCommonFields(XContentParser parser, boolean parseHid
260272
Instant lastUpdateTime = null;
261273
String appType = null;
262274
boolean isHidden = false;
275+
String tenantId = null;
263276

264277
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
265278
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -305,6 +318,9 @@ private static MLAgent parseCommonFields(XContentParser parser, boolean parseHid
305318
if (parseHidden)
306319
isHidden = parser.booleanValue();
307320
break;
321+
case TENANT_ID_FIELD:
322+
tenantId = parser.textOrNull();
323+
break;
308324
default:
309325
parser.skipChildren();
310326
break;
@@ -324,11 +340,11 @@ private static MLAgent parseCommonFields(XContentParser parser, boolean parseHid
324340
.lastUpdateTime(lastUpdateTime)
325341
.appType(appType)
326342
.isHidden(isHidden)
343+
.tenantId(tenantId)
327344
.build();
328345
}
329346

330347
public static MLAgent fromStream(StreamInput in) throws IOException {
331-
MLAgent agent = new MLAgent(in);
332-
return agent;
348+
return new MLAgent(in);
333349
}
334350
}

common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java

+25-5
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
package org.opensearch.ml.common.agent;
77

88
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
9+
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
10+
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;
911
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
1012

1113
import java.io.IOException;
@@ -22,6 +24,7 @@
2224
import lombok.Builder;
2325
import lombok.EqualsAndHashCode;
2426
import lombok.Getter;
27+
import lombok.Setter;
2528

2629
@EqualsAndHashCode
2730
@Getter
@@ -41,6 +44,8 @@ public class MLToolSpec implements ToXContentObject {
4144
private Map<String, String> parameters;
4245
private boolean includeOutputInAgentResponse;
4346
private Map<String, String> configMap;
47+
@Setter
48+
private String tenantId;
4449

4550
@Builder(toBuilder = true)
4651
public MLToolSpec(
@@ -49,7 +54,8 @@ public MLToolSpec(
4954
String description,
5055
Map<String, String> parameters,
5156
boolean includeOutputInAgentResponse,
52-
Map<String, String> configMap
57+
Map<String, String> configMap,
58+
String tenantId
5359
) {
5460
if (type == null) {
5561
throw new IllegalArgumentException("tool type is null");
@@ -60,9 +66,11 @@ public MLToolSpec(
6066
this.parameters = parameters;
6167
this.includeOutputInAgentResponse = includeOutputInAgentResponse;
6268
this.configMap = configMap;
69+
this.tenantId = tenantId;
6370
}
6471

6572
public MLToolSpec(StreamInput input) throws IOException {
73+
Version streamInputVersion = input.getVersion();
6674
type = input.readString();
6775
name = input.readOptionalString();
6876
description = input.readOptionalString();
@@ -73,13 +81,15 @@ public MLToolSpec(StreamInput input) throws IOException {
7381
if (input.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_TOOL_CONFIG) && input.readBoolean()) {
7482
configMap = input.readMap(StreamInput::readString, StreamInput::readOptionalString);
7583
}
84+
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
7685
}
7786

7887
public void writeTo(StreamOutput out) throws IOException {
88+
Version streamOutputVersion = out.getVersion();
7989
out.writeString(type);
8090
out.writeOptionalString(name);
8191
out.writeOptionalString(description);
82-
if (parameters != null && parameters.size() > 0) {
92+
if (parameters != null && !parameters.isEmpty()) {
8393
out.writeBoolean(true);
8494
out.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeOptionalString);
8595
} else {
@@ -94,6 +104,9 @@ public void writeTo(StreamOutput out) throws IOException {
94104
out.writeBoolean(false);
95105
}
96106
}
107+
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
108+
out.writeOptionalString(tenantId);
109+
}
97110
}
98111

99112
@Override
@@ -108,13 +121,16 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
108121
if (description != null) {
109122
builder.field(DESCRIPTION_FIELD, description);
110123
}
111-
if (parameters != null && parameters.size() > 0) {
124+
if (parameters != null && !parameters.isEmpty()) {
112125
builder.field(PARAMETERS_FIELD, parameters);
113126
}
114127
builder.field(INCLUDE_OUTPUT_IN_AGENT_RESPONSE, includeOutputInAgentResponse);
115128
if (configMap != null && !configMap.isEmpty()) {
116129
builder.field(CONFIG_FIELD, configMap);
117130
}
131+
if (tenantId != null) {
132+
builder.field(TENANT_ID_FIELD, tenantId);
133+
}
118134
builder.endObject();
119135
return builder;
120136
}
@@ -126,6 +142,7 @@ public static MLToolSpec parse(XContentParser parser) throws IOException {
126142
Map<String, String> parameters = null;
127143
boolean includeOutputInAgentResponse = false;
128144
Map<String, String> configMap = null;
145+
String tenantId = null;
129146

130147
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
131148
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -151,6 +168,9 @@ public static MLToolSpec parse(XContentParser parser) throws IOException {
151168
case CONFIG_FIELD:
152169
configMap = getParameterMap(parser.map());
153170
break;
171+
case TENANT_ID_FIELD:
172+
tenantId = parser.textOrNull();
173+
break;
154174
default:
155175
parser.skipChildren();
156176
break;
@@ -164,11 +184,11 @@ public static MLToolSpec parse(XContentParser parser) throws IOException {
164184
.parameters(parameters)
165185
.includeOutputInAgentResponse(includeOutputInAgentResponse)
166186
.configMap(configMap)
187+
.tenantId(tenantId)
167188
.build();
168189
}
169190

170191
public static MLToolSpec fromStream(StreamInput in) throws IOException {
171-
MLToolSpec toolSpec = new MLToolSpec(in);
172-
return toolSpec;
192+
return new MLToolSpec(in);
173193
}
174194
}

common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ protected Map<String, String> createDecryptedHeaders(Map<String, String> headers
8383
for (String key : headers.keySet()) {
8484
decryptedHeaders.put(key, substitutor.replace(headers.get(key)));
8585
}
86-
if (parameters != null && parameters.size() > 0) {
86+
if (parameters != null && !parameters.isEmpty()) {
8787
substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
8888
for (String key : decryptedHeaders.keySet()) {
8989
decryptedHeaders.put(key, substitutor.replace(decryptedHeaders.get(key)));
@@ -142,11 +142,11 @@ public void removeCredential() {
142142
@Override
143143
public String getActionEndpoint(String action, Map<String, String> parameters) {
144144
Optional<ConnectorAction> actionEndpoint = findAction(action);
145-
if (!actionEndpoint.isPresent()) {
145+
if (actionEndpoint.isEmpty()) {
146146
return null;
147147
}
148148
String predictEndpoint = actionEndpoint.get().getUrl();
149-
if (parameters != null && parameters.size() > 0) {
149+
if (parameters != null && !parameters.isEmpty()) {
150150
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
151151
predictEndpoint = substitutor.replace(predictEndpoint);
152152
}

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

+1-3
Original file line numberDiff line numberDiff line change
@@ -242,9 +242,7 @@ private void parseFromStream(StreamInput input) throws IOException {
242242
if (input.readBoolean()) {
243243
this.connectorClientConfig = new ConnectorClientConfig(input);
244244
}
245-
if (streamInputVersion.onOrAfter(VERSION_2_19_0)) {
246-
this.tenantId = input.readOptionalString();
247-
}
245+
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
248246
}
249247

250248
@Override

common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import lombok.Getter;
2222
import lombok.Setter;
2323

24+
@Setter
2425
@Getter
2526
@InputDataSet(MLInputDataType.REMOTE)
2627
public class RemoteInferenceInputDataSet extends MLInputDataset {
@@ -45,7 +46,7 @@ public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException {
4546
super(MLInputDataType.REMOTE);
4647
Version streamInputVersion = streamInput.getVersion();
4748
if (streamInput.readBoolean()) {
48-
parameters = streamInput.readMap(s -> s.readString(), s -> s.readString());
49+
parameters = streamInput.readMap(StreamInput::readString, StreamInput::readString);
4950
}
5051
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG)) {
5152
if (streamInput.readBoolean()) {

common/src/main/java/org/opensearch/ml/common/input/MLInput.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -164,18 +164,18 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
164164
TextDocsInputDataSet textInputDataSet = (TextDocsInputDataSet) this.inputDataset;
165165
List<String> docs = textInputDataSet.getDocs();
166166
ModelResultFilter resultFilter = textInputDataSet.getResultFilter();
167-
if (docs != null && docs.size() > 0) {
167+
if (docs != null && !docs.isEmpty()) {
168168
builder.field(TEXT_DOCS_FIELD, docs.toArray(new String[0]));
169169
}
170170
if (resultFilter != null) {
171171
builder.field(RETURN_BYTES_FIELD, resultFilter.isReturnBytes());
172172
builder.field(RETURN_NUMBER_FIELD, resultFilter.isReturnNumber());
173173
List<String> targetResponse = resultFilter.getTargetResponse();
174-
if (targetResponse != null && targetResponse.size() > 0) {
174+
if (targetResponse != null && !targetResponse.isEmpty()) {
175175
builder.field(TARGET_RESPONSE_FIELD, targetResponse.toArray(new String[0]));
176176
}
177177
List<Integer> targetPositions = resultFilter.getTargetResponsePositions();
178-
if (targetPositions != null && targetPositions.size() > 0) {
178+
if (targetPositions != null && !targetPositions.isEmpty()) {
179179
builder.field(TARGET_RESPONSE_POSITIONS_FIELD, targetPositions.toArray(new Integer[0]));
180180
}
181181
}

0 commit comments

Comments
 (0)