7
7
8
8
import static org .opensearch .core .rest .RestStatus .BAD_REQUEST ;
9
9
import static org .opensearch .core .rest .RestStatus .INTERNAL_SERVER_ERROR ;
10
- import static org .opensearch .ml .common .CommonValue .TENANT_ID_FIELD ;
11
10
import static org .opensearch .ml .utils .RestActionUtils .wrapListenerToHandleSearchIndexNotFound ;
12
11
13
12
import java .util .ArrayList ;
43
42
import org .opensearch .ml .common .exception .MLResourceNotFoundException ;
44
43
import org .opensearch .ml .helper .ModelAccessControlHelper ;
45
44
import org .opensearch .ml .utils .RestActionUtils ;
45
+ import org .opensearch .remote .metadata .client .SdkClient ;
46
+ import org .opensearch .remote .metadata .client .SearchDataObjectRequest ;
47
+ import org .opensearch .remote .metadata .common .SdkClientUtils ;
46
48
import org .opensearch .search .SearchHits ;
47
49
import org .opensearch .search .builder .SearchSourceBuilder ;
48
50
import org .opensearch .search .fetch .subphase .FetchSourceContext ;
@@ -77,10 +79,11 @@ public MLSearchHandler(
77
79
78
80
/**
79
81
* Fetch all the models from the model group index, and then create a combined query to model version index.
82
+ * @param sdkClient sdkclient a wrapper of the client
80
83
* @param request
81
84
* @param actionListener
82
85
*/
83
- public void search (SearchRequest request , String tenantId , ActionListener <SearchResponse > actionListener ) {
86
+ public void search (SdkClient sdkClient , SearchRequest request , String tenantId , ActionListener <SearchResponse > actionListener ) {
84
87
User user = RestActionUtils .getUserContext (client );
85
88
ActionListener <SearchResponse > listener = wrapRestActionListener (actionListener , "Fail to search model version" );
86
89
try (ThreadContext .StoredContext context = client .threadPool ().getThreadContext ().stashContext ()) {
@@ -114,11 +117,6 @@ public void search(SearchRequest request, String tenantId, ActionListener<Search
114
117
// Add a should clause to include documents where IS_HIDDEN_FIELD is false
115
118
shouldQuery .should (QueryBuilders .termQuery (MLModel .IS_HIDDEN_FIELD , false ));
116
119
117
- // For multi-tenancy
118
- if (tenantId != null ) {
119
- shouldQuery .should (QueryBuilders .termQuery (TENANT_ID_FIELD , tenantId ));
120
- }
121
-
122
120
// Add a should clause to include documents where IS_HIDDEN_FIELD does not exist or is null
123
121
shouldQuery .should (QueryBuilders .boolQuery ().mustNot (QueryBuilders .existsQuery (MLModel .IS_HIDDEN_FIELD )));
124
122
@@ -132,10 +130,29 @@ public void search(SearchRequest request, String tenantId, ActionListener<Search
132
130
request .source ().fetchSource (rebuiltFetchSourceContext );
133
131
final ActionListener <SearchResponse > doubleWrapperListener = ActionListener
134
132
.wrap (wrappedListener ::onResponse , e -> wrapListenerToHandleSearchIndexNotFound (e , wrappedListener ));
135
- if (modelAccessControlHelper .skipModelAccessControl (user )) {
136
- client .search (request , doubleWrapperListener );
137
- } else if (!clusterService .state ().metadata ().hasIndex (CommonValue .ML_MODEL_GROUP_INDEX )) {
138
- client .search (request , doubleWrapperListener );
133
+ if (modelAccessControlHelper .skipModelAccessControl (user )
134
+ || !clusterService .state ().metadata ().hasIndex (CommonValue .ML_MODEL_GROUP_INDEX )) {
135
+
136
+ SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest
137
+ .builder ()
138
+ .indices (request .indices ())
139
+ .searchSourceBuilder (request .source ())
140
+ .tenantId (tenantId )
141
+ .build ();
142
+ sdkClient .searchDataObjectAsync (searchDataObjectRequest ).whenComplete ((r , throwable ) -> {
143
+ if (throwable == null ) {
144
+ try {
145
+ SearchResponse searchResponse = SearchResponse .fromXContent (r .parser ());
146
+ log .info ("Model search complete: {}" , searchResponse .getHits ().getTotalHits ());
147
+ doubleWrapperListener .onResponse (searchResponse );
148
+ } catch (Exception e ) {
149
+ doubleWrapperListener .onFailure (e );
150
+ }
151
+ } else {
152
+ Exception e = SdkClientUtils .unwrapAndConvertToException (throwable , OpenSearchStatusException .class );
153
+ doubleWrapperListener .onFailure (e );
154
+ }
155
+ });
139
156
} else {
140
157
SearchSourceBuilder sourceBuilder = modelAccessControlHelper .createSearchSourceBuilder (user );
141
158
SearchRequest modelGroupSearchRequest = new SearchRequest ();
@@ -154,17 +171,54 @@ public void search(SearchRequest request, String tenantId, ActionListener<Search
154
171
Arrays .stream (r .getHits ().getHits ()).forEach (hit -> { modelGroupIds .add (hit .getId ()); });
155
172
156
173
request .source ().query (rewriteQueryBuilder (request .source ().query (), modelGroupIds ));
157
- client .search (request , doubleWrapperListener );
158
174
} else {
159
175
log .debug ("No model group found" );
160
176
request .source ().query (rewriteQueryBuilder (request .source ().query (), null ));
161
- client .search (request , doubleWrapperListener );
162
177
}
178
+ SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest
179
+ .builder ()
180
+ .indices (request .indices ())
181
+ .searchSourceBuilder (request .source ())
182
+ .tenantId (tenantId )
183
+ .build ();
184
+ sdkClient .searchDataObjectAsync (searchDataObjectRequest ).whenComplete ((sr , throwable ) -> {
185
+ if (throwable == null ) {
186
+ try {
187
+ SearchResponse searchResponse = SearchResponse .fromXContent (sr .parser ());
188
+ log .info ("Model search complete: {}" , searchResponse .getHits ().getTotalHits ());
189
+ doubleWrapperListener .onResponse (searchResponse );
190
+ } catch (Exception e ) {
191
+ doubleWrapperListener .onFailure (e );
192
+ }
193
+ } else {
194
+ Exception e = SdkClientUtils .unwrapAndConvertToException (throwable , OpenSearchStatusException .class );
195
+ doubleWrapperListener .onFailure (e );
196
+ }
197
+ });
163
198
}, e -> {
164
199
log .error ("Fail to search model groups!" , e );
165
200
wrappedListener .onFailure (e );
166
201
});
167
- client .search (modelGroupSearchRequest , modelGroupSearchActionListener );
202
+ SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest
203
+ .builder ()
204
+ .indices (modelGroupSearchRequest .indices ())
205
+ .searchSourceBuilder (modelGroupSearchRequest .source ())
206
+ .tenantId (tenantId )
207
+ .build ();
208
+ sdkClient .searchDataObjectAsync (searchDataObjectRequest ).whenComplete ((r , throwable ) -> {
209
+ if (throwable == null ) {
210
+ try {
211
+ SearchResponse searchResponse = SearchResponse .fromXContent (r .parser ());
212
+ log .info ("Model search complete: {}" , searchResponse .getHits ().getTotalHits ());
213
+ modelGroupSearchActionListener .onResponse (searchResponse );
214
+ } catch (Exception e ) {
215
+ modelGroupSearchActionListener .onFailure (e );
216
+ }
217
+ } else {
218
+ Exception e = SdkClientUtils .unwrapAndConvertToException (throwable , OpenSearchStatusException .class );
219
+ modelGroupSearchActionListener .onFailure (e );
220
+ }
221
+ });
168
222
}
169
223
} catch (Exception e ) {
170
224
log .error (e .getMessage (), e );
0 commit comments