Skip to content

Commit 8fea510

Browse files
committed
add UT
Signed-off-by: xinyual <xinyual@amazon.com>
1 parent 9f7947e commit 8fea510

File tree

1 file changed

+127
-72
lines changed

1 file changed

+127
-72
lines changed

plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java

+127-72
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,20 @@
1818
import static org.opensearch.ml.action.models.DeleteModelTransportAction.OS_STATUS_EXCEPTION_MESSAGE;
1919
import static org.opensearch.ml.action.models.DeleteModelTransportAction.SEARCH_FAILURE_MSG;
2020
import static org.opensearch.ml.action.models.DeleteModelTransportAction.TIMEOUT_MSG;
21+
import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX;
2122
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
2223

24+
import java.io.ByteArrayOutputStream;
2325
import java.io.IOException;
26+
import java.io.ObjectOutputStream;
27+
import java.nio.ByteBuffer;
28+
import java.nio.charset.StandardCharsets;
2429
import java.util.ArrayList;
2530
import java.util.Arrays;
31+
import java.util.List;
32+
import java.util.Map;
2633

34+
import org.apache.lucene.search.TotalHits;
2735
import org.junit.Before;
2836
import org.junit.Ignore;
2937
import org.junit.Rule;
@@ -36,14 +44,26 @@
3644
import org.opensearch.action.bulk.BulkItemResponse;
3745
import org.opensearch.action.delete.DeleteResponse;
3846
import org.opensearch.action.get.GetResponse;
47+
import org.opensearch.action.ingest.GetPipelineAction;
48+
import org.opensearch.action.ingest.GetPipelineResponse;
49+
import org.opensearch.action.search.GetSearchPipelineAction;
50+
import org.opensearch.action.search.GetSearchPipelineResponse;
51+
import org.opensearch.action.search.SearchRequest;
52+
import org.opensearch.action.search.SearchResponse;
3953
import org.opensearch.action.support.ActionFilters;
4054
import org.opensearch.client.Client;
55+
import org.opensearch.client.Response;
4156
import org.opensearch.cluster.service.ClusterService;
4257
import org.opensearch.common.settings.Settings;
4358
import org.opensearch.common.util.concurrent.ThreadContext;
4459
import org.opensearch.common.xcontent.XContentFactory;
60+
import org.opensearch.common.xcontent.XContentType;
4561
import org.opensearch.core.action.ActionListener;
62+
import org.opensearch.core.action.ActionResponse;
63+
import org.opensearch.core.common.bytes.BytesArray;
4664
import org.opensearch.core.common.bytes.BytesReference;
65+
import org.opensearch.core.xcontent.MediaType;
66+
import org.opensearch.core.xcontent.MediaTypeRegistry;
4767
import org.opensearch.core.xcontent.NamedXContentRegistry;
4868
import org.opensearch.core.xcontent.ToXContent;
4969
import org.opensearch.core.xcontent.XContentBuilder;
@@ -58,6 +78,9 @@
5878
import org.opensearch.ml.engine.tools.RelatedModelIdHelper;
5979
import org.opensearch.ml.helper.ModelAccessControlHelper;
6080
import org.opensearch.ml.model.MLModelManager;
81+
import org.opensearch.search.SearchHit;
82+
import org.opensearch.search.SearchHits;
83+
import org.opensearch.search.pipeline.PipelineConfiguration;
6184
import org.opensearch.test.OpenSearchTestCase;
6285
import org.opensearch.threadpool.ThreadPool;
6386
import org.opensearch.transport.TransportService;
@@ -84,6 +107,9 @@ public class DeleteModelTransportActionTests extends OpenSearchTestCase {
84107
@Mock
85108
BulkByScrollResponse bulkByScrollResponse;
86109

110+
@Mock
111+
SearchResponse searchResponse;
112+
87113
@Mock
88114
NamedXContentRegistry xContentRegistry;
89115

@@ -107,6 +133,20 @@ public class DeleteModelTransportActionTests extends OpenSearchTestCase {
107133
@Mock
108134
private RelatedModelIdHelper relatedModelIdHelper;
109135

136+
@Mock
137+
private GetSearchPipelineResponse getSearchPipelineResponse;
138+
139+
@Mock
140+
private org.opensearch.search.pipeline.PipelineConfiguration searchPipelineConfiguration;
141+
142+
143+
@Mock
144+
GetPipelineResponse getIngestionPipelineResponse;
145+
146+
private BulkByScrollResponse emptyBulkByScrollResponse;
147+
148+
private Map<String, Object> configDataMap;
149+
110150
@Before
111151
public void setup() throws IOException {
112152
MockitoAnnotations.openMocks(this);
@@ -137,6 +177,7 @@ public void setup() throws IOException {
137177
when(clusterService.getSettings()).thenReturn(settings);
138178
when(client.threadPool()).thenReturn(threadPool);
139179
when(threadPool.getThreadContext()).thenReturn(threadContext);
180+
prepare();
140181
}
141182

142183
public void testDeleteModel_Success() throws IOException {
@@ -146,13 +187,6 @@ public void testDeleteModel_Success() throws IOException {
146187
return null;
147188
}).when(client).delete(any(), any());
148189

149-
doAnswer(invocation -> {
150-
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
151-
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
152-
listener.onResponse(response);
153-
return null;
154-
}).when(client).execute(any(), any(), any());
155-
156190
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null, false);
157191
doAnswer(invocation -> {
158192
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
@@ -164,20 +198,63 @@ public void testDeleteModel_Success() throws IOException {
164198
verify(actionListener).onResponse(deleteResponse);
165199
}
166200

201+
public void testDeleteModel_BlockedBySearchPipeline() throws IOException {
202+
//org.opensearch.search.pipeline.PipelineConfiguration pipelineConfiguration = new PipelineConfiguration();
203+
when(searchPipelineConfiguration.getId()).thenReturn("1");
204+
when(searchPipelineConfiguration.getConfigAsMap()).thenReturn(configDataMap);
205+
when(getSearchPipelineResponse.pipelines()).thenReturn(List.of(searchPipelineConfiguration));
206+
doAnswer(invocation -> {
207+
ActionListener<GetSearchPipelineResponse> listener = invocation.getArgument(2);
208+
listener.onResponse(getSearchPipelineResponse);
209+
return null;
210+
}).when(client).execute(eq(GetSearchPipelineAction.INSTANCE), any(), any());
211+
212+
deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
213+
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
214+
verify(actionListener).onFailure(argumentCaptor.capture());
215+
assertEquals("1 search pipelines are still using this model, please delete or update the pipelines first: [1]", argumentCaptor.getValue().getMessage());
216+
}
217+
218+
public void testDeleteModel_BlockedByIngestPipeline() throws IOException {
219+
org.opensearch.ingest.PipelineConfiguration ingestPipelineConfiguration = new org.opensearch.ingest.PipelineConfiguration(
220+
"1", new BytesArray("{\"model_id\": \"test_id\"}".getBytes(StandardCharsets.UTF_8)), MediaTypeRegistry.JSON
221+
);
222+
when(getIngestionPipelineResponse.pipelines()).thenReturn(List.of(ingestPipelineConfiguration));
223+
doAnswer(invocation -> {
224+
ActionListener<GetPipelineResponse> listener = invocation.getArgument(2);
225+
listener.onResponse(getIngestionPipelineResponse);
226+
return null;
227+
}).when(client).execute(eq(GetPipelineAction.INSTANCE), any(), any());
228+
229+
deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
230+
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
231+
verify(actionListener).onFailure(argumentCaptor.capture());
232+
assertEquals("1 ingest pipelines are still using this model, please delete or update the pipelines first: [1]", argumentCaptor.getValue().getMessage());
233+
}
234+
235+
public void testDeleteModel_BlockedByAgent() throws IOException {
236+
SearchHit hit = new SearchHit(1, "1", null, null);
237+
SearchHits searchHits = new SearchHits(new SearchHit[] { hit }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0f);
238+
when(searchResponse.getHits()).thenReturn(searchHits);
239+
doAnswer(invocation -> {
240+
ActionListener<SearchResponse> listener = invocation.getArgument(1);
241+
listener.onResponse(searchResponse);
242+
return null;
243+
}).when(client).search(any(), any());
244+
245+
deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
246+
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
247+
verify(actionListener).onFailure(argumentCaptor.capture());
248+
assertEquals("1 agents are still using this model, please delete or update the agents first: [1]", argumentCaptor.getValue().getMessage());
249+
}
250+
167251
public void testDeleteRemoteModel_Success() throws IOException {
168252
doAnswer(invocation -> {
169253
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
170254
listener.onResponse(deleteResponse);
171255
return null;
172256
}).when(client).delete(any(), any());
173257

174-
doAnswer(invocation -> {
175-
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
176-
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
177-
listener.onResponse(response);
178-
return null;
179-
}).when(client).execute(any(), any(), any());
180-
181258
GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.REMOTE);
182259
doAnswer(invocation -> {
183260
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
@@ -200,12 +277,6 @@ public void testDeleteRemoteModel_deleteModelController_failed() throws IOExcept
200277
return null;
201278
}).when(client).delete(any(), any());
202279

203-
doAnswer(invocation -> {
204-
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
205-
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
206-
listener.onResponse(response);
207-
return null;
208-
}).when(client).execute(any(), any(), any());
209280

210281
GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.REMOTE);
211282
doAnswer(invocation -> {
@@ -231,13 +302,6 @@ public void testDeleteLocalModel_deleteModelController_failed() throws IOExcepti
231302
return null;
232303
}).when(client).delete(any(), any());
233304

234-
doAnswer(invocation -> {
235-
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
236-
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
237-
listener.onResponse(response);
238-
return null;
239-
}).when(client).execute(any(), any(), any());
240-
241305
GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.TEXT_EMBEDDING);
242306
doAnswer(invocation -> {
243307
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
@@ -284,13 +348,6 @@ public void testDeleteHiddenModel_Success() throws IOException {
284348
return null;
285349
}).when(client).delete(any(), any());
286350

287-
doAnswer(invocation -> {
288-
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
289-
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
290-
listener.onResponse(response);
291-
return null;
292-
}).when(client).execute(any(), any(), any());
293-
294351
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null, true);
295352
doAnswer(invocation -> {
296353
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
@@ -310,13 +367,6 @@ public void testDeleteHiddenModel_NoSuperAdminPermission() throws IOException {
310367
return null;
311368
}).when(client).delete(any(), any());
312369

313-
doAnswer(invocation -> {
314-
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
315-
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
316-
listener.onResponse(response);
317-
return null;
318-
}).when(client).execute(any(), any(), any());
319-
320370
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null, true);
321371
doAnswer(invocation -> {
322372
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
@@ -338,13 +388,6 @@ public void testDeleteModel_Success_AlgorithmNotNull() throws IOException {
338388
return null;
339389
}).when(client).delete(any(), any());
340390

341-
doAnswer(invocation -> {
342-
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
343-
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
344-
listener.onResponse(response);
345-
return null;
346-
}).when(client).execute(any(), any(), any());
347-
348391
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null, false);
349392
doAnswer(invocation -> {
350393
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
@@ -417,12 +460,6 @@ public void testDeleteModel_deleteModelController_ResourceNotFoundException() th
417460
return null;
418461
}).when(client).delete(any(), any());
419462

420-
doAnswer(invocation -> {
421-
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
422-
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
423-
listener.onResponse(response);
424-
return null;
425-
}).when(client).execute(any(), any(), any());
426463

427464
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null, false);
428465
doAnswer(invocation -> {
@@ -467,12 +504,6 @@ public void testDeleteRemoteModel_modelNotFound_ResourceNotFoundException() thro
467504
return null;
468505
}).when(client).delete(any(), any());
469506

470-
doAnswer(invocation -> {
471-
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
472-
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
473-
listener.onResponse(response);
474-
return null;
475-
}).when(client).execute(any(), any(), any());
476507

477508
GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.REMOTE);
478509
doAnswer(invocation -> {
@@ -498,12 +529,6 @@ public void testDeleteRemoteModel_modelNotFound_RuntimeException() throws IOExce
498529
return null;
499530
}).when(client).delete(any(), any());
500531

