Skip to content

Commit 9054f6c

Browse files
committed
feat(telemetry aware plugin): adding counters and implementing telemetry aware plugin
1 parent 204d498 commit 9054f6c

7 files changed

+184
-11
lines changed

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

+50-9
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@
5555
import org.opensearch.env.Environment;
5656
import org.opensearch.env.NodeEnvironment;
5757
import org.opensearch.indices.SystemIndexDescriptor;
58+
import org.opensearch.jobscheduler.spi.JobSchedulerExtension;
59+
import org.opensearch.jobscheduler.spi.ScheduledJobParser;
60+
import org.opensearch.jobscheduler.spi.ScheduledJobRunner;
5861
import org.opensearch.ml.action.agents.DeleteAgentTransportAction;
5962
import org.opensearch.ml.action.agents.GetAgentTransportAction;
6063
import org.opensearch.ml.action.agents.TransportRegisterAgentAction;
@@ -111,6 +114,7 @@
111114
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
112115
import org.opensearch.ml.cluster.MLCommonsClusterEventListener;
113116
import org.opensearch.ml.cluster.MLCommonsClusterManagerEventListener;
117+
import org.opensearch.ml.common.CommonValue;
114118
import org.opensearch.ml.common.FunctionName;
115119
import org.opensearch.ml.common.input.execute.anomalylocalization.AnomalyLocalizationInput;
116120
import org.opensearch.ml.common.input.execute.metricscorrelation.MetricsCorrelationInput;
@@ -199,7 +203,8 @@
199203
import org.opensearch.ml.engine.utils.AgentModelsSearcher;
200204
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
201205
import org.opensearch.ml.helper.ModelAccessControlHelper;
202-
import org.opensearch.ml.jobs.MLBatchTaskUpdateJobRunner;
206+
import org.opensearch.ml.jobs.MLJobParameter;
207+
import org.opensearch.ml.jobs.MLJobRunner;
203208
import org.opensearch.ml.memory.ConversationalMemoryHandler;
204209
import org.opensearch.ml.memory.action.conversation.CreateConversationAction;
205210
import org.opensearch.ml.memory.action.conversation.CreateConversationTransportAction;
@@ -293,6 +298,7 @@
293298
import org.opensearch.ml.stats.MLNodeLevelStat;
294299
import org.opensearch.ml.stats.MLStat;
295300
import org.opensearch.ml.stats.MLStats;
301+
import org.opensearch.ml.stats.otel.counters.MLOperationalMetricsCounter;
296302
import org.opensearch.ml.stats.suppliers.CounterSupplier;
297303
import org.opensearch.ml.stats.suppliers.IndexStatusSupplier;
298304
import org.opensearch.ml.task.MLExecuteTaskRunner;
@@ -313,6 +319,7 @@
313319
import org.opensearch.plugins.SearchPipelinePlugin;
314320
import org.opensearch.plugins.SearchPlugin;
315321
import org.opensearch.plugins.SystemIndexPlugin;
322+
import org.opensearch.plugins.TelemetryAwarePlugin;
316323
import org.opensearch.remote.metadata.client.SdkClient;
317324
import org.opensearch.remote.metadata.client.impl.SdkClientFactory;
318325
import org.opensearch.repositories.RepositoriesService;
@@ -326,6 +333,8 @@
326333
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQARequestProcessor;
327334
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAResponseProcessor;
328335
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder;
336+
import org.opensearch.telemetry.metrics.MetricsRegistry;
337+
import org.opensearch.telemetry.tracing.Tracer;
329338
import org.opensearch.threadpool.ExecutorBuilder;
330339
import org.opensearch.threadpool.FixedExecutorBuilder;
331340
import org.opensearch.threadpool.ThreadPool;
@@ -343,7 +352,9 @@ public class MachineLearningPlugin extends Plugin
343352
SearchPipelinePlugin,
344353
ExtensiblePlugin,
345354
IngestPlugin,
346-
SystemIndexPlugin {
355+
SystemIndexPlugin,
356+
TelemetryAwarePlugin,
357+
JobSchedulerExtension {
347358
public static final String ML_THREAD_POOL_PREFIX = "thread_pool.ml_commons.";
348359
public static final String GENERAL_THREAD_POOL = "opensearch_ml_general";
349360
public static final String SDK_CLIENT_THREAD_POOL = "opensearch_ml_sdkclient";
@@ -356,6 +367,8 @@ public class MachineLearningPlugin extends Plugin
356367
public static final String DEPLOY_THREAD_POOL = "opensearch_ml_deploy";
357368
public static final String ML_BASE_URI = "/_plugins/_ml";
358369

370+
public static final String ML_COMMONS_JOBS_TYPE = "opensearch_ml_commons_jobs";
371+
359372
private MLStats mlStats;
360373
private MLModelCacheHelper modelCacheHelper;
361374
private MLTaskManager mlTaskManager;
@@ -398,11 +411,13 @@ public class MachineLearningPlugin extends Plugin
398411
private ScriptService scriptService;
399412
private Encryptor encryptor;
400413

401-
public MachineLearningPlugin(Settings settings) {
402-
// Handle this here as this feature is tied to Search/Query API, not to a ml-common API
403-
// and as such, it can't be lazy-loaded when a ml-commons API is invoked.
404-
this.ragSearchPipelineEnabled = MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED.get(settings);
405-
}
414+
// public MachineLearningPlugin(Settings settings) {
415+
// // Handle this here as this feature is tied to Search/Query API, not to a ml-common API
416+
// // and as such, it can't be lazy-loaded when a ml-commons API is invoked.
417+
// this.ragSearchPipelineEnabled = MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED.get(settings);
418+
// }
419+
420+
public MachineLearningPlugin() {}
406421

407422
@Override
408423
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
@@ -485,7 +500,9 @@ public Collection<Object> createComponents(
485500
NodeEnvironment nodeEnvironment,
486501
NamedWriteableRegistry namedWriteableRegistry,
487502
IndexNameExpressionResolver indexNameExpressionResolver,
488-
Supplier<RepositoriesService> repositoriesServiceSupplier
503+
Supplier<RepositoriesService> repositoriesServiceSupplier,
504+
Tracer tracer,
505+
MetricsRegistry metricsRegistry
489506
) {
490507
this.indexUtils = new IndexUtils(client, clusterService);
491508
this.client = client;
@@ -723,7 +740,11 @@ public Collection<Object> createComponents(
723740
.getClusterSettings()
724741
.addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, it -> ragSearchPipelineEnabled = it);
725742

726-
MLBatchTaskUpdateJobRunner.getJobRunnerInstance().initialize(clusterService, threadPool, client);
743+
MLJobRunner.getInstance().initialize(clusterService, threadPool, client);
744+
745+
// todo: add setting
746+
MLOperationalMetricsCounter.initialize(clusterService.getClusterName().toString(), metricsRegistry);
747+
// MLAdoptionMetricsCounter.initialize(clusterService.getClusterName().toString(), metricsRegistry);
727748

728749
return ImmutableList
729750
.of(
@@ -1174,4 +1195,24 @@ public Collection<SystemIndexDescriptor> getSystemIndexDescriptors(Settings sett
11741195
systemIndexDescriptors.add(new SystemIndexDescriptor(ML_STOP_WORDS_INDEX, "ML Commons Stop Words Index"));
11751196
return systemIndexDescriptors;
11761197
}
1198+
1199+
@Override
1200+
public String getJobType() {
1201+
return ML_COMMONS_JOBS_TYPE;
1202+
}
1203+
1204+
@Override
1205+
public String getJobIndex() {
1206+
return CommonValue.ML_JOBS_INDEX;
1207+
}
1208+
1209+
@Override
1210+
public ScheduledJobRunner getJobRunner() {
1211+
return MLJobRunner.getInstance();
1212+
}
1213+
1214+
@Override
1215+
public ScheduledJobParser getJobParser() {
1216+
return (parser, id, jobDocVersion) -> MLJobParameter.parse(parser);
1217+
}
11771218
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package org.opensearch.ml.stats.otel.counters;
2+
3+
import java.util.Map;
4+
import java.util.concurrent.ConcurrentHashMap;
5+
import java.util.stream.Stream;
6+
7+
import org.opensearch.telemetry.metrics.Counter;
8+
import org.opensearch.telemetry.metrics.MetricsRegistry;
9+
import org.opensearch.telemetry.metrics.tags.Tags;
10+
11+
public abstract class AbstractMLMetricsCounter<T extends Enum<T>> {
12+
private static final String PREFIX = "ml.commons.";
13+
private static final String UNIT = "1";
14+
private static final String CLUSTER_NAME_TAG = "cluster_name";
15+
16+
protected final String clusterName;
17+
protected final MetricsRegistry metricsRegistry;
18+
protected final Map<T, Counter> metricCounterMap;
19+
20+
protected AbstractMLMetricsCounter(String clusterName, MetricsRegistry metricsRegistry, Class<T> metricClass) {
21+
this.clusterName = clusterName;
22+
this.metricsRegistry = metricsRegistry;
23+
this.metricCounterMap = new ConcurrentHashMap<>();
24+
Stream.of(metricClass.getEnumConstants()).forEach(metric -> metricCounterMap.computeIfAbsent(metric, this::createMetricCounter));
25+
}
26+
27+
public void incrementCounter(T metric, Tags customTags) {
28+
Counter counter = metricCounterMap.computeIfAbsent(metric, this::createMetricCounter);
29+
Tags metricsTags = (customTags == null ? Tags.create() : customTags).addTag(CLUSTER_NAME_TAG, clusterName);
30+
counter.add(1, metricsTags);
31+
}
32+
33+
private Counter createMetricCounter(T metric) {
34+
return metricsRegistry.createCounter(PREFIX + metric.name(), getMetricDescription(metric), UNIT);
35+
}
36+
37+
protected abstract String getMetricDescription(T metric);
38+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package org.opensearch.ml.stats.otel.counters;
2+
3+
import org.opensearch.ml.stats.otel.metrics.AdoptionMetric;
4+
import org.opensearch.telemetry.metrics.MetricsRegistry;
5+
6+
public class MLAdoptionMetricsCounter extends AbstractMLMetricsCounter<AdoptionMetric> {
7+
8+
private static MLAdoptionMetricsCounter instance;
9+
10+
private MLAdoptionMetricsCounter(String clusterName, MetricsRegistry metricsRegistry) {
11+
super(clusterName, metricsRegistry, AdoptionMetric.class);
12+
}
13+
14+
public static synchronized void initialize(String clusterName, MetricsRegistry metricsRegistry) {
15+
instance = new MLAdoptionMetricsCounter(clusterName, metricsRegistry);
16+
}
17+
18+
public static synchronized MLAdoptionMetricsCounter getInstance() {
19+
if (instance == null) {
20+
throw new IllegalStateException("MLAdoptionMetricsCounter is not initialized. Call initialize() first.");
21+
}
22+
return instance;
23+
}
24+
25+
@Override
26+
protected String getMetricDescription(AdoptionMetric metric) {
27+
return metric.getDescription();
28+
}
29+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package org.opensearch.ml.stats.otel.counters;
2+
3+
import org.opensearch.ml.stats.otel.metrics.OperationalMetric;
4+
import org.opensearch.telemetry.metrics.MetricsRegistry;
5+
6+
public class MLOperationalMetricsCounter extends AbstractMLMetricsCounter<OperationalMetric> {
7+
8+
private static MLOperationalMetricsCounter instance;
9+
10+
private MLOperationalMetricsCounter(String clusterName, MetricsRegistry metricsRegistry) {
11+
super(clusterName, metricsRegistry, OperationalMetric.class);
12+
}
13+
14+
public static synchronized void initialize(String clusterName, MetricsRegistry metricsRegistry) {
15+
instance = new MLOperationalMetricsCounter(clusterName, metricsRegistry);
16+
}
17+
18+
public static synchronized MLOperationalMetricsCounter getInstance() {
19+
if (instance == null) {
20+
throw new IllegalStateException("MLOperationalMetricsCounter is not initialized. Call initialize() first.");
21+
}
22+
23+
return instance;
24+
}
25+
26+
@Override
27+
protected String getMetricDescription(OperationalMetric metric) {
28+
return metric.getDescription();
29+
}
30+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package org.opensearch.ml.stats.otel.metrics;
2+
3+
import lombok.Getter;
4+
5+
@Getter
6+
public enum AdoptionMetric {
7+
MODEL_COUNT("Number of models created"),
8+
CONNECTOR_COUNT("Number of connectors created");
9+
10+
private final String description;
11+
12+
AdoptionMetric(String description) {
13+
this.description = description;
14+
}
15+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package org.opensearch.ml.stats.otel.metrics;
2+
3+
import lombok.Getter;
4+
5+
@Getter
6+
public enum OperationalMetric {
7+
MODEL_PREDICT_COUNT("Total number of predict calls made"),
8+
MODEL_PREDICT_LATENCY("Latency for model predict");
9+
10+
private final String description;
11+
12+
OperationalMetric(String description) {
13+
this.description = description;
14+
}
15+
}

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

+7-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
package org.opensearch.ml.task;
77

88
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
9+
import static org.opensearch.ml.common.CommonValue.ML_JOBS_INDEX;
910
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
10-
import static org.opensearch.ml.common.CommonValue.TASK_POLLING_JOB_INDEX;
1111
import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD;
1212
import static org.opensearch.ml.common.utils.StringUtils.getErrorMessage;
1313
import static org.opensearch.ml.permission.AccessController.checkUserPermissions;
@@ -73,7 +73,10 @@
7373
import org.opensearch.ml.stats.MLActionLevelStat;
7474
import org.opensearch.ml.stats.MLNodeLevelStat;
7575
import org.opensearch.ml.stats.MLStats;
76+
import org.opensearch.ml.stats.otel.counters.MLOperationalMetricsCounter;
77+
import org.opensearch.ml.stats.otel.metrics.OperationalMetric;
7678
import org.opensearch.ml.utils.MLNodeUtils;
79+
import org.opensearch.telemetry.metrics.tags.Tags;
7780
import org.opensearch.threadpool.ThreadPool;
7881
import org.opensearch.transport.TransportResponseHandler;
7982
import org.opensearch.transport.TransportService;
@@ -433,7 +436,8 @@ private void runPredict(
433436
remoteJob
434437
);
435438

436-
if (!clusterService.state().metadata().indices().containsKey(TASK_POLLING_JOB_INDEX)) {
439+
// todo: logic for starting the job
440+
if (!clusterService.state().metadata().indices().containsKey(ML_JOBS_INDEX)) {
437441
mlTaskManager.startTaskPollingJob();
438442
}
439443

@@ -459,6 +463,7 @@ private void runPredict(
459463
} else {
460464
handleAsyncMLTaskComplete(mlTask);
461465
mlModelManager.trackPredictDuration(modelId, startTime);
466+
MLOperationalMetricsCounter.getInstance().incrementCounter(OperationalMetric.MODEL_PREDICT_COUNT, Tags.create().addTag("MODEL_ID", modelId));
462467
internalListener.onResponse(output);
463468
}
464469
}, e -> handlePredictFailure(mlTask, internalListener, e, false, modelId, actionName));

0 commit comments

Comments
 (0)