Skip to content

Commit f8267b4

Browse files
Fix race condition in PageListener (#1351) (#1352)
* Fix race condition in PageListener This PR - Introduced an `AtomicInteger` called `pagesInFlight` to track the number of pages currently being processed.  - Incremented `pagesInFlight` before processing each page and decremented it after processing is complete - Adjusted the condition in `scheduleImputeHCTask` to check both `pagesInFlight.get() == 0` (all pages have been processed) and `sentOutPages.get() == receivedPages.get()` (all responses have been received) before scheduling the `imputeHC` task.  - Removed the previous final check in `onResponse` that decided when to schedule `imputeHC`, relying instead on the updated counters for accurate synchronization. These changes address the race condition where `sentOutPages` might not have been incremented in time before checking whether to schedule the `imputeHC` task. By accurately tracking the number of in-flight pages and sent pages, we ensure that `imputeHC` is executed only after all pages have been fully processed and all responses have been received. Testing done: 1. Reproduced the race condition by starting two detectors with imputation. This causes an out of order illegal argument exception from RCF due to this race condition. Also verified the change fixed the problem. 2. added an IT for the above scenario. * make sure increment before schedule --------- (cherry picked from commit f62885a) Signed-off-by: Kaituo Li <kaituo@amazon.com> Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 5842bef commit f8267b4

File tree

6 files changed

+127
-25
lines changed

6 files changed

+127
-25
lines changed

src/main/java/org/opensearch/timeseries/transport/ResultBulkTransportAction.java

-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ protected void doExecute(Task task, ResultBulkRequestType request, ActionListene
8585
// all non-zero anomaly grade index requests and index zero anomaly grade index requests with probability (1 - index pressure).
8686
long totalBytes = indexingPressure.getCurrentCombinedCoordinatingAndPrimaryBytes() + indexingPressure.getCurrentReplicaBytes();
8787
float indexingPressurePercent = (float) totalBytes / primaryAndCoordinatingLimits;
88-
@SuppressWarnings("rawtypes")
8988
List<? extends ResultWriteRequest> results = request.getResults();
9089

9190
if (results == null || results.size() < 1) {

src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java

+22-12
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,10 @@ class PageListener implements ActionListener<CompositeRetriever.Page> {
210210
private String taskId;
211211
private AtomicInteger receivedPages;
212212
private AtomicInteger sentOutPages;
213+
// By introducing pagesInFlight and incrementing it in the main thread before asynchronous processing begins,
214+
// we ensure that the count of in-flight pages is accurate at all times. This allows us to reliably determine
215+
// when all pages have been processed.
216+
private AtomicInteger pagesInFlight;
213217

214218
PageListener(PageIterator pageIterator, Config config, long dataStartTime, long dataEndTime, String taskId) {
215219
this.pageIterator = pageIterator;
@@ -220,14 +224,21 @@ class PageListener implements ActionListener<CompositeRetriever.Page> {
220224
this.taskId = taskId;
221225
this.receivedPages = new AtomicInteger();
222226
this.sentOutPages = new AtomicInteger();
227+
this.pagesInFlight = new AtomicInteger();
223228
}
224229

225230
@Override
226231
public void onResponse(CompositeRetriever.Page entityFeatures) {
232+
// Increment pagesInFlight to track the processing of this page
233+
pagesInFlight.incrementAndGet();
234+
227235
// start processing next page after sending out features for previous page
228236
if (pageIterator.hasNext()) {
229237
pageIterator.next(this);
238+
} else if (config.getImputationOption() != null) {
239+
scheduleImputeHCTask();
230240
}
241+
231242
if (entityFeatures != null && false == entityFeatures.isEmpty()) {
232243
LOG
233244
.info(
@@ -309,19 +320,15 @@ public void onResponse(CompositeRetriever.Page entityFeatures) {
309320
} catch (Exception e) {
310321
LOG.error("Unexpected exception", e);
311322
handleException(e);
323+
} finally {
324+
// Decrement pagesInFlight after processing is complete
325+
pagesInFlight.decrementAndGet();
312326
}
313327
});
314-
}
315-
316-
if (!pageIterator.hasNext() && config.getImputationOption() != null) {
317-
if (sentOutPages.get() > 0) {
318-
// at least 1 page sent out. Wait until all responses are back.
319-
scheduleImputeHCTask();
320-
} else {
321-
// no data in current interval. Send out impute request right away.
322-
imputeHC(dataStartTime, dataEndTime, configId, taskId);
323-
}
324-
328+
} else {
329+
// No entity features to process
330+
// Decrement pagesInFlight immediately
331+
pagesInFlight.decrementAndGet();
325332
}
326333
}
327334

@@ -358,7 +365,10 @@ private void scheduleImputeHCTask() {
358365

359366
@Override
360367
public void run() {
361-
if (sentOutPages.get() == receivedPages.get()) {
368+
// By using pagesInFlight in the condition within scheduleImputeHCTask, we ensure that imputeHC
369+
// is executed only after all pages have been processed (pagesInFlight.get() == 0) and all
370+
// responses have been received (sentOutPages.get() == receivedPages.get()).
371+
if (pagesInFlight.get() == 0 && sentOutPages.get() == receivedPages.get()) {
362372
if (!sent.get()) {
363373
// since we don't know when cancel will succeed, need sent to ensure imputeHC is only called once
364374
sent.set(true);

src/test/java/org/opensearch/ad/e2e/AbstractMissingSingleFeatureTestCase.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ protected String genDetector(
2828
long windowDelayMinutes,
2929
boolean hc,
3030
ImputationMethod imputation,
31-
long trainTimeMillis
31+
long trainTimeMillis,
32+
String name
3233
) {
3334
StringBuilder sb = new StringBuilder();
3435
// common part

src/test/java/org/opensearch/ad/e2e/MissingIT.java

+30-5
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,27 @@ protected TrainResult createAndStartRealTimeDetector(
7878
List<JsonObject> data,
7979
ImputationMethod imputation,
8080
boolean hc,
81-
long trainTimeMillis
81+
long trainTimeMillis,
82+
String name
8283
) throws Exception {
83-
TrainResult trainResult = createDetector(numberOfEntities, trainTestSplit, data, imputation, hc, trainTimeMillis);
84+
TrainResult trainResult = createDetector(numberOfEntities, trainTestSplit, data, imputation, hc, trainTimeMillis, name);
8485
List<JsonObject> result = startRealTimeDetector(trainResult, numberOfEntities, intervalMinutes, true);
8586
recordLastSeenFromResult(result);
8687

8788
return trainResult;
8889
}
8990

91+
protected TrainResult createAndStartRealTimeDetector(
92+
int numberOfEntities,
93+
int trainTestSplit,
94+
List<JsonObject> data,
95+
ImputationMethod imputation,
96+
boolean hc,
97+
long trainTimeMillis
98+
) throws Exception {
99+
return createAndStartRealTimeDetector(numberOfEntities, trainTestSplit, data, imputation, hc, trainTimeMillis, "test");
100+
}
101+
90102
protected TrainResult createAndStartHistoricalDetector(
91103
int numberOfEntities,
92104
int trainTestSplit,
@@ -115,12 +127,13 @@ protected TrainResult createDetector(
115127
List<JsonObject> data,
116128
ImputationMethod imputation,
117129
boolean hc,
118-
long trainTimeMillis
130+
long trainTimeMillis,
131+
String name
119132
) throws Exception {
120133
Instant trainTime = Instant.ofEpochMilli(trainTimeMillis);
121134

122135
Duration windowDelay = getWindowDelay(trainTimeMillis);
123-
String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), hc, imputation, trainTimeMillis);
136+
String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), hc, imputation, trainTimeMillis, name);
124137

125138
RestClient client = client();
126139
String detectorId = createDetector(client, detector);
@@ -129,6 +142,17 @@ protected TrainResult createDetector(
129142
return new TrainResult(detectorId, data, trainTestSplit * numberOfEntities, windowDelay, trainTime, "timestamp");
130143
}
131144

145+
protected TrainResult createDetector(
146+
int numberOfEntities,
147+
int trainTestSplit,
148+
List<JsonObject> data,
149+
ImputationMethod imputation,
150+
boolean hc,
151+
long trainTimeMillis
152+
) throws Exception {
153+
return createDetector(numberOfEntities, trainTestSplit, data, imputation, hc, trainTimeMillis, "test");
154+
}
155+
132156
protected Duration getWindowDelay(long trainTimeMillis) {
133157
/*
134158
* AD accepts windowDelay in the unit of minutes. Thus, we need to convert the delay in minutes. This will
@@ -156,7 +180,8 @@ protected abstract String genDetector(
156180
long windowDelayMinutes,
157181
boolean hc,
158182
ImputationMethod imputation,
159-
long trainTimeMillis
183+
long trainTimeMillis,
184+
String name
160185
);
161186

162187
protected abstract AbstractSyntheticDataTest.GenData genData(

src/test/java/org/opensearch/ad/e2e/MissingMultiFeatureIT.java

+71-4
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,80 @@ public void testHCPrevious() throws Exception {
135135
);
136136
}
137137

138+
/**
139+
* test we start two HC detector with zero imputation consecutively.
140+
* We expect there is no out of order error from RCF.
141+
* @throws Exception
142+
*/
143+
public void testDoubleHCZero() throws Exception {
144+
lastSeen.clear();
145+
int numberOfEntities = 2;
146+
147+
AbstractSyntheticDataTest.MISSING_MODE mode = AbstractSyntheticDataTest.MISSING_MODE.NO_MISSING_DATA;
148+
ImputationMethod method = ImputationMethod.ZERO;
149+
150+
AbstractSyntheticDataTest.GenData dataGenerated = genData(trainTestSplit, numberOfEntities, mode);
151+
152+
// only ingest train data to avoid validation error as we use latest data time as starting point.
153+
// otherwise, we will have too many missing points.
154+
ingestUniformSingleFeatureData(
155+
trainTestSplit + numberOfEntities * 6, // we only need a few to verify and trigger train.
156+
dataGenerated.data
157+
);
158+
159+
TrainResult trainResult1 = createAndStartRealTimeDetector(
160+
numberOfEntities,
161+
trainTestSplit,
162+
dataGenerated.data,
163+
method,
164+
true,
165+
dataGenerated.testStartTime,
166+
"test1"
167+
);
168+
169+
TrainResult trainResult2 = createAndStartRealTimeDetector(
170+
numberOfEntities,
171+
trainTestSplit,
172+
dataGenerated.data,
173+
method,
174+
true,
175+
dataGenerated.testStartTime,
176+
"test2"
177+
);
178+
179+
runTest(
180+
dataGenerated.testStartTime,
181+
dataGenerated,
182+
trainResult1.windowDelay,
183+
trainResult1.detectorId,
184+
numberOfEntities,
185+
mode,
186+
method,
187+
3,
188+
true
189+
);
190+
191+
runTest(
192+
dataGenerated.testStartTime,
193+
dataGenerated,
194+
trainResult2.windowDelay,
195+
trainResult2.detectorId,
196+
numberOfEntities,
197+
mode,
198+
method,
199+
3,
200+
true
201+
);
202+
}
203+
138204
@Override
139205
protected String genDetector(
140206
int trainTestSplit,
141207
long windowDelayMinutes,
142208
boolean hc,
143209
ImputationMethod imputation,
144-
long trainTimeMillis
210+
long trainTimeMillis,
211+
String name
145212
) {
146213
StringBuilder sb = new StringBuilder();
147214

@@ -185,7 +252,7 @@ protected String genDetector(
185252
// common part
186253
sb
187254
.append(
188-
"{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\""
255+
"{ \"name\": \"%s\", \"description\": \"test\", \"time_field\": \"timestamp\""
189256
+ ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_id\": \"feature2\", \"feature_name\": \"feature 2\", \"feature_enabled\": "
190257
+ "\"true\", \"aggregation_query\": { \"Feature2\": { \"avg\": { \"field\": \"data\" } } } },"
191258
+ featureWithFilter
@@ -226,9 +293,9 @@ protected String genDetector(
226293
sb.append("\"schema_version\": 0}");
227294

228295
if (hc) {
229-
return String.format(Locale.ROOT, sb.toString(), datasetName, intervalMinutes, trainTestSplit - 1, categoricalField);
296+
return String.format(Locale.ROOT, sb.toString(), name, datasetName, intervalMinutes, trainTestSplit - 1, categoricalField);
230297
} else {
231-
return String.format(Locale.ROOT, sb.toString(), datasetName, intervalMinutes, trainTestSplit - 1);
298+
return String.format(Locale.ROOT, sb.toString(), name, datasetName, intervalMinutes, trainTestSplit - 1);
232299
}
233300
}
234301

src/test/java/org/opensearch/ad/e2e/PreviewMissingSingleFeatureIT.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public void testSingleStream() throws Exception {
3535
);
3636

3737
Duration windowDelay = getWindowDelay(dataGenerated.testStartTime);
38-
String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), false, method, dataGenerated.testStartTime);
38+
String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), false, method, dataGenerated.testStartTime, "test");
3939

4040
Instant begin = Instant.ofEpochMilli(dataGenerated.data.get(0).get("timestamp").getAsLong());
4141
Instant end = Instant.ofEpochMilli(dataGenerated.data.get(dataGenerated.data.size() - 1).get("timestamp").getAsLong());
@@ -63,7 +63,7 @@ public void testHC() throws Exception {
6363
);
6464

6565
Duration windowDelay = getWindowDelay(dataGenerated.testStartTime);
66-
String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), true, method, dataGenerated.testStartTime);
66+
String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), true, method, dataGenerated.testStartTime, "test");
6767

6868
Instant begin = Instant.ofEpochMilli(dataGenerated.data.get(0).get("timestamp").getAsLong());
6969
Instant end = Instant.ofEpochMilli(dataGenerated.data.get(dataGenerated.data.size() - 1).get("timestamp").getAsLong());

0 commit comments

Comments
 (0)