501-
doAnswer(invocation -> {
502-
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
503-
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
504-
listener.onResponse(response);
505-
return null;
506-
}).when(client).execute(any(), any(), any());
507532

508533
GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.REMOTE);
509534
doAnswer(invocation -> {
@@ -531,12 +556,6 @@ public void testModelNotFound_modelChunks_modelController_delete_success() throw
531556
return null;
532557
}).when(client).delete(any(), any());
533558

534-
doAnswer(invocation -> {
535-
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
536-
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
537-
listener.onResponse(response);
538-
return null;
539-
}).when(client).execute(any(), any(), any());
540559
deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
541560
ArgumentCaptor<OpenSearchStatusException> argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
542561
verify(actionListener).onFailure(argumentCaptor.capture());
@@ -661,4 +680,40 @@ private GetResponse buildResponse(MLModel mlModel) throws IOException {
661680
GetResponse getResponse = new GetResponse(getResult);
662681
return getResponse;
663682
}
683+
684+
private void prepare() {
685+
emptyBulkByScrollResponse = new BulkByScrollResponse(new ArrayList<>(), null);
686+
SearchHits hits = new SearchHits(new SearchHit[] { }, new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0.0f);
687+
when(searchResponse.getHits()).thenReturn(hits);
688+
when(getIngestionPipelineResponse.pipelines()).thenReturn(List.of());
689+
690+
doAnswer(invocation -> {
691+
ActionListener<SearchResponse> listener = invocation.getArgument(1);
692+
listener.onResponse(searchResponse);
693+
return null;
694+
}).when(client).search(any(), any());
695+
696+
697+
doAnswer(invocation -> {
698+
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
699+
listener.onResponse(emptyBulkByScrollResponse);
700+
return null;
701+
}).when(client).execute(eq(DeleteByQueryAction.INSTANCE), any(), any());
702+
703+
doAnswer(invocation -> {
704+
ActionListener<GetPipelineResponse> listener = invocation.getArgument(2);
705+
listener.onResponse(getIngestionPipelineResponse);
706+
return null;
707+
}).when(client).execute(eq(GetPipelineAction.INSTANCE), any(), any());
708+
709+
when(getSearchPipelineResponse.pipelines()).thenReturn(List.of());
710+
doAnswer(invocation -> {
711+
ActionListener<GetSearchPipelineResponse> listener = invocation.getArgument(2);
712+
listener.onResponse(getSearchPipelineResponse);
713+
return null;
714+
}).when(client).execute(eq(GetSearchPipelineAction.INSTANCE), any(), any());
715+
configDataMap = Map.of("model_id", "test_id", "list_model_id", List.of("test_list_id"),
716+
"test_map_id", Map.of("test_key", "test_map_id"));
717+
doAnswer(invocation -> new SearchRequest()).when(relatedModelIdHelper).constructQueryRequest(any());
718+
}
664719
}

0 commit comments

Comments
 (0)