Skip to content

Commit 4123cf3

Browse files
remainig sdk client changes for search (opensearch-project#3522) (opensearch-project#3525)
Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> (cherry picked from commit 5432f25) Co-authored-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent 8a1f64f commit 4123cf3

File tree

5 files changed

+170
-47
lines changed

5 files changed

+170
-47
lines changed

plugin/src/main/java/org/opensearch/ml/action/agents/TransportSearchAgentAction.java

+30-7
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
package org.opensearch.ml.action.agents;
77

88
import static org.opensearch.ml.action.handler.MLSearchHandler.wrapRestActionListener;
9-
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
109

10+
import org.opensearch.OpenSearchStatusException;
1111
import org.opensearch.action.search.SearchRequest;
1212
import org.opensearch.action.search.SearchResponse;
1313
import org.opensearch.action.support.ActionFilters;
@@ -16,6 +16,7 @@
1616
import org.opensearch.common.inject.Inject;
1717
import org.opensearch.common.util.concurrent.ThreadContext;
1818
import org.opensearch.core.action.ActionListener;
19+
import org.opensearch.core.rest.RestStatus;
1920
import org.opensearch.index.query.BoolQueryBuilder;
2021
import org.opensearch.index.query.QueryBuilders;
2122
import org.opensearch.ml.common.CommonValue;
@@ -25,6 +26,8 @@
2526
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
2627
import org.opensearch.ml.utils.TenantAwareHelper;
2728
import org.opensearch.remote.metadata.client.SdkClient;
29+
import org.opensearch.remote.metadata.client.SearchDataObjectRequest;
30+
import org.opensearch.remote.metadata.common.SdkClientUtils;
2831
import org.opensearch.tasks.Task;
2932
import org.opensearch.transport.TransportService;
3033

@@ -76,11 +79,6 @@ private void search(SearchRequest request, String tenantId, ActionListener<Searc
7679
// Add a should clause to include documents where IS_HIDDEN_FIELD is false
7780
shouldQuery.should(QueryBuilders.termQuery(MLAgent.IS_HIDDEN_FIELD, false));
7881

79-
// For multi-tenancy
80-
if (tenantId != null) {
81-
shouldQuery.should(QueryBuilders.termQuery(TENANT_ID_FIELD, tenantId));
82-
}
83-
8482
// Add a should clause to include documents where IS_HIDDEN_FIELD does not exist or is null
8583
shouldQuery.should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(MLAgent.IS_HIDDEN_FIELD)));
8684

