Skip to content

Commit 2ea4dcf

Browse files
authored
Add Forecaster class (opensearch-project#920)
* Add Forecaster class This PR adds class Forecaster that serves as the configuration POJO for forecasting. Shared code between AnomalyDetector and Forecaster is extracted and moved to the Config class to reduce duplication and promote reusability. References to the common code in related classes have also been adjusted. Testing done: 1. gradle build. Signed-off-by: Kaituo Li <kaituo@amazon.com> * fix compiler error due to a recent core change Signed-off-by: Kaituo Li <kaituo@amazon.com> * address Sudipto and Amit's comments Signed-off-by: Kaituo Li <kaituo@amazon.com> --------- Signed-off-by: Kaituo Li <kaituo@amazon.com>
1 parent ee04225 commit 2ea4dcf

File tree

226 files changed

+2888
-1544
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

226 files changed

+2888
-1544
lines changed

src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ protected void runAdJob(
246246
String user = userInfo.getName();
247247
List<String> roles = userInfo.getRoles();
248248

249-
String resultIndex = jobParameter.getResultIndex();
249+
String resultIndex = jobParameter.getCustomResultIndex();
250250
if (resultIndex == null) {
251251
runAnomalyDetectionJob(
252252
jobParameter,
@@ -536,7 +536,7 @@ private void stopAdJob(String detectorId, AnomalyDetectorFunction function) {
536536
Instant.now(),
537537
job.getLockDurationSeconds(),
538538
job.getUser(),
539-
job.getResultIndex()
539+
job.getCustomResultIndex()
540540
);
541541
IndexRequest indexRequest = new IndexRequest(CommonName.JOB_INDEX)
542542
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)

src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@
175175
import org.opensearch.core.xcontent.XContentParser;
176176
import org.opensearch.env.Environment;
177177
import org.opensearch.env.NodeEnvironment;
178+
import org.opensearch.forecast.model.Forecaster;
178179
import org.opensearch.jobscheduler.spi.JobSchedulerExtension;
179180
import org.opensearch.jobscheduler.spi.ScheduledJobParser;
180181
import org.opensearch.jobscheduler.spi.ScheduledJobRunner;
@@ -955,7 +956,8 @@ public List<NamedXContentRegistry.Entry> getNamedXContent() {
955956
AnomalyDetector.XCONTENT_REGISTRY,
956957
AnomalyResult.XCONTENT_REGISTRY,
957958
DetectorInternalState.XCONTENT_REGISTRY,
958-
AnomalyDetectorJob.XCONTENT_REGISTRY
959+
AnomalyDetectorJob.XCONTENT_REGISTRY,
960+
Forecaster.XCONTENT_REGISTRY
959961
);
960962
}
961963

src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java

+19-19
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ private void prepareProfile(
149149
ActionListener<DetectorProfile> listener,
150150
Set<DetectorProfileName> profilesToCollect
151151
) {
152-
String detectorId = detector.getDetectorId();
152+
String detectorId = detector.getId();
153153
GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX, detectorId);
154154
client.get(getRequest, ActionListener.wrap(getResponse -> {
155155
if (getResponse != null && getResponse.isExists()) {
@@ -162,7 +162,7 @@ private void prepareProfile(
162162
AnomalyDetectorJob job = AnomalyDetectorJob.parse(parser);
163163
long enabledTimeMs = job.getEnabledTime().toEpochMilli();
164164

165-
boolean isMultiEntityDetector = detector.isMultientityDetector();
165+
boolean isMultiEntityDetector = detector.isHighCardinality();
166166

167167
int totalResponsesToWait = 0;
168168
if (profilesToCollect.contains(DetectorProfileName.ERROR)) {
@@ -284,8 +284,8 @@ private void prepareProfile(
284284
}
285285

286286
private void profileEntityStats(MultiResponsesDelegateActionListener<DetectorProfile> listener, AnomalyDetector detector) {
287-
List<String> categoryField = detector.getCategoryField();
288-
if (!detector.isMultientityDetector() || categoryField.size() > ADNumericSetting.maxCategoricalFields()) {
287+
List<String> categoryField = detector.getCategoryFields();
288+
if (!detector.isHighCardinality() || categoryField.size() > ADNumericSetting.maxCategoricalFields()) {
289289
listener.onResponse(new DetectorProfile.Builder().build());
290290
} else {
291291
if (categoryField.size() == 1) {
@@ -304,7 +304,7 @@ private void profileEntityStats(MultiResponsesDelegateActionListener<DetectorPro
304304
DetectorProfile profile = profileBuilder.totalEntities(value).build();
305305
listener.onResponse(profile);
306306
}, searchException -> {
307-
logger.warn(ADCommonMessages.FAIL_TO_GET_TOTAL_ENTITIES + detector.getDetectorId());
307+
logger.warn(ADCommonMessages.FAIL_TO_GET_TOTAL_ENTITIES + detector.getId());
308308
listener.onFailure(searchException);
309309
});
310310
// using the original context in listener as user roles have no permissions for internal operations like fetching a
@@ -313,7 +313,7 @@ private void profileEntityStats(MultiResponsesDelegateActionListener<DetectorPro
313313
.<SearchRequest, SearchResponse>asyncRequestWithInjectedSecurity(
314314
request,
315315
client::search,
316-
detector.getDetectorId(),
316+
detector.getId(),
317317
client,
318318
searchResponseListener
319319
);
@@ -322,7 +322,11 @@ private void profileEntityStats(MultiResponsesDelegateActionListener<DetectorPro
322322
AggregationBuilder bucketAggs = AggregationBuilders
323323
.composite(
324324
ADCommonName.TOTAL_ENTITIES,
325-
detector.getCategoryField().stream().map(f -> new TermsValuesSourceBuilder(f).field(f)).collect(Collectors.toList())
325+
detector
326+
.getCategoryFields()
327+
.stream()
328+
.map(f -> new TermsValuesSourceBuilder(f).field(f))
329+
.collect(Collectors.toList())
326330
)
327331
.size(maxTotalEntitiesToTrack);
328332
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().aggregation(bucketAggs).trackTotalHits(false).size(0);
@@ -353,7 +357,7 @@ private void profileEntityStats(MultiResponsesDelegateActionListener<DetectorPro
353357
DetectorProfile profile = profileBuilder.totalEntities(Long.valueOf(compositeAgg.getBuckets().size())).build();
354358
listener.onResponse(profile);
355359
}, searchException -> {
356-
logger.warn(ADCommonMessages.FAIL_TO_GET_TOTAL_ENTITIES + detector.getDetectorId());
360+
logger.warn(ADCommonMessages.FAIL_TO_GET_TOTAL_ENTITIES + detector.getId());
357361
listener.onFailure(searchException);
358362
});
359363
// using the original context in listener as user roles have no permissions for internal operations like fetching a
@@ -362,7 +366,7 @@ private void profileEntityStats(MultiResponsesDelegateActionListener<DetectorPro
362366
.<SearchRequest, SearchResponse>asyncRequestWithInjectedSecurity(
363367
searchRequest,
364368
client::search,
365-
detector.getDetectorId(),
369+
detector.getId(),
366370
client,
367371
searchResponseListener
368372
);
@@ -400,7 +404,7 @@ private void profileStateRelated(
400404
Set<DetectorProfileName> profilesToCollect
401405
) {
402406
if (enabled) {
403-
RCFPollingRequest request = new RCFPollingRequest(detector.getDetectorId());
407+
RCFPollingRequest request = new RCFPollingRequest(detector.getId());
404408
client.execute(RCFPollingAction.INSTANCE, request, onPollRCFUpdates(detector, profilesToCollect, listener));
405409
} else {
406410
DetectorProfile.Builder builder = new DetectorProfile.Builder();
@@ -419,7 +423,7 @@ private void profileModels(
419423
MultiResponsesDelegateActionListener<DetectorProfile> listener
420424
) {
421425
DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes();
422-
ProfileRequest profileRequest = new ProfileRequest(detector.getDetectorId(), profiles, forMultiEntityDetector, dataNodes);
426+
ProfileRequest profileRequest = new ProfileRequest(detector.getId(), profiles, forMultiEntityDetector, dataNodes);
423427
client.execute(ProfileAction.INSTANCE, profileRequest, onModelResponse(detector, profiles, job, listener));// get init progress
424428
}
425429

@@ -429,7 +433,7 @@ private ActionListener<ProfileResponse> onModelResponse(
429433
AnomalyDetectorJob job,
430434
MultiResponsesDelegateActionListener<DetectorProfile> listener
431435
) {
432-
boolean isMultientityDetector = detector.isMultientityDetector();
436+
boolean isMultientityDetector = detector.isHighCardinality();
433437
return ActionListener.wrap(profileResponse -> {
434438
DetectorProfile.Builder profile = new DetectorProfile.Builder();
435439
if (profilesToCollect.contains(DetectorProfileName.COORDINATING_NODE)) {
@@ -516,7 +520,7 @@ private ActionListener<SearchResponse> onInittedEver(
516520
logger
517521
.error(
518522
"Fail to find any anomaly result with anomaly score larger than 0 after AD job enabled time for detector {}",
519-
detector.getDetectorId()
523+
detector.getId()
520524
);
521525
listener.onFailure(exception);
522526
}
@@ -565,11 +569,7 @@ private ActionListener<RCFPollingResponse> onPollRCFUpdates(
565569
// data exists.
566570
processInitResponse(detector, profilesToCollect, 0L, true, new DetectorProfile.Builder(), listener);
567571
} else {
568-
logger
569-
.error(
570-
new ParameterizedMessage("Fail to get init progress through messaging for {}", detector.getDetectorId()),
571-
exception
572-
);
572+
logger.error(new ParameterizedMessage("Fail to get init progress through messaging for {}", detector.getId()), exception);
573573
listener.onFailure(exception);
574574
}
575575
});
@@ -603,7 +603,7 @@ private void processInitResponse(
603603
InitProgressProfile initProgress = computeInitProgressProfile(totalUpdates, 0);
604604
builder.initProgress(initProgress);
605605
} else {
606-
long intervalMins = ((IntervalTimeConfiguration) detector.getDetectionInterval()).toDuration().toMinutes();
606+
long intervalMins = ((IntervalTimeConfiguration) detector.getInterval()).toDuration().toMinutes();
607607
InitProgressProfile initProgress = computeInitProgressProfile(totalUpdates, intervalMins);
608608
builder.initProgress(initProgress);
609609
}

src/main/java/org/opensearch/ad/AnomalyDetectorRunner.java

+7-7
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ public void executeDetector(
7272
ActionListener<List<AnomalyResult>> listener
7373
) throws IOException {
7474
context.restore();
75-
List<String> categoryField = detector.getCategoryField();
75+
List<String> categoryField = detector.getCategoryFields();
7676
if (categoryField != null && !categoryField.isEmpty()) {
7777
featureManager.getPreviewEntities(detector, startTime.toEpochMilli(), endTime.toEpochMilli(), ActionListener.wrap(entities -> {
7878

@@ -86,13 +86,13 @@ public void executeDetector(
8686
ActionListener<EntityAnomalyResult> entityAnomalyResultListener = ActionListener
8787
.wrap(
8888
entityAnomalyResult -> { listener.onResponse(entityAnomalyResult.getAnomalyResults()); },
89-
e -> onFailure(e, listener, detector.getDetectorId())
89+
e -> onFailure(e, listener, detector.getId())
9090
);
9191
MultiResponsesDelegateActionListener<EntityAnomalyResult> multiEntitiesResponseListener =
9292
new MultiResponsesDelegateActionListener<EntityAnomalyResult>(
9393
entityAnomalyResultListener,
9494
entities.size(),
95-
String.format(Locale.ROOT, "Fail to get preview result for multi entity detector %s", detector.getDetectorId()),
95+
String.format(Locale.ROOT, "Fail to get preview result for multi entity detector %s", detector.getId()),
9696
true
9797
);
9898
for (Entity entity : entities) {
@@ -113,17 +113,17 @@ public void executeDetector(
113113
}, e -> multiEntitiesResponseListener.onFailure(e))
114114
);
115115
}
116-
}, e -> onFailure(e, listener, detector.getDetectorId())));
116+
}, e -> onFailure(e, listener, detector.getId())));
117117
} else {
118118
featureManager.getPreviewFeatures(detector, startTime.toEpochMilli(), endTime.toEpochMilli(), ActionListener.wrap(features -> {
119119
try {
120120
List<ThresholdingResult> results = modelManager
121121
.getPreviewResults(features.getProcessedFeatures(), detector.getShingleSize());
122122
listener.onResponse(sample(parsePreviewResult(detector, features, results, null), maxPreviewResults));
123123
} catch (Exception e) {
124-
onFailure(e, listener, detector.getDetectorId());
124+
onFailure(e, listener, detector.getId());
125125
}
126-
}, e -> onFailure(e, listener, detector.getDetectorId())));
126+
}, e -> onFailure(e, listener, detector.getId())));
127127
}
128128
}
129129

@@ -184,7 +184,7 @@ private List<AnomalyResult> parsePreviewResult(
184184
);
185185
} else {
186186
result = new AnomalyResult(
187-
detector.getDetectorId(),
187+
detector.getId(),
188188
null,
189189
featureDatas,
190190
Instant.ofEpochMilli(timeRange.getKey()),

src/main/java/org/opensearch/ad/EntityProfileRunner.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ public void profile(
105105
) {
106106
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
107107
AnomalyDetector detector = AnomalyDetector.parse(parser, detectorId);
108-
List<String> categoryFields = detector.getCategoryField();
108+
List<String> categoryFields = detector.getCategoryFields();
109109
int maxCategoryFields = ADNumericSetting.maxCategoricalFields();
110110
if (categoryFields == null || categoryFields.size() == 0) {
111111
listener.onFailure(new IllegalArgumentException(NOT_HC_DETECTOR_ERR_MSG));
@@ -186,7 +186,7 @@ private void validateEntity(
186186
.<SearchRequest, SearchResponse>asyncRequestWithInjectedSecurity(
187187
searchRequest,
188188
client::search,
189-
detector.getDetectorId(),
189+
detector.getId(),
190190
client,
191191
searchResponseListener
192192
);
@@ -277,7 +277,7 @@ private void getJob(
277277
detectorId,
278278
enabledTimeMs,
279279
entityValue,
280-
detector.getResultIndex()
280+
detector.getCustomResultIndex()
281281
);
282282

283283
EntityProfile.Builder builder = new EntityProfile.Builder();
@@ -397,7 +397,7 @@ private void sendInitState(
397397
builder.state(EntityState.INIT);
398398
}
399399
if (profilesToCollect.contains(EntityProfileName.INIT_PROGRESS)) {
400-
long intervalMins = ((IntervalTimeConfiguration) detector.getDetectionInterval()).toDuration().toMinutes();
400+
long intervalMins = ((IntervalTimeConfiguration) detector.getInterval()).toDuration().toMinutes();
401401
InitProgressProfile initProgress = computeInitProgressProfile(updates, intervalMins);
402402
builder.initProgress(initProgress);
403403
}

src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java

+8-20
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ public void indexAnomalyResult(
9292
AnomalyResultResponse response,
9393
AnomalyDetector detector
9494
) {
95-
String detectorId = detector.getDetectorId();
95+
String detectorId = detector.getId();
9696
try {
9797
// skipping writing to the result index if not necessary
9898
// For a single-entity detector, the result is not useful if error is null
@@ -124,7 +124,7 @@ public void indexAnomalyResult(
124124
response.getError()
125125
);
126126

127-
String resultIndex = detector.getResultIndex();
127+
String resultIndex = detector.getCustomResultIndex();
128128
anomalyResultHandler.index(anomalyResult, detectorId, resultIndex);
129129
updateRealtimeTask(response, detectorId);
130130
} catch (EndRunException e) {
@@ -156,13 +156,7 @@ private void updateRealtimeTask(AnomalyResultResponse response, String detectorI
156156
Runnable profileHCInitProgress = () -> {
157157
client.execute(ProfileAction.INSTANCE, profileRequest, ActionListener.wrap(r -> {
158158
log.debug("Update latest realtime task for HC detector {}, total updates: {}", detectorId, r.getTotalUpdates());
159-
updateLatestRealtimeTask(
160-
detectorId,
161-
null,
162-
r.getTotalUpdates(),
163-
response.getDetectorIntervalInMinutes(),
164-
response.getError()
165-
);
159+
updateLatestRealtimeTask(detectorId, null, r.getTotalUpdates(), response.getIntervalInMinutes(), response.getError());
166160
}, e -> { log.error("Failed to update latest realtime task for " + detectorId, e); }));
167161
};
168162
if (!adTaskManager.isHCRealtimeTaskStartInitializing(detectorId)) {
@@ -181,13 +175,7 @@ private void updateRealtimeTask(AnomalyResultResponse response, String detectorI
181175
detectorId,
182176
response.getRcfTotalUpdates()
183177
);
184-
updateLatestRealtimeTask(
185-
detectorId,
186-
null,
187-
response.getRcfTotalUpdates(),
188-
response.getDetectorIntervalInMinutes(),
189-
response.getError()
190-
);
178+
updateLatestRealtimeTask(detectorId, null, response.getRcfTotalUpdates(), response.getIntervalInMinutes(), response.getError());
191179
}
192180
}
193181

@@ -278,7 +266,7 @@ public void indexAnomalyResultException(
278266
String taskState,
279267
AnomalyDetector detector
280268
) {
281-
String detectorId = detector.getDetectorId();
269+
String detectorId = detector.getId();
282270
try {
283271
IntervalTimeConfiguration windowDelay = (IntervalTimeConfiguration) detector.getWindowDelay();
284272
Instant dataStartTime = detectionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit());
@@ -299,15 +287,15 @@ public void indexAnomalyResultException(
299287
anomalyDetectionIndices.getSchemaVersion(ADIndex.RESULT),
300288
null // no model id
301289
);
302-
String resultIndex = detector.getResultIndex();
290+
String resultIndex = detector.getCustomResultIndex();
303291
if (resultIndex != null && !anomalyDetectionIndices.doesIndexExist(resultIndex)) {
304292
// Set result index as null, will write exception to default result index.
305293
anomalyResultHandler.index(anomalyResult, detectorId, null);
306294
} else {
307295
anomalyResultHandler.index(anomalyResult, detectorId, resultIndex);
308296
}
309297

310-
if (errorMessage.contains(ADCommonMessages.NO_MODEL_ERR_MSG) && !detector.isMultiCategoryDetector()) {
298+
if (errorMessage.contains(ADCommonMessages.NO_MODEL_ERR_MSG) && !detector.isHighCardinality()) {
311299
// single stream detector raises ResourceNotFoundException containing CommonErrorMessages.NO_CHECKPOINT_ERR_MSG
312300
// when there is no checkpoint.
313301
// Delay real time cache update by one minute so we will have trained models by then and update the state
@@ -321,7 +309,7 @@ public void indexAnomalyResultException(
321309
detectorId,
322310
taskState,
323311
totalUpdates,
324-
detector.getDetectorIntervalInMinutes(),
312+
detector.getIntervalInMinutes(),
325313
totalUpdates > 0 ? "" : errorMessage
326314
);
327315
}, e -> {

src/main/java/org/opensearch/ad/NodeState.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public NodeState(String detectorId, Clock clock) {
5858
this.detectorJob = null;
5959
}
6060

61-
public String getDetectorId() {
61+
public String getId() {
6262
return detectorId;
6363
}
6464

0 commit comments

Comments
 (0)