Skip to content

Commit f04507a

Browse files
committed
applying multi-tenancy in search [model, model group, agent, connector] (opensearch-project#3433)
* applying multi-tenancy in search Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * addressed comments Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> --------- Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent 4e8afcd commit f04507a

32 files changed

+1187
-172
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
package org.opensearch.ml.common.transport.search;
2+
3+
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;
4+
5+
import java.io.ByteArrayInputStream;
6+
import java.io.ByteArrayOutputStream;
7+
import java.io.IOException;
8+
import java.io.UncheckedIOException;
9+
10+
import org.opensearch.Version;
11+
import org.opensearch.action.ActionRequest;
12+
import org.opensearch.action.search.SearchRequest;
13+
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
14+
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
15+
import org.opensearch.core.common.io.stream.StreamInput;
16+
import org.opensearch.core.common.io.stream.StreamOutput;
17+
18+
import lombok.Builder;
19+
import lombok.Getter;
20+
21+
/**
22+
* Represents an extended search action request that includes a tenant ID.
23+
* This class allows OpenSearch to include a tenant ID in search requests,
24+
* which is not natively supported in the standard {@link SearchRequest}.
25+
*/
26+
@Getter
27+
public class MLSearchActionRequest extends SearchRequest {
28+
SearchRequest searchRequest;
29+
String tenantId;
30+
31+
/**
32+
* Constructor for building an MLSearchActionRequest.
33+
*
34+
* @param searchRequest The original {@link SearchRequest} to be wrapped.
35+
* @param tenantId The tenant ID associated with the request.
36+
*/
37+
@Builder
38+
public MLSearchActionRequest(SearchRequest searchRequest, String tenantId) {
39+
this.searchRequest = searchRequest;
40+
this.tenantId = tenantId;
41+
}
42+
43+
/**
44+
* Deserializes an {@link MLSearchActionRequest} from a {@link StreamInput}.
45+
*
46+
* @param input The stream input to read from.
47+
* @throws IOException If an I/O error occurs during deserialization.
48+
*/
49+
public MLSearchActionRequest(StreamInput input) throws IOException {
50+
super(input);
51+
Version streamInputVersion = input.getVersion();
52+
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
53+
}
54+
55+
/**
56+
* Serializes this {@link MLSearchActionRequest} to a {@link StreamOutput}.
57+
*
58+
* @param output The stream output to write to.
59+
* @throws IOException If an I/O error occurs during serialization.
60+
*/
61+
@Override
62+
public void writeTo(StreamOutput output) throws IOException {
63+
super.writeTo(output);
64+
Version streamOutputVersion = output.getVersion();
65+
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
66+
output.writeOptionalString(tenantId);
67+
}
68+
}
69+
70+
/**
71+
* Converts a generic {@link ActionRequest} into an {@link MLSearchActionRequest}.
72+
* This is useful when handling requests that may need to be converted for compatibility.
73+
*
74+
* @param actionRequest The original {@link ActionRequest}.
75+
* @return The converted {@link MLSearchActionRequest}.
76+
* @throws UncheckedIOException If the conversion fails due to an I/O error.
77+
*/
78+
public static MLSearchActionRequest fromActionRequest(ActionRequest actionRequest) {
79+
if (actionRequest instanceof MLSearchActionRequest) {
80+
return (MLSearchActionRequest) actionRequest;
81+
}
82+
83+
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
84+
actionRequest.writeTo(osso);
85+
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
86+
return new MLSearchActionRequest(input);
87+
}
88+
} catch (IOException e) {
89+
throw new UncheckedIOException("failed to parse ActionRequest into MLSearchActionRequest", e);
90+
}
91+
}
92+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
package org.opensearch.ml.common.transport.search;
2+
3+
import static org.junit.Assert.assertEquals;
4+
import static org.junit.Assert.assertNotSame;
5+
import static org.junit.Assert.assertNull;
6+
import static org.junit.Assert.assertSame;
7+
8+
import java.io.IOException;
9+
import java.io.UncheckedIOException;
10+
11+
import org.junit.Before;
12+
import org.junit.Test;
13+
import org.opensearch.Version;
14+
import org.opensearch.action.ActionRequest;
15+
import org.opensearch.action.ActionRequestValidationException;
16+
import org.opensearch.action.search.SearchRequest;
17+
import org.opensearch.common.io.stream.BytesStreamOutput;
18+
import org.opensearch.core.common.io.stream.StreamInput;
19+
import org.opensearch.core.common.io.stream.StreamOutput;
20+
21+
public class MLSearchActionRequestTest {
22+
23+
private SearchRequest searchRequest;
24+
25+
@Before
26+
public void setUp() {
27+
searchRequest = new SearchRequest("test-index");
28+
}
29+
30+
@Test
31+
public void testConstructorAndGetters() {
32+
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
33+
assertEquals("test-index", request.getSearchRequest().indices()[0]);
34+
assertEquals("test-tenant", request.getTenantId());
35+
}
36+
37+
@Test
38+
public void testStreamConstructorAndWriteTo() throws IOException {
39+
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
40+
BytesStreamOutput out = new BytesStreamOutput();
41+
request.writeTo(out);
42+
43+
MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(out.bytes().streamInput());
44+
assertEquals("test-index", deserializedRequest.getSearchRequest().indices()[0]);
45+
assertEquals("test-tenant", deserializedRequest.getTenantId());
46+
}
47+
48+
@Test
49+
public void testWriteToWithNullSearchRequest() throws IOException {
50+
MLSearchActionRequest request = MLSearchActionRequest.builder().tenantId("test-tenant").build();
51+
BytesStreamOutput out = new BytesStreamOutput();
52+
request.writeTo(out);
53+
54+
MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(out.bytes().streamInput());
55+
assertNull(deserializedRequest.getSearchRequest());
56+
assertEquals("test-tenant", deserializedRequest.getTenantId());
57+
}
58+
59+
@Test
60+
public void testFromActionRequestWithMLSearchActionRequest() {
61+
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
62+
MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(request);
63+
assertSame(result, request);
64+
}
65+
66+
@Test
67+
public void testFromActionRequestWithNonMLSearchActionRequest() throws IOException {
68+
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
69+
ActionRequest actionRequest = new ActionRequest() {
70+
@Override
71+
public ActionRequestValidationException validate() {
72+
return null;
73+
}
74+
75+
@Override
76+
public void writeTo(StreamOutput out) throws IOException {
77+
request.writeTo(out);
78+
}
79+
};
80+
81+
MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(actionRequest);
82+
assertNotSame(result, request);
83+
assertEquals(request.getSearchRequest().indices()[0], result.getSearchRequest().indices()[0]);
84+
assertEquals(request.getTenantId(), result.getTenantId());
85+
}
86+
87+
@Test(expected = UncheckedIOException.class)
88+
public void testFromActionRequestIOException() {
89+
ActionRequest actionRequest = new ActionRequest() {
90+
@Override
91+
public ActionRequestValidationException validate() {
92+
return null;
93+
}
94+
95+
@Override
96+
public void writeTo(StreamOutput out) throws IOException {
97+
throw new IOException("test");
98+
}
99+
};
100+
MLSearchActionRequest.fromActionRequest(actionRequest);
101+
}
102+
103+
@Test
104+
public void testBackwardCompatibility() throws IOException {
105+
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
106+
107+
BytesStreamOutput out = new BytesStreamOutput();
108+
out.setVersion(Version.V_2_18_0); // Older version
109+
request.writeTo(out);
110+
111+
StreamInput in = out.bytes().streamInput();
112+
in.setVersion(Version.V_2_18_0);
113+
114+
MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(in);
115+
assertNull(deserializedRequest.getTenantId()); // Ensure tenantId is ignored
116+
}
117+
118+
@Test
119+
public void testFromActionRequestWithValidRequest() {
120+
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
121+
122+
MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(request);
123+
assertSame(request, result);
124+
}
125+
126+
@Test
127+
public void testMixedVersionCompatibility() throws IOException {
128+
MLSearchActionRequest originalRequest = MLSearchActionRequest
129+
.builder()
130+
.searchRequest(searchRequest)
131+
.tenantId("test-tenant")
132+
.build();
133+
134+
// Serialize with a newer version
135+
BytesStreamOutput out = new BytesStreamOutput();
136+
out.setVersion(Version.V_2_19_0);
137+
originalRequest.writeTo(out);
138+
139+
// Deserialize with an older version
140+
StreamInput in = out.bytes().streamInput();
141+
in.setVersion(Version.V_2_18_0);
142+
143+
MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(in);
144+
assertNull(deserializedRequest.getTenantId()); // tenantId should not exist in older versions
145+
}
146+
147+
}

memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java

+6-4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.opensearch.common.util.concurrent.ThreadContext;
3131
import org.opensearch.core.action.ActionListener;
3232
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
33+
import org.opensearch.ml.common.transport.search.MLSearchActionRequest;
3334
import org.opensearch.ml.memory.ConversationalMemoryHandler;
3435
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
3536
import org.opensearch.tasks.Task;
@@ -38,7 +39,7 @@
3839
import lombok.extern.log4j.Log4j2;
3940

4041
@Log4j2
41-
public class SearchConversationsTransportAction extends HandledTransportAction<SearchRequest, SearchResponse> {
42+
public class SearchConversationsTransportAction extends HandledTransportAction<MLSearchActionRequest, SearchResponse> {
4243

4344
private ConversationalMemoryHandler cmHandler;
4445
private Client client;
@@ -61,7 +62,7 @@ public SearchConversationsTransportAction(
6162
Client client,
6263
ClusterService clusterService
6364
) {
64-
super(SearchConversationsAction.NAME, transportService, actionFilters, SearchRequest::new);
65+
super(SearchConversationsAction.NAME, transportService, actionFilters, MLSearchActionRequest::new);
6566
this.cmHandler = cmHandler;
6667
this.client = client;
6768
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
@@ -71,13 +72,14 @@ public SearchConversationsTransportAction(
7172
}
7273

7374
@Override
74-
public void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
75+
public void doExecute(Task task, MLSearchActionRequest mlSearchActionRequest, ActionListener<SearchResponse> actionListener) {
76+
SearchRequest request = mlSearchActionRequest.getSearchRequest();
7577
if (!featureIsEnabled) {
7678
actionListener.onFailure(new OpenSearchException(ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE));
7779
return;
7880
} else {
7981
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
80-
ActionListener<SearchResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
82+
ActionListener<SearchResponse> internalListener = ActionListener.runBefore(actionListener, context::restore);
8183
cmHandler.searchConversations(request, internalListener);
8284
} catch (Exception e) {
8385
log.error("Failed to search memories", e);

memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java

+6-2
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import org.opensearch.core.action.ActionListener;
4545
import org.opensearch.core.xcontent.NamedXContentRegistry;
4646
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
47+
import org.opensearch.ml.common.transport.search.MLSearchActionRequest;
4748
import org.opensearch.ml.memory.MemoryTestUtil;
4849
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
4950
import org.opensearch.test.OpenSearchTestCase;
@@ -79,12 +80,15 @@ public class SearchConversationsTransportActionTests extends OpenSearchTestCase
7980
@Mock
8081
SearchRequest request;
8182

83+
MLSearchActionRequest mlSearchActionRequest;
84+
8285
SearchConversationsTransportAction action;
8386
ThreadContext threadContext;
8487

8588
@Before
8689
public void setup() throws IOException {
8790
MockitoAnnotations.openMocks(this);
91+
mlSearchActionRequest = new MLSearchActionRequest(request, null);
8892

8993
Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build();
9094
this.threadContext = new ThreadContext(settings);
@@ -104,7 +108,7 @@ public void testEnabled_ThenSucceed() {
104108
listener.onResponse(response);
105109
return null;
106110
}).when(cmHandler).searchConversations(any(), any());
107-
action.doExecute(null, request, actionListener);
111+
action.doExecute(null, mlSearchActionRequest, actionListener);
108112
ArgumentCaptor<SearchResponse> argCaptor = ArgumentCaptor.forClass(SearchResponse.class);
109113
verify(actionListener, times(1)).onResponse(argCaptor.capture());
110114
assert (argCaptor.getValue().equals(response));
@@ -114,7 +118,7 @@ public void testDisabled_ThenFail() {
114118
clusterService = MemoryTestUtil.clusterServiceWithMemoryFeatureDisabled();
115119
this.action = spy(new SearchConversationsTransportAction(transportService, actionFilters, cmHandler, client, clusterService));
116120

117-
action.doExecute(null, request, actionListener);
121+
action.doExecute(null, mlSearchActionRequest, actionListener);
118122
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
119123
verify(actionListener).onFailure(argCaptor.capture());
120124
assertEquals(argCaptor.getValue().getMessage(), ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE);

plugin/src/main/java/org/opensearch/ml/action/agents/TransportSearchAgentAction.java

+31-7
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.opensearch.ml.action.agents;
77

88
import static org.opensearch.ml.action.handler.MLSearchHandler.wrapRestActionListener;
9+
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
910

1011
import org.opensearch.action.search.SearchRequest;
1112
import org.opensearch.action.search.SearchResponse;
@@ -20,28 +21,46 @@
2021
import org.opensearch.ml.common.CommonValue;
2122
import org.opensearch.ml.common.agent.MLAgent;
2223
import org.opensearch.ml.common.transport.agent.MLSearchAgentAction;
24+
import org.opensearch.ml.common.transport.search.MLSearchActionRequest;
25+
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
26+
import org.opensearch.ml.utils.TenantAwareHelper;
27+
import org.opensearch.remote.metadata.client.SdkClient;
2328
import org.opensearch.tasks.Task;
2429
import org.opensearch.transport.TransportService;
2530

2631
import lombok.extern.log4j.Log4j2;
2732

2833
@Log4j2
29-
public class TransportSearchAgentAction extends HandledTransportAction<SearchRequest, SearchResponse> {
34+
public class TransportSearchAgentAction extends HandledTransportAction<MLSearchActionRequest, SearchResponse> {
3035
private final Client client;
36+
private final SdkClient sdkClient;
37+
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;
3138

3239
@Inject
33-
public TransportSearchAgentAction(TransportService transportService, ActionFilters actionFilters, Client client) {
34-
super(MLSearchAgentAction.NAME, transportService, actionFilters, SearchRequest::new);
40+
public TransportSearchAgentAction(
41+
TransportService transportService,
42+
ActionFilters actionFilters,
43+
Client client,
44+
SdkClient sdkClient,
45+
MLFeatureEnabledSetting mlFeatureEnabledSetting
46+
) {
47+
super(MLSearchAgentAction.NAME, transportService, actionFilters, MLSearchActionRequest::new);
3548
this.client = client;
49+
this.sdkClient = sdkClient;
50+
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
3651
}
3752

3853
@Override
39-
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
40-
request.indices(CommonValue.ML_AGENT_INDEX);
41-
search(request, actionListener);
54+
protected void doExecute(Task task, MLSearchActionRequest request, ActionListener<SearchResponse> actionListener) {
55+
request.getSearchRequest().indices(CommonValue.ML_AGENT_INDEX);
56+
String tenantId = request.getTenantId();
57+
if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) {
58+
return;
59+
}
60+
search(request.getSearchRequest(), tenantId, actionListener);
4261
}
4362

44-
private void search(SearchRequest request, ActionListener<SearchResponse> actionListener) {
63+
private void search(SearchRequest request, String tenantId, ActionListener<SearchResponse> actionListener) {
4564
ActionListener<SearchResponse> listener = wrapRestActionListener(actionListener, "Fail to search agent");
4665
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
4766
ActionListener<SearchResponse> wrappedListener = ActionListener.runBefore(listener, context::restore);
@@ -57,6 +76,11 @@ private void search(SearchRequest request, ActionListener<SearchResponse> action
5776
// Add a should clause to include documents where IS_HIDDEN_FIELD is false
5877
shouldQuery.should(QueryBuilders.termQuery(MLAgent.IS_HIDDEN_FIELD, false));
5978

79+
// For multi-tenancy
80+
if (tenantId != null) {
81+
shouldQuery.should(QueryBuilders.termQuery(TENANT_ID_FIELD, tenantId));
82+
}
83+
6084
// Add a should clause to include documents where IS_HIDDEN_FIELD does not exist or is null
6185
shouldQuery.should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(MLAgent.IS_HIDDEN_FIELD)));
6286

0 commit comments

Comments
 (0)