Skip to content

Commit 77e91c2

Browse files
authored
Stop processing search requests when _msearch is canceled (#17005)
Prior to this fix, the _msearch API would keep running search requests even after being canceled. With this change, we explicitly check if the task has been canceled before kicking off subsequent requests. --------- Signed-off-by: Michael Froh <froh@amazon.com>
1 parent 0d7ac2c commit 77e91c2

File tree

3 files changed

+143
-0
lines changed

3 files changed

+143
-0
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
111111
- Fix Shallow copy snapshot failures on closed index ([#16868](https://github.com/opensearch-project/OpenSearch/pull/16868))
112112
- Fix multi-value sort for unsigned long ([#16732](https://github.com/opensearch-project/OpenSearch/pull/16732))
113113
- The `phone-search` analyzer no longer emits the tel/sip prefix, international calling code, extension numbers and unformatted input as a token ([#16993](https://github.com/opensearch-project/OpenSearch/pull/16993))
114+
- Stop processing search requests when _msearch request is cancelled ([#17005](https://github.com/opensearch-project/OpenSearch/pull/17005))
114115
- Fix GRPC AUX_TRANSPORT_PORT and SETTING_GRPC_PORT settings and remove lingering HTTP terminology ([#17037](https://github.com/opensearch-project/OpenSearch/pull/17037))
115116
- Fix exists queries on nested flat_object fields throws exception ([#16803](https://github.com/opensearch-project/OpenSearch/pull/16803))
116117

server/src/main/java/org/opensearch/action/search/TransportMultiSearchAction.java

+24
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@
4444
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
4545
import org.opensearch.core.action.ActionListener;
4646
import org.opensearch.core.common.io.stream.Writeable;
47+
import org.opensearch.core.tasks.TaskCancelledException;
48+
import org.opensearch.core.tasks.TaskId;
49+
import org.opensearch.tasks.CancellableTask;
4750
import org.opensearch.tasks.Task;
4851
import org.opensearch.threadpool.ThreadPool;
4952
import org.opensearch.transport.TransportService;
@@ -193,6 +196,19 @@ private void handleResponse(final int responseSlot, final MultiSearchResponse.It
193196
if (responseCounter.decrementAndGet() == 0) {
194197
assert requests.isEmpty();
195198
finish();
199+
} else if (isCancelled(request.request.getParentTask())) {
200+
// Drain the rest of the queue
201+
SearchRequestSlot request;
202+
while ((request = requests.poll()) != null) {
203+
responses.set(
204+
request.responseSlot,
205+
new MultiSearchResponse.Item(null, new TaskCancelledException("Parent task was cancelled"))
206+
);
207+
if (responseCounter.decrementAndGet() == 0) {
208+
assert requests.isEmpty();
209+
finish();
210+
}
211+
}
196212
} else {
197213
if (thread == Thread.currentThread()) {
198214
// we are on the same thread, we need to fork to another thread to avoid recursive stack overflow on a single thread
@@ -220,6 +236,14 @@ private long buildTookInMillis() {
220236
});
221237
}
222238

239+
private boolean isCancelled(TaskId taskId) {
240+
if (taskId.isSet()) {
241+
CancellableTask task = taskManager.getCancellableTask(taskId.getId());
242+
return task != null && task.isCancelled();
243+
}
244+
return false;
245+
}
246+
223247
/**
224248
* Slots a search request
225249
*

server/src/test/java/org/opensearch/action/search/TransportMultiSearchActionTests.java

+118
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@
4949
import org.opensearch.common.settings.Settings;
5050
import org.opensearch.core.action.ActionListener;
5151
import org.opensearch.search.internal.InternalSearchResponse;
52+
import org.opensearch.tasks.CancellableTask;
5253
import org.opensearch.tasks.Task;
54+
import org.opensearch.tasks.TaskListener;
5355
import org.opensearch.tasks.TaskManager;
5456
import org.opensearch.telemetry.tracing.noop.NoopTracer;
5557
import org.opensearch.test.OpenSearchTestCase;
@@ -62,7 +64,9 @@
6264
import java.util.IdentityHashMap;
6365
import java.util.List;
6466
import java.util.Set;
67+
import java.util.concurrent.CountDownLatch;
6568
import java.util.concurrent.ExecutorService;
69+
import java.util.concurrent.TimeUnit;
6670
import java.util.concurrent.atomic.AtomicInteger;
6771
import java.util.concurrent.atomic.AtomicReference;
6872

@@ -289,4 +293,118 @@ public void testDefaultMaxConcurrentSearches() {
289293
assertThat(result, equalTo(1));
290294
}
291295

296+
public void testCancellation() {
297+
// Initialize dependencies of TransportMultiSearchAction
298+
Settings settings = Settings.builder().put("node.name", TransportMultiSearchActionTests.class.getSimpleName()).build();
299+
ActionFilters actionFilters = mock(ActionFilters.class);
300+
when(actionFilters.filters()).thenReturn(new ActionFilter[0]);
301+
ThreadPool threadPool = new ThreadPool(settings);
302+
TransportService transportService = new TransportService(
303+
Settings.EMPTY,
304+
mock(Transport.class),
305+
threadPool,
306+
TransportService.NOOP_TRANSPORT_INTERCEPTOR,
307+
boundAddress -> DiscoveryNode.createLocal(settings, boundAddress.publishAddress(), UUIDs.randomBase64UUID()),
308+
null,
309+
Collections.emptySet(),
310+
NoopTracer.INSTANCE
311+
) {
312+
@Override
313+
public TaskManager getTaskManager() {
314+
return taskManager;
315+
}
316+
};
317+
ClusterService clusterService = mock(ClusterService.class);
318+
when(clusterService.state()).thenReturn(ClusterState.builder(new ClusterName("test")).build());
319+
320+
// Keep track of the number of concurrent searches started by multi search api,
321+
// and if there are more searches than is allowed create an error and remember that.
322+
int maxAllowedConcurrentSearches = 1; // Allow 1 search at a time.
323+
AtomicInteger counter = new AtomicInteger();
324+
AtomicReference<AssertionError> errorHolder = new AtomicReference<>();
325+
// randomize whether or not requests are executed asynchronously
326+
ExecutorService executorService = threadPool.executor(ThreadPool.Names.GENERIC);
327+
final Set<SearchRequest> requests = Collections.newSetFromMap(Collections.synchronizedMap(new IdentityHashMap<>()));
328+
CountDownLatch countDownLatch = new CountDownLatch(1);
329+
CancellableTask[] parentTask = new CancellableTask[1];
330+
NodeClient client = new NodeClient(settings, threadPool) {
331+
@Override
332+
public void search(final SearchRequest request, final ActionListener<SearchResponse> listener) {
333+
if (parentTask[0] != null && parentTask[0].isCancelled()) {
334+
fail("Should not execute search after parent task is cancelled");
335+
}
336+
try {
337+
countDownLatch.await(10, TimeUnit.MILLISECONDS);
338+
} catch (InterruptedException e) {
339+
throw new RuntimeException(e);
340+
}
341+
342+
requests.add(request);
343+
executorService.execute(() -> {
344+
counter.decrementAndGet();
345+
listener.onResponse(
346+
new SearchResponse(
347+
InternalSearchResponse.empty(),
348+
null,
349+
0,
350+
0,
351+
0,
352+
0L,
353+
ShardSearchFailure.EMPTY_ARRAY,
354+
SearchResponse.Clusters.EMPTY
355+
)
356+
);
357+
});
358+
}
359+
360+
@Override
361+
public String getLocalNodeId() {
362+
return "local_node_id";
363+
}
364+
};
365+
366+
TransportMultiSearchAction action = new TransportMultiSearchAction(
367+
threadPool,
368+
actionFilters,
369+
transportService,
370+
clusterService,
371+
10,
372+
System::nanoTime,
373+
client
374+
);
375+
376+
// Execute the multi search api and fail if we find an error after executing:
377+
try {
378+
/*
379+
* Allow for a large number of search requests in a single batch as previous implementations could stack overflow if the number
380+
* of requests in a single batch was large
381+
*/
382+
int numSearchRequests = scaledRandomIntBetween(1024, 8192);
383+
MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
384+
multiSearchRequest.maxConcurrentSearchRequests(maxAllowedConcurrentSearches);
385+
for (int i = 0; i < numSearchRequests; i++) {
386+
multiSearchRequest.add(new SearchRequest());
387+
}
388+
MultiSearchResponse[] responses = new MultiSearchResponse[1];
389+
Exception[] exceptions = new Exception[1];
390+
parentTask[0] = (CancellableTask) action.execute(multiSearchRequest, new TaskListener<>() {
391+
@Override
392+
public void onResponse(Task task, MultiSearchResponse items) {
393+
responses[0] = items;
394+
}
395+
396+
@Override
397+
public void onFailure(Task task, Exception e) {
398+
exceptions[0] = e;
399+
}
400+
});
401+
parentTask[0].cancel("Giving up");
402+
countDownLatch.countDown();
403+
404+
assertNull(responses[0]);
405+
assertNull(exceptions[0]);
406+
} finally {
407+
assertTrue(OpenSearchTestCase.terminate(threadPool));
408+
}
409+
}
292410
}

0 commit comments

Comments
 (0)