Skip to content

Commit c5ab162

Browse files
[WLM] Add wlm support for scroll API (#16981)
* add wlm support for scroll API Signed-off-by: Kaushal Kumar <ravi.kaushal97@gmail.com> * add CHANGELOG entry Signed-off-by: Kaushal Kumar <ravi.kaushal97@gmail.com> * remove untagged tasks from WLM tracking Signed-off-by: Kaushal Kumar <ravi.kaushal97@gmail.com> * add UTs for invalid tasks Signed-off-by: Kaushal Kumar <ravi.kaushal97@gmail.com> * fix UT failures Signed-off-by: Kaushal Kumar <ravi.kaushal97@gmail.com> * rename a field in QueryGroupTask Signed-off-by: Kaushal Kumar <ravi.kaushal97@gmail.com> --------- Signed-off-by: Kaushal Kumar <ravi.kaushal97@gmail.com>
1 parent de59264 commit c5ab162

File tree

6 files changed

+86
-1
lines changed

6 files changed

+86
-1
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
115115
- The `phone-search` analyzer no longer emits the tel/sip prefix, international calling code, extension numbers and unformatted input as a token ([#16993](https://github.com/opensearch-project/OpenSearch/pull/16993))
116116
- Stop processing search requests when _msearch request is cancelled ([#17005](https://github.com/opensearch-project/OpenSearch/pull/17005))
117117
- Fix GRPC AUX_TRANSPORT_PORT and SETTING_GRPC_PORT settings and remove lingering HTTP terminology ([#17037](https://github.com/opensearch-project/OpenSearch/pull/17037))
118+
- [WLM] Add WLM support for search scroll API ([#16981](https://github.com/opensearch-project/OpenSearch/pull/16981))
118119
- Fix exists queries on nested flat_object fields throws exception ([#16803](https://github.com/opensearch-project/OpenSearch/pull/16803))
119120
- Use OpenSearch version to deserialize remote custom metadata([#16494](https://github.com/opensearch-project/OpenSearch/pull/16494))
120121
- Fix AutoDateHistogramAggregator rounding assertion failure ([#17023](https://github.com/opensearch-project/OpenSearch/pull/17023))

server/src/main/java/org/opensearch/action/search/TransportSearchScrollAction.java

+11-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@
3939
import org.opensearch.core.action.ActionListener;
4040
import org.opensearch.core.common.io.stream.Writeable;
4141
import org.opensearch.tasks.Task;
42+
import org.opensearch.threadpool.ThreadPool;
4243
import org.opensearch.transport.TransportService;
44+
import org.opensearch.wlm.QueryGroupTask;
4345

4446
/**
4547
* Perform the search scroll
@@ -51,24 +53,32 @@ public class TransportSearchScrollAction extends HandledTransportAction<SearchSc
5153
private final ClusterService clusterService;
5254
private final SearchTransportService searchTransportService;
5355
private final SearchPhaseController searchPhaseController;
56+
private final ThreadPool threadPool;
5457

5558
@Inject
5659
public TransportSearchScrollAction(
5760
TransportService transportService,
5861
ClusterService clusterService,
5962
ActionFilters actionFilters,
6063
SearchTransportService searchTransportService,
61-
SearchPhaseController searchPhaseController
64+
SearchPhaseController searchPhaseController,
65+
ThreadPool threadPool
6266
) {
6367
super(SearchScrollAction.NAME, transportService, actionFilters, (Writeable.Reader<SearchScrollRequest>) SearchScrollRequest::new);
6468
this.clusterService = clusterService;
6569
this.searchTransportService = searchTransportService;
6670
this.searchPhaseController = searchPhaseController;
71+
this.threadPool = threadPool;
6772
}
6873

6974
@Override
7075
protected void doExecute(Task task, SearchScrollRequest request, ActionListener<SearchResponse> listener) {
7176
try {
77+
78+
if (task instanceof QueryGroupTask) {
79+
((QueryGroupTask) task).setQueryGroupId(threadPool.getThreadContext());
80+
}
81+
7282
ParsedScrollId scrollId = TransportSearchHelper.parseScrollId(request.scrollId());
7383
Runnable action;
7484
switch (scrollId.getType()) {

server/src/main/java/org/opensearch/wlm/QueryGroupTask.java

+6
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ public class QueryGroupTask extends CancellableTask {
3333
public static final Supplier<String> DEFAULT_QUERY_GROUP_ID_SUPPLIER = () -> "DEFAULT_QUERY_GROUP";
3434
private final LongSupplier nanoTimeSupplier;
3535
private String queryGroupId;
36+
private boolean isQueryGroupSet = false;
3637

3738
public QueryGroupTask(long id, String type, String action, String description, TaskId parentTaskId, Map<String, String> headers) {
3839
this(id, type, action, description, parentTaskId, headers, NO_TIMEOUT, System::nanoTime);
@@ -81,6 +82,7 @@ public final String getQueryGroupId() {
8182
* @param threadContext current threadContext
8283
*/
8384
public final void setQueryGroupId(final ThreadContext threadContext) {
85+
isQueryGroupSet = true;
8486
if (threadContext != null && threadContext.getHeader(QUERY_GROUP_ID_HEADER) != null) {
8587
this.queryGroupId = threadContext.getHeader(QUERY_GROUP_ID_HEADER);
8688
} else {
@@ -92,6 +94,10 @@ public long getElapsedTime() {
9294
return nanoTimeSupplier.getAsLong() - getStartTimeNanos();
9395
}
9496

97+
public boolean isQueryGroupSet() {
98+
return isQueryGroupSet;
99+
}
100+
95101
@Override
96102
public boolean shouldCancelChildrenOnCancellation() {
97103
return false;

server/src/main/java/org/opensearch/wlm/tracker/QueryGroupResourceUsageTrackerService.java

+1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ private Map<String, List<QueryGroupTask>> getTasksGroupedByQueryGroup() {
7676
.stream()
7777
.filter(QueryGroupTask.class::isInstance)
7878
.map(QueryGroupTask.class::cast)
79+
.filter(QueryGroupTask::isQueryGroupSet)
7980
.collect(Collectors.groupingBy(QueryGroupTask::getQueryGroupId, Collectors.mapping(task -> task, Collectors.toList())));
8081
}
8182
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*/
8+
9+
package org.opensearch.wlm.tracker;
10+
11+
import org.opensearch.action.search.SearchTask;
12+
import org.opensearch.common.settings.ClusterSettings;
13+
import org.opensearch.common.settings.Settings;
14+
import org.opensearch.common.util.concurrent.ThreadContext;
15+
import org.opensearch.core.tasks.TaskId;
16+
import org.opensearch.tasks.TaskResourceTrackingService;
17+
import org.opensearch.test.OpenSearchTestCase;
18+
import org.opensearch.threadpool.TestThreadPool;
19+
import org.opensearch.threadpool.ThreadPool;
20+
import org.opensearch.wlm.QueryGroupLevelResourceUsageView;
21+
import org.opensearch.wlm.QueryGroupTask;
22+
23+
import java.util.HashMap;
24+
import java.util.Map;
25+
26+
public class QueryGroupTaskResourceTrackingTests extends OpenSearchTestCase {
27+
ThreadPool threadPool;
28+
QueryGroupResourceUsageTrackerService queryGroupResourceUsageTrackerService;
29+
TaskResourceTrackingService taskResourceTrackingService;
30+
31+
@Override
32+
public void setUp() throws Exception {
33+
super.setUp();
34+
threadPool = new TestThreadPool("workload-management-tracking-thread-pool");
35+
taskResourceTrackingService = new TaskResourceTrackingService(
36+
Settings.EMPTY,
37+
new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS),
38+
threadPool
39+
);
40+
queryGroupResourceUsageTrackerService = new QueryGroupResourceUsageTrackerService(taskResourceTrackingService);
41+
}
42+
43+
public void tearDown() throws Exception {
44+
super.tearDown();
45+
threadPool.shutdownNow();
46+
}
47+
48+
public void testValidQueryGroupTasksCase() {
49+
taskResourceTrackingService.setTaskResourceTrackingEnabled(true);
50+
QueryGroupTask task = new SearchTask(1, "test", "test", () -> "Test", TaskId.EMPTY_TASK_ID, new HashMap<>());
51+
taskResourceTrackingService.startTracking(task);
52+
53+
// since the query group id is not set we should not track this task
54+
Map<String, QueryGroupLevelResourceUsageView> resourceUsageViewMap = queryGroupResourceUsageTrackerService
55+
.constructQueryGroupLevelUsageViews();
56+
assertTrue(resourceUsageViewMap.isEmpty());
57+
58+
// Now since this task has a valid queryGroupId header it should be tracked
59+
try (ThreadContext.StoredContext context = threadPool.getThreadContext().stashContext()) {
60+
threadPool.getThreadContext().putHeader(QueryGroupTask.QUERY_GROUP_ID_HEADER, "testHeader");
61+
task.setQueryGroupId(threadPool.getThreadContext());
62+
resourceUsageViewMap = queryGroupResourceUsageTrackerService.constructQueryGroupLevelUsageViews();
63+
assertFalse(resourceUsageViewMap.isEmpty());
64+
}
65+
}
66+
}

server/src/test/java/org/opensearch/wlm/tracker/ResourceUsageCalculatorTrackerServiceTests.java

+1
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ private <T extends QueryGroupTask> T createMockTask(Class<T> type, long cpuUsage
146146
when(task.getTotalResourceUtilization(ResourceStats.MEMORY)).thenReturn(heapUsage);
147147
when(task.getStartTimeNanos()).thenReturn((long) 0);
148148
when(task.getElapsedTime()).thenReturn(clock.getTime());
149+
when(task.isQueryGroupSet()).thenReturn(true);
149150

150151
AtomicBoolean isCancelled = new AtomicBoolean(false);
151152
doAnswer(invocation -> {

0 commit comments

Comments
 (0)