@@ -91,7 +89,32 @@ private void search(SearchRequest request, String tenantId, ActionListener<Searc
9189
queryBuilder.filter(shouldQuery);
9290

9391
request.source().query(queryBuilder);
94-
client.search(request, wrappedListener);
92+
SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest
93+
.builder()
94+
.indices(request.indices())
95+
.searchSourceBuilder(request.source())
96+
.tenantId(tenantId)
97+
.build();
98+
99+
sdkClient.searchDataObjectAsync(searchDataObjectRequest).whenComplete((r, throwable) -> {
100+
if (throwable != null) {
101+
Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable, OpenSearchStatusException.class);
102+
log.error("Failed to search agent", cause);
103+
wrappedListener.onFailure(cause);
104+
} else {
105+
try {
106+
SearchResponse searchResponse = SearchResponse.fromXContent(r.parser());
107+
log.info("Agent search complete: {}", searchResponse.getHits().getTotalHits());
108+
wrappedListener.onResponse(searchResponse);
109+
} catch (Exception e) {
110+
log.error("Failed to parse model search response", e);
111+
wrappedListener
112+
.onFailure(
113+
new OpenSearchStatusException("Failed to parse model search response", RestStatus.INTERNAL_SERVER_ERROR)
114+
);
115+
}
116+
}
117+
});
95118
} catch (Exception e) {
96119
log.error("failed to search the agent index", e);
97120
actionListener.onFailure(e);

plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java

+68-14
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import static org.opensearch.core.rest.RestStatus.BAD_REQUEST;
99
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
10-
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
1110
import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound;
1211

1312
import java.util.ArrayList;
@@ -43,6 +42,9 @@
4342
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
4443
import org.opensearch.ml.helper.ModelAccessControlHelper;
4544
import org.opensearch.ml.utils.RestActionUtils;
45+
import org.opensearch.remote.metadata.client.SdkClient;
46+
import org.opensearch.remote.metadata.client.SearchDataObjectRequest;
47+
import org.opensearch.remote.metadata.common.SdkClientUtils;
4648
import org.opensearch.search.SearchHits;
4749
import org.opensearch.search.builder.SearchSourceBuilder;
4850
import org.opensearch.search.fetch.subphase.FetchSourceContext;
@@ -77,10 +79,11 @@ public MLSearchHandler(
7779

7880
/**
7981
* Fetch all the models from the model group index, and then create a combined query to model version index.
82+
* @param sdkClient sdkclient a wrapper of the client
8083
* @param request
8184
* @param actionListener
8285
*/
83-
public void search(SearchRequest request, String tenantId, ActionListener<SearchResponse> actionListener) {
86+
public void search(SdkClient sdkClient, SearchRequest request, String tenantId, ActionListener<SearchResponse> actionListener) {
8487
User user = RestActionUtils.getUserContext(client);
8588
ActionListener<SearchResponse> listener = wrapRestActionListener(actionListener, "Fail to search model version");
8689
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
@@ -114,11 +117,6 @@ public void search(SearchRequest request, String tenantId, ActionListener<Search
114117
// Add a should clause to include documents where IS_HIDDEN_FIELD is false
115118
shouldQuery.should(QueryBuilders.termQuery(MLModel.IS_HIDDEN_FIELD, false));
116119

117-
// For multi-tenancy
118-
if (tenantId != null) {
119-
shouldQuery.should(QueryBuilders.termQuery(TENANT_ID_FIELD, tenantId));
120-
}
121-
122120
// Add a should clause to include documents where IS_HIDDEN_FIELD does not exist or is null
123121
shouldQuery.should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(MLModel.IS_HIDDEN_FIELD)));
124122

@@ -132,10 +130,29 @@ public void search(SearchRequest request, String tenantId, ActionListener<Search
132130
request.source().fetchSource(rebuiltFetchSourceContext);
133131
final ActionListener<SearchResponse> doubleWrapperListener = ActionListener
134132
.wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener));
135-
if (modelAccessControlHelper.skipModelAccessControl(user)) {
136-
client.search(request, doubleWrapperListener);
137-
} else if (!clusterService.state().metadata().hasIndex(CommonValue.ML_MODEL_GROUP_INDEX)) {
138-
client.search(request, doubleWrapperListener);
133+
if (modelAccessControlHelper.skipModelAccessControl(user)
134+
|| !clusterService.state().metadata().hasIndex(CommonValue.ML_MODEL_GROUP_INDEX)) {
135+
136+
SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest
137+
.builder()
138+
.indices(request.indices())
139+
.searchSourceBuilder(request.source())
140+
.tenantId(tenantId)
141+
.build();
142+
sdkClient.searchDataObjectAsync(searchDataObjectRequest).whenComplete((r, throwable) -> {
143+
if (throwable == null) {
144+
try {
145+
SearchResponse searchResponse = SearchResponse.fromXContent(r.parser());
146+
log.info("Model search complete: {}", searchResponse.getHits().getTotalHits());
147+
doubleWrapperListener.onResponse(searchResponse);
148+
} catch (Exception e) {
149+
doubleWrapperListener.onFailure(e);
150+
}
151+
} else {
152+
Exception e = SdkClientUtils.unwrapAndConvertToException(throwable, OpenSearchStatusException.class);
153+
doubleWrapperListener.onFailure(e);
154+
}
155+
});
139156
} else {
140157
SearchSourceBuilder sourceBuilder = modelAccessControlHelper.createSearchSourceBuilder(user);
141158
SearchRequest modelGroupSearchRequest = new SearchRequest();
@@ -154,17 +171,54 @@ public void search(SearchRequest request, String tenantId, ActionListener<Search
154171
Arrays.stream(r.getHits().getHits()).forEach(hit -> { modelGroupIds.add(hit.getId()); });
155172

156173
request.source().query(rewriteQueryBuilder(request.source().query(), modelGroupIds));
157-
client.search(request, doubleWrapperListener);
158174
} else {
159175
log.debug("No model group found");
160176
request.source().query(rewriteQueryBuilder(request.source().query(), null));
161-
client.search(request, doubleWrapperListener);
162177
}
178+
SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest
179+
.builder()
180+
.indices(request.indices())
181+
.searchSourceBuilder(request.source())
182+
.tenantId(tenantId)
183+
.build();
184+
sdkClient.searchDataObjectAsync(searchDataObjectRequest).whenComplete((sr, throwable) -> {
185+
if (throwable == null) {
186+
try {
187+
SearchResponse searchResponse = SearchResponse.fromXContent(sr.parser());
188+
log.info("Model search complete: {}", searchResponse.getHits().getTotalHits());
189+
doubleWrapperListener.onResponse(searchResponse);
190+
} catch (Exception e) {
191+
doubleWrapperListener.onFailure(e);
192+
}
193+
} else {
194+
Exception e = SdkClientUtils.unwrapAndConvertToException(throwable, OpenSearchStatusException.class);
195+
doubleWrapperListener.onFailure(e);
196+
}
197+
});
163198
}, e -> {
164199
log.error("Fail to search model groups!", e);
165200
wrappedListener.onFailure(e);
166201
});
167-
client.search(modelGroupSearchRequest, modelGroupSearchActionListener);
202+
SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest
203+
.builder()
204+
.indices(modelGroupSearchRequest.indices())
205+
.searchSourceBuilder(modelGroupSearchRequest.source())
206+
.tenantId(tenantId)
207+
.build();
208+
sdkClient.searchDataObjectAsync(searchDataObjectRequest).whenComplete((r, throwable) -> {
209+
if (throwable == null) {
210+
try {
211+
SearchResponse searchResponse = SearchResponse.fromXContent(r.parser());
212+
log.info("Model search complete: {}", searchResponse.getHits().getTotalHits());
213+
modelGroupSearchActionListener.onResponse(searchResponse);
214+
} catch (Exception e) {
215+
modelGroupSearchActionListener.onFailure(e);
216+
}
217+
} else {
218+
Exception e = SdkClientUtils.unwrapAndConvertToException(throwable, OpenSearchStatusException.class);
219+
modelGroupSearchActionListener.onFailure(e);
220+
}
221+
});
168222
}
169223
} catch (Exception e) {
170224
log.error(e.getMessage(), e);

plugin/src/main/java/org/opensearch/ml/action/models/SearchModelTransportAction.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,6 @@ protected void doExecute(Task task, MLSearchActionRequest request, ActionListene
4848
if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) {
4949
return;
5050
}
51-
mlSearchHandler.search(request, tenantId, actionListener);
51+
mlSearchHandler.search(sdkClient, request, tenantId, actionListener);
5252
}
5353
}

0 commit comments

Comments
 (0)