Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.19] [BACKPORT 2.x] applying multi-tenancy in search [model, model group, agent, connector] (#3433) #3469

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package org.opensearch.ml.common.transport.search;

import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import org.opensearch.Version;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

import lombok.Builder;
import lombok.Getter;

/**
* Represents an extended search action request that includes a tenant ID.
* This class allows OpenSearch to include a tenant ID in search requests,
* which is not natively supported in the standard {@link SearchRequest}.
*/
@Getter
public class MLSearchActionRequest extends SearchRequest {
String tenantId;

/**
* Constructor for building an MLSearchActionRequest.
*
* @param searchRequest The original {@link SearchRequest} to be wrapped.
* @param tenantId The tenant ID associated with the request.
*/
@Builder
public MLSearchActionRequest(SearchRequest searchRequest, String tenantId) {
super(searchRequest);
this.tenantId = tenantId;
}

/**
* Deserializes an {@link MLSearchActionRequest} from a {@link StreamInput}.
*
* @param input The stream input to read from.
* @throws IOException If an I/O error occurs during deserialization.
*/
public MLSearchActionRequest(StreamInput input) throws IOException {
super(input);
Version streamInputVersion = input.getVersion();
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;

}

/**
* Serializes this {@link MLSearchActionRequest} to a {@link StreamOutput}.
*
* @param output The stream output to write to.
* @throws IOException If an I/O error occurs during serialization.
*/
@Override
public void writeTo(StreamOutput output) throws IOException {
super.writeTo(output);
Version streamOutputVersion = output.getVersion();
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
output.writeOptionalString(tenantId);
}
}

/**
* Converts a generic {@link ActionRequest} into an {@link MLSearchActionRequest}.
* This is useful when handling requests that may need to be converted for compatibility.
*
* @param actionRequest The original {@link ActionRequest}.
* @return The converted {@link MLSearchActionRequest}.
* @throws UncheckedIOException If the conversion fails due to an I/O error.
*/
public static MLSearchActionRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof MLSearchActionRequest) {
return (MLSearchActionRequest) actionRequest;
}

if (actionRequest instanceof SearchRequest) {
return MLSearchActionRequest
.builder()
.searchRequest((SearchRequest) actionRequest)
.tenantId(null) // No tenant ID in the original request
.build();
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLSearchActionRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionRequest into MLSearchActionRequest", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package org.opensearch.ml.common.transport.search;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;

import java.io.IOException;

import org.junit.Before;
import org.junit.Test;
import org.opensearch.Version;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;

public class MLSearchActionRequestTest {

private SearchRequest searchRequest;

@Before
public void setUp() {
searchRequest = new SearchRequest("test-index");
}

@Test
public void testSerializationDeserialization_Version_2_19_0() throws IOException {
// Set up a valid SearchRequest
SearchRequest searchRequest = new SearchRequest("test-index");

// Create the MLSearchActionRequest
MLSearchActionRequest originalRequest = MLSearchActionRequest
.builder()
.searchRequest(searchRequest)
.tenantId("test-tenant")
.build();

BytesStreamOutput out = new BytesStreamOutput();
out.setVersion(Version.V_2_19_0);
originalRequest.writeTo(out);

StreamInput in = out.bytes().streamInput();
in.setVersion(Version.V_2_19_0);
MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(in);

assertEquals("test-tenant", deserializedRequest.getTenantId());
}

@Test
public void testSerializationDeserialization_Version_2_18_0() throws IOException {

// Create the MLSearchActionRequest
MLSearchActionRequest originalRequest = MLSearchActionRequest
.builder()
.searchRequest(searchRequest)
.tenantId("test-tenant")
.build();

BytesStreamOutput out = new BytesStreamOutput();
out.setVersion(Version.V_2_18_0);
originalRequest.writeTo(out);

StreamInput in = out.bytes().streamInput();
in.setVersion(Version.V_2_18_0);
MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(in);

assertNull(deserializedRequest.getTenantId());
}

@Test
public void testFromActionRequest_WithMLSearchActionRequest() {
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();

MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(request);

assertSame(request, result);
}

@Test
public void testFromActionRequest_WithSearchRequest() throws IOException {
SearchRequest simpleRequest = new SearchRequest("test-index");

MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(simpleRequest);

assertNotNull(result);
assertNull(result.getTenantId()); // Since tenantId wasn't in original request
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE;

import org.opensearch.OpenSearchException;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
Expand All @@ -30,6 +29,7 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.common.transport.search.MLSearchActionRequest;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
Expand All @@ -38,7 +38,7 @@
import lombok.extern.log4j.Log4j2;

@Log4j2
public class SearchConversationsTransportAction extends HandledTransportAction<SearchRequest, SearchResponse> {
public class SearchConversationsTransportAction extends HandledTransportAction<MLSearchActionRequest, SearchResponse> {

private ConversationalMemoryHandler cmHandler;
private Client client;
Expand All @@ -61,7 +61,7 @@ public SearchConversationsTransportAction(
Client client,
ClusterService clusterService
) {
super(SearchConversationsAction.NAME, transportService, actionFilters, SearchRequest::new);
super(SearchConversationsAction.NAME, transportService, actionFilters, MLSearchActionRequest::new);
this.cmHandler = cmHandler;
this.client = client;
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
Expand All @@ -71,14 +71,14 @@ public SearchConversationsTransportAction(
}

@Override
public void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
public void doExecute(Task task, MLSearchActionRequest mlSearchActionRequest, ActionListener<SearchResponse> actionListener) {
if (!featureIsEnabled) {
actionListener.onFailure(new OpenSearchException(ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE));
return;
} else {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<SearchResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
cmHandler.searchConversations(request, internalListener);
ActionListener<SearchResponse> internalListener = ActionListener.runBefore(actionListener, context::restore);
cmHandler.searchConversations(mlSearchActionRequest, internalListener);
} catch (Exception e) {
log.error("Failed to search memories", e);
actionListener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.common.transport.search.MLSearchActionRequest;
import org.opensearch.ml.memory.MemoryTestUtil;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.test.OpenSearchTestCase;
Expand Down Expand Up @@ -79,12 +80,15 @@ public class SearchConversationsTransportActionTests extends OpenSearchTestCase
@Mock
SearchRequest request;

MLSearchActionRequest mlSearchActionRequest;

SearchConversationsTransportAction action;
ThreadContext threadContext;

@Before
public void setup() throws IOException {
MockitoAnnotations.openMocks(this);
mlSearchActionRequest = new MLSearchActionRequest(request, null);

Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build();
this.threadContext = new ThreadContext(settings);
Expand All @@ -104,7 +108,7 @@ public void testEnabled_ThenSucceed() {
listener.onResponse(response);
return null;
}).when(cmHandler).searchConversations(any(), any());
action.doExecute(null, request, actionListener);
action.doExecute(null, mlSearchActionRequest, actionListener);
ArgumentCaptor<SearchResponse> argCaptor = ArgumentCaptor.forClass(SearchResponse.class);
verify(actionListener, times(1)).onResponse(argCaptor.capture());
assert (argCaptor.getValue().equals(response));
Expand All @@ -114,7 +118,7 @@ public void testDisabled_ThenFail() {
clusterService = MemoryTestUtil.clusterServiceWithMemoryFeatureDisabled();
this.action = spy(new SearchConversationsTransportAction(transportService, actionFilters, cmHandler, client, clusterService));

action.doExecute(null, request, actionListener);
action.doExecute(null, mlSearchActionRequest, actionListener);
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argCaptor.capture());
assertEquals(argCaptor.getValue().getMessage(), ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.action.agents;

import static org.opensearch.ml.action.handler.MLSearchHandler.wrapRestActionListener;
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
Expand All @@ -20,28 +21,46 @@
import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.transport.agent.MLSearchAgentAction;
import org.opensearch.ml.common.transport.search.MLSearchActionRequest;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.utils.TenantAwareHelper;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class TransportSearchAgentAction extends HandledTransportAction<SearchRequest, SearchResponse> {
public class TransportSearchAgentAction extends HandledTransportAction<MLSearchActionRequest, SearchResponse> {
private final Client client;
private final SdkClient sdkClient;
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public TransportSearchAgentAction(TransportService transportService, ActionFilters actionFilters, Client client) {
super(MLSearchAgentAction.NAME, transportService, actionFilters, SearchRequest::new);
public TransportSearchAgentAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
SdkClient sdkClient,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(MLSearchAgentAction.NAME, transportService, actionFilters, MLSearchActionRequest::new);
this.client = client;
this.sdkClient = sdkClient;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
protected void doExecute(Task task, MLSearchActionRequest request, ActionListener<SearchResponse> actionListener) {
request.indices(CommonValue.ML_AGENT_INDEX);
search(request, actionListener);
String tenantId = request.getTenantId();
if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) {
return;
}
search(request, tenantId, actionListener);
}

private void search(SearchRequest request, ActionListener<SearchResponse> actionListener) {
private void search(SearchRequest request, String tenantId, ActionListener<SearchResponse> actionListener) {
ActionListener<SearchResponse> listener = wrapRestActionListener(actionListener, "Fail to search agent");
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<SearchResponse> wrappedListener = ActionListener.runBefore(listener, context::restore);
Expand All @@ -57,6 +76,11 @@ private void search(SearchRequest request, ActionListener<SearchResponse> action
// Add a should clause to include documents where IS_HIDDEN_FIELD is false
shouldQuery.should(QueryBuilders.termQuery(MLAgent.IS_HIDDEN_FIELD, false));

// For multi-tenancy
if (tenantId != null) {
shouldQuery.should(QueryBuilders.termQuery(TENANT_ID_FIELD, tenantId));
}

// Add a should clause to include documents where IS_HIDDEN_FIELD does not exist or is null
shouldQuery.should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(MLAgent.IS_HIDDEN_FIELD)));

Expand Down
Loading
Loading