Skip to content

Commit 9d4670d

Browse files
committed
applying multi-tenancy in search [model, model group, agent, connector] (opensearch-project#3433)
* applying multi-tenancy in search Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * addressed comments Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> --------- Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent a852152 commit 9d4670d

32 files changed

+1196
-172
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
package org.opensearch.ml.common.transport.search;
2+
3+
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;
4+
5+
import java.io.ByteArrayInputStream;
6+
import java.io.ByteArrayOutputStream;
7+
import java.io.IOException;
8+
import java.io.UncheckedIOException;
9+
10+
import org.opensearch.Version;
11+
import org.opensearch.action.ActionRequest;
12+
import org.opensearch.action.search.SearchRequest;
13+
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
14+
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
15+
import org.opensearch.core.common.io.stream.StreamInput;
16+
import org.opensearch.core.common.io.stream.StreamOutput;
17+
18+
import lombok.Builder;
19+
import lombok.Getter;
20+
21+
/**
22+
* Represents an extended search action request that includes a tenant ID.
23+
* This class allows OpenSearch to include a tenant ID in search requests,
24+
* which is not natively supported in the standard {@link SearchRequest}.
25+
*/
26+
@Getter
27+
public class MLSearchActionRequest extends SearchRequest {
28+
SearchRequest searchRequest;
29+
String tenantId;
30+
31+
/**
32+
* Constructor for building an MLSearchActionRequest.
33+
*
34+
* @param searchRequest The original {@link SearchRequest} to be wrapped.
35+
* @param tenantId The tenant ID associated with the request.
36+
*/
37+
@Builder
38+
public MLSearchActionRequest(SearchRequest searchRequest, String tenantId) {
39+
this.searchRequest = searchRequest;
40+
this.tenantId = tenantId;
41+
}
42+
43+
/**
44+
* Deserializes an {@link MLSearchActionRequest} from a {@link StreamInput}.
45+
*
46+
* @param input The stream input to read from.
47+
* @throws IOException If an I/O error occurs during deserialization.
48+
*/
49+
public MLSearchActionRequest(StreamInput input) throws IOException {
50+
super(input);
51+
Version streamInputVersion = input.getVersion();
52+
if (input.readBoolean()) {
53+
searchRequest = new SearchRequest(input);
54+
}
55+
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
56+
}
57+
58+
/**
59+
* Serializes this {@link MLSearchActionRequest} to a {@link StreamOutput}.
60+
*
61+
* @param output The stream output to write to.
62+
* @throws IOException If an I/O error occurs during serialization.
63+
*/
64+
@Override
65+
public void writeTo(StreamOutput output) throws IOException {
66+
super.writeTo(output);
67+
Version streamOutputVersion = output.getVersion();
68+
if (searchRequest != null) {
69+
output.writeBoolean(true); // user exists
70+
searchRequest.writeTo(output);
71+
} else {
72+
output.writeBoolean(false); // user does not exist
73+
}
74+
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
75+
output.writeOptionalString(tenantId);
76+
}
77+
}
78+
79+
/**
80+
* Converts a generic {@link ActionRequest} into an {@link MLSearchActionRequest}.
81+
* This is useful when handling requests that may need to be converted for compatibility.
82+
*
83+
* @param actionRequest The original {@link ActionRequest}.
84+
* @return The converted {@link MLSearchActionRequest}.
85+
* @throws UncheckedIOException If the conversion fails due to an I/O error.
86+
*/
87+
public static MLSearchActionRequest fromActionRequest(ActionRequest actionRequest) {
88+
if (actionRequest instanceof MLSearchActionRequest) {
89+
return (MLSearchActionRequest) actionRequest;
90+
}
91+
92+
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
93+
actionRequest.writeTo(osso);
94+
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
95+
return new MLSearchActionRequest(input);
96+
}
97+
} catch (IOException e) {
98+
throw new UncheckedIOException("failed to parse ActionRequest into MLSearchActionRequest", e);
99+
}
100+
}
101+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
package org.opensearch.ml.common.transport.search;
2+
3+
import static org.junit.Assert.assertEquals;
4+
import static org.junit.Assert.assertNotSame;
5+
import static org.junit.Assert.assertNull;
6+
import static org.junit.Assert.assertSame;
7+
8+
import java.io.IOException;
9+
import java.io.UncheckedIOException;
10+
11+
import org.junit.Before;
12+
import org.junit.Test;
13+
import org.opensearch.Version;
14+
import org.opensearch.action.ActionRequest;
15+
import org.opensearch.action.ActionRequestValidationException;
16+
import org.opensearch.action.search.SearchRequest;
17+
import org.opensearch.common.io.stream.BytesStreamOutput;
18+
import org.opensearch.core.common.io.stream.StreamInput;
19+
import org.opensearch.core.common.io.stream.StreamOutput;
20+
21+
public class MLSearchActionRequestTest {
22+
23+
private SearchRequest searchRequest;
24+
25+
@Before
26+
public void setUp() {
27+
searchRequest = new SearchRequest("test-index");
28+
}
29+
30+
@Test
31+
public void testConstructorAndGetters() {
32+
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
33+
assertEquals("test-index", request.getSearchRequest().indices()[0]);
34+
assertEquals("test-tenant", request.getTenantId());
35+
}
36+
37+
@Test
38+
public void testStreamConstructorAndWriteTo() throws IOException {
39+
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
40+
BytesStreamOutput out = new BytesStreamOutput();
41+
request.writeTo(out);
42+
43+
MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(out.bytes().streamInput());
44+
assertEquals("test-index", deserializedRequest.getSearchRequest().indices()[0]);
45+
assertEquals("test-tenant", deserializedRequest.getTenantId());
46+
}
47+
48+
@Test
49+
public void testWriteToWithNullSearchRequest() throws IOException {
50+
MLSearchActionRequest request = MLSearchActionRequest.builder().tenantId("test-tenant").build();
51+
BytesStreamOutput out = new BytesStreamOutput();
52+
request.writeTo(out);
53+
54+
MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(out.bytes().streamInput());
55+
assertNull(deserializedRequest.getSearchRequest());
56+
assertEquals("test-tenant", deserializedRequest.getTenantId());
57+
}
58+
59+
@Test
60+
public void testFromActionRequestWithMLSearchActionRequest() {
61+
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
62+
MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(request);
63+
assertSame(result, request);
64+
}
65+
66+
@Test
67+
public void testFromActionRequestWithNonMLSearchActionRequest() throws IOException {
68+
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
69+
ActionRequest actionRequest = new ActionRequest() {
70+
@Override
71+
public ActionRequestValidationException validate() {
72+
return null;
73+
}
74+
75+
@Override
76+
public void writeTo(StreamOutput out) throws IOException {
77+
request.writeTo(out);
78+
}
79+
};
80+
81+
MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(actionRequest);
82+
assertNotSame(result, request);
83+
assertEquals(request.getSearchRequest().indices()[0], result.getSearchRequest().indices()[0]);
84+
assertEquals(request.getTenantId(), result.getTenantId());
85+
}
86+
87+
@Test(expected = UncheckedIOException.class)
88+
public void testFromActionRequestIOException() {
89+
ActionRequest actionRequest = new ActionRequest() {
90+
@Override
91+
public ActionRequestValidationException validate() {
92+
return null;
93+
}
94+
95+
@Override
96+
public void writeTo(StreamOutput out) throws IOException {
97+
throw new IOException("test");
98+
}
99+
};
100+
MLSearchActionRequest.fromActionRequest(actionRequest);
101+
}
102+
103+
@Test
104+
public void testBackwardCompatibility() throws IOException {
105+
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
106+
107+
BytesStreamOutput out = new BytesStreamOutput();
108+
out.setVersion(Version.V_2_18_0); // Older version
109+
request.writeTo(out);
110+
111+
StreamInput in = out.bytes().streamInput();
112+
in.setVersion(Version.V_2_18_0);
113+
114+
MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(in);
115+
assertNull(deserializedRequest.getTenantId()); // Ensure tenantId is ignored
116+
}
117+
118+
@Test
119+
public void testFromActionRequestWithValidRequest() {
120+
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
121+
122+
MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(request);
123+
assertSame(request, result);
124+
}
125+
126+
@Test
127+
public void testMixedVersionCompatibility() throws IOException {
128+
MLSearchActionRequest originalRequest = MLSearchActionRequest
129+
.builder()
130+
.searchRequest(searchRequest)
131+
.tenantId("test-tenant")
132+
.build();
133+
134+
// Serialize with a newer version
135+
BytesStreamOutput out = new BytesStreamOutput();
136+
out.setVersion(Version.V_2_19_0);
137+
originalRequest.writeTo(out);
138+
139+
// Deserialize with an older version
140+
StreamInput in = out.bytes().streamInput();
141+
in.setVersion(Version.V_2_18_0);
142+
143+
MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(in);
144+
assertNull(deserializedRequest.getTenantId()); // tenantId should not exist in older versions
145+
}
146+
147+
}

memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java

+6-4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.opensearch.common.util.concurrent.ThreadContext;
3131
import org.opensearch.core.action.ActionListener;
3232
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
33+
import org.opensearch.ml.common.transport.search.MLSearchActionRequest;
3334
import org.opensearch.ml.memory.ConversationalMemoryHandler;
3435
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
3536
import org.opensearch.tasks.Task;
@@ -38,7 +39,7 @@
3839
import lombok.extern.log4j.Log4j2;
3940

4041
@Log4j2
41-
public class SearchConversationsTransportAction extends HandledTransportAction<SearchRequest, SearchResponse> {
42+
public class SearchConversationsTransportAction extends HandledTransportAction<MLSearchActionRequest, SearchResponse> {
4243

4344
private ConversationalMemoryHandler cmHandler;
4445
private Client client;
@@ -61,7 +62,7 @@ public SearchConversationsTransportAction(
6162
Client client,
6263
ClusterService clusterService
6364
) {
64-
super(SearchConversationsAction.NAME, transportService, actionFilters, SearchRequest::new);
65+
super(SearchConversationsAction.NAME, transportService, actionFilters, MLSearchActionRequest::new);
6566
this.cmHandler = cmHandler;
6667
this.client = client;
6768
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
@@ -71,13 +72,14 @@ public SearchConversationsTransportAction(
7172
}
7273

7374
@Override
74-
public void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
75+
public void doExecute(Task task, MLSearchActionRequest mlSearchActionRequest, ActionListener<SearchResponse> actionListener) {
76+
SearchRequest request = mlSearchActionRequest.getSearchRequest();
7577
if (!featureIsEnabled) {
7678
actionListener.onFailure(new OpenSearchException(ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE));
7779
return;
7880
} else {
7981
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
80-
ActionListener<SearchResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
82+
ActionListener<SearchResponse> internalListener = ActionListener.runBefore(actionListener, context::restore);
8183
cmHandler.searchConversations(request, internalListener);
8284
} catch (Exception e) {
8385
log.error("Failed to search memories", e);

memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java

+6-2
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import org.opensearch.core.action.ActionListener;
4545
import org.opensearch.core.xcontent.NamedXContentRegistry;
4646
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
47+
import org.opensearch.ml.common.transport.search.MLSearchActionRequest;
4748
import org.opensearch.ml.memory.MemoryTestUtil;
4849
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
4950
import org.opensearch.test.OpenSearchTestCase;
@@ -79,12 +80,15 @@ public class SearchConversationsTransportActionTests extends OpenSearchTestCase
7980
@Mock
8081
SearchRequest request;
8182

83+
MLSearchActionRequest mlSearchActionRequest;
84+
8285
SearchConversationsTransportAction action;
8386
ThreadContext threadContext;
8487

8588
@Before
8689
public void setup() throws IOException {
8790
MockitoAnnotations.openMocks(this);
91+
mlSearchActionRequest = new MLSearchActionRequest(request, null);
8892

8993
Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build();
9094
this.threadContext = new ThreadContext(settings);
@@ -104,7 +108,7 @@ public void testEnabled_ThenSucceed() {
104108
listener.onResponse(response);
105109
return null;
106110
}).when(cmHandler).searchConversations(any(), any());
107-
action.doExecute(null, request, actionListener);
111+
action.doExecute(null, mlSearchActionRequest, actionListener);
108112
ArgumentCaptor<SearchResponse> argCaptor = ArgumentCaptor.forClass(SearchResponse.class);
109113
verify(actionListener, times(1)).onResponse(argCaptor.capture());
110114
assert (argCaptor.getValue().equals(response));
@@ -114,7 +118,7 @@ public void testDisabled_ThenFail() {
114118
clusterService = MemoryTestUtil.clusterServiceWithMemoryFeatureDisabled();
115119
this.action = spy(new SearchConversationsTransportAction(transportService, actionFilters, cmHandler, client, clusterService));
116120

117-
action.doExecute(null, request, actionListener);
121+
action.doExecute(null, mlSearchActionRequest, actionListener);
118122
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
119123
verify(actionListener).onFailure(argCaptor.capture());
120124
assertEquals(argCaptor.getValue().getMessage(), ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE);

0 commit comments

Comments
 (0)