diff --git a/.github/workflows/test_bwc.yml b/.github/workflows/test_bwc.yml index 83c3ca7eb5..79aafd1ac8 100644 --- a/.github/workflows/test_bwc.yml +++ b/.github/workflows/test_bwc.yml @@ -33,7 +33,8 @@ jobs: echo plugin_version $plugin_version ./gradlew assemble echo "Creating ./plugin/src/test/resources/org/opensearch/ml/bwc..." - mkdir -p ./plugin/src/test/resources/org/opensearch/ml/bwc + mkdir -p ./plugin/src/test/resources/org/opensearch/ml/bwc/job-scheduler + mkdir -p ./plugin/src/test/resources/org/opensearch/ml/bwc/ml - name: Run MLCommons Backwards Compatibility Tests run: | echo "Running backwards compatibility tests ..." diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index d4811731a0..ccbc12cba4 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -44,6 +44,7 @@ public class CommonValue { public static final String ML_MEMORY_META_INDEX = ".plugins-ml-memory-meta"; public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message"; public static final String ML_STOP_WORDS_INDEX = ".plugins-ml-stop-words"; + public static final String TASK_POLLING_JOB_INDEX = ".ml_commons_task_polling_job"; public static final Set stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words"); public static final String TOOL_PARAMETERS_PREFIX = "tools.parameters."; diff --git a/common/src/main/java/org/opensearch/ml/common/MLTaskState.java b/common/src/main/java/org/opensearch/ml/common/MLTaskState.java index dfd7b835d4..a15a470a6f 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLTaskState.java +++ b/common/src/main/java/org/opensearch/ml/common/MLTaskState.java @@ -30,5 +30,6 @@ public enum MLTaskState { CANCELLED, COMPLETED_WITH_ERROR, CANCELLING, - EXPIRED + EXPIRED, + UNREACHABLE } diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java b/common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java index 7c81bb3af9..063b73d8e5 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java +++ b/common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java @@ -26,20 +26,28 @@ @InputDataSet(MLInputDataType.REMOTE) public class RemoteInferenceInputDataSet extends MLInputDataset { private static final Version MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG = CommonValue.VERSION_2_16_0; + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_DLQ_CONFIG = CommonValue.VERSION_2_19_0; @Setter private Map parameters; @Setter private ActionType actionType; + @Setter + private Map dlq; @Builder(toBuilder = true) - public RemoteInferenceInputDataSet(Map parameters, ActionType actionType) { + public RemoteInferenceInputDataSet(Map parameters, ActionType actionType, Map dlq) { super(MLInputDataType.REMOTE); this.parameters = parameters; this.actionType = actionType; + this.dlq = dlq; + } + + public RemoteInferenceInputDataSet(Map parameters, ActionType actionType) { + this(parameters, actionType, null); } public RemoteInferenceInputDataSet(Map parameters) { - this(parameters, null); + this(parameters, null, null); } public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException { @@ -55,6 +63,13 @@ public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException { this.actionType = null; } } + if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DLQ_CONFIG)) { + if (streamInput.readBoolean()) { + dlq = streamInput.readMap(s -> s.readString(), s -> s.readString()); + } else { + this.dlq = null; + } + } } @Override @@ -75,6 +90,14 @@ public void writeTo(StreamOutput streamOutput) throws IOException { streamOutput.writeBoolean(false); } } + if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DLQ_CONFIG)) { + if (dlq != null) { + streamOutput.writeBoolean(true); + streamOutput.writeMap(dlq, StreamOutput::writeString, StreamOutput::writeString); + } else { + streamOutput.writeBoolean(false); + } + } } } diff --git a/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java index f30d845179..8dce708d0d 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java @@ -23,6 +23,7 @@ public class RemoteInferenceMLInput extends MLInput { public static final String PARAMETERS_FIELD = "parameters"; public static final String ACTION_TYPE_FIELD = "action_type"; + public static final String DLQ_FIELD = "dlq"; public RemoteInferenceMLInput(StreamInput in) throws IOException { super(in); @@ -37,6 +38,7 @@ public RemoteInferenceMLInput(XContentParser parser, FunctionName functionName) super(); this.algorithm = functionName; Map parameters = null; + Map dlq = null; ActionType actionType = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -50,12 +52,15 @@ public RemoteInferenceMLInput(XContentParser parser, FunctionName functionName) case ACTION_TYPE_FIELD: actionType = ActionType.from(parser.text()); break; + case DLQ_FIELD: + dlq = StringUtils.getParameterMap(parser.map()); + break; default: parser.skipChildren(); break; } } - inputDataset = new RemoteInferenceInputDataSet(parameters, actionType); + inputDataset = new RemoteInferenceInputDataSet(parameters, actionType, dlq); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetRequest.java index b4579442c4..3c05d5f3a8 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetRequest.java @@ -27,13 +27,26 @@ public class MLTaskGetRequest extends ActionRequest { @Getter String taskId; + @Getter String tenantId; + // This is to identify if the get request is initiated by user or not. During batch task polling job, + // we also perform get operation. This field is to distinguish between + // these two situations. + @Getter + boolean isUserInitiatedGetTaskRequest; + @Builder public MLTaskGetRequest(String taskId, String tenantId) { + this(taskId, tenantId, true); + } + + @Builder + public MLTaskGetRequest(String taskId, String tenantId, Boolean isUserInitiatedGetTaskRequest) { this.taskId = taskId; this.tenantId = tenantId; + this.isUserInitiatedGetTaskRequest = isUserInitiatedGetTaskRequest; } public MLTaskGetRequest(StreamInput in) throws IOException { @@ -41,6 +54,7 @@ public MLTaskGetRequest(StreamInput in) throws IOException { Version streamInputVersion = in.getVersion(); this.taskId = in.readString(); this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null; + this.isUserInitiatedGetTaskRequest = in.readBoolean(); } @Override @@ -51,6 +65,7 @@ public void writeTo(StreamOutput out) throws IOException { if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { out.writeOptionalString(tenantId); } + out.writeBoolean(isUserInitiatedGetTaskRequest); } @Override diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index 4c898fd8de..29395759ff 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -72,9 +72,11 @@ dependencies { exclude group: 'org.bouncycastle', module: 'bcprov-ext-jdk18on' } implementation 'org.bouncycastle:bcprov-jdk18on:1.78.1' - implementation group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12' - implementation group: 'software.amazon.awssdk', name: 's3', version: '2.29.12' - implementation group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12' + + compileOnly group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12' + compileOnly group: 'software.amazon.awssdk', name: 's3', version: '2.29.12' + compileOnly group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12' + implementation 'com.jayway.jsonpath:json-path:2.9.0' implementation group: 'org.json', name: 'json', version: '20231013' implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: '2.29.12' diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java index 27aafd72d8..d3e4e337bb 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java @@ -28,16 +28,10 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput; import org.opensearch.ml.engine.annotation.Ingester; - -import com.google.common.annotations.VisibleForTesting; +import org.opensearch.ml.engine.utils.S3Utils; import lombok.extern.log4j.Log4j2; -import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; -import software.amazon.awssdk.auth.credentials.AwsCredentials; -import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; -import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.core.ResponseInputStream; -import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.GetObjectResponse; @@ -54,7 +48,12 @@ public S3DataIngestion(Client client) { @Override public double ingest(MLBatchIngestionInput mlBatchIngestionInput, int bulkSize) { - S3Client s3 = initS3Client(mlBatchIngestionInput); + String accessKey = mlBatchIngestionInput.getCredential().get(ACCESS_KEY_FIELD); + String secretKey = mlBatchIngestionInput.getCredential().get(SECRET_KEY_FIELD); + String sessionToken = mlBatchIngestionInput.getCredential().get(SESSION_TOKEN_FIELD); + String region = mlBatchIngestionInput.getCredential().get(REGION_FIELD); + + S3Client s3 = S3Utils.initS3Client(accessKey, secretKey, region, sessionToken); List s3Uris = (List) mlBatchIngestionInput.getDataSources().get(SOURCE); if (Objects.isNull(s3Uris) || s3Uris.isEmpty()) { @@ -77,8 +76,8 @@ public double ingestSingleSource( boolean isSoleSource, int bulkSize ) { - String bucketName = getS3BucketName(s3Uri); - String keyName = getS3KeyName(s3Uri); + String bucketName = S3Utils.getS3BucketName(s3Uri); + String keyName = S3Utils.getS3KeyName(s3Uri); GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket(bucketName).key(keyName).build(); double successRate = 0; @@ -153,55 +152,4 @@ public double ingestSingleSource( return successRate; } - - private String getS3BucketName(String s3Uri) { - // Remove the "s3://" prefix - String uriWithoutPrefix = s3Uri.substring(5); - // Find the first slash after the bucket name - int slashIndex = uriWithoutPrefix.indexOf('/'); - // If there is no slash, the entire remaining string is the bucket name - if (slashIndex == -1) { - return uriWithoutPrefix; - } - // Otherwise, the bucket name is the substring up to the first slash - return uriWithoutPrefix.substring(0, slashIndex); - } - - private String getS3KeyName(String s3Uri) { - String uriWithoutPrefix = s3Uri.substring(5); - // Find the first slash after the bucket name - int slashIndex = uriWithoutPrefix.indexOf('/'); - // If there is no slash, it means there is no key, return an empty string or handle as needed - if (slashIndex == -1) { - return ""; - } - // The key name is the substring after the first slash - return uriWithoutPrefix.substring(slashIndex + 1); - } - - @VisibleForTesting - public S3Client initS3Client(MLBatchIngestionInput mlBatchIngestionInput) { - String accessKey = mlBatchIngestionInput.getCredential().get(ACCESS_KEY_FIELD); - String secretKey = mlBatchIngestionInput.getCredential().get(SECRET_KEY_FIELD); - String sessionToken = mlBatchIngestionInput.getCredential().get(SESSION_TOKEN_FIELD); - String region = mlBatchIngestionInput.getCredential().get(REGION_FIELD); - - AwsCredentials credentials = sessionToken == null - ? AwsBasicCredentials.create(accessKey, secretKey) - : AwsSessionCredentials.create(accessKey, secretKey, sessionToken); - - try { - S3Client s3 = AccessController - .doPrivileged( - (PrivilegedExceptionAction) () -> S3Client - .builder() - .region(Region.of(region)) // Specify the region here - .credentialsProvider(StaticCredentialsProvider.create(credentials)) - .build() - ); - return s3; - } catch (PrivilegedActionException e) { - throw new RuntimeException("Can't load credentials", e); - } - } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/S3Utils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/S3Utils.java new file mode 100644 index 0000000000..2c6929bdf1 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/S3Utils.java @@ -0,0 +1,86 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.utils; + +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; + +import com.google.common.annotations.VisibleForTesting; + +import lombok.extern.log4j.Log4j2; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentials; +import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; + +@Log4j2 +public class S3Utils { + @VisibleForTesting + public static S3Client initS3Client(String accessKey, String secretKey, String sessionToken, String region) { + AwsCredentials credentials = sessionToken == null + ? AwsBasicCredentials.create(accessKey, secretKey) + : AwsSessionCredentials.create(accessKey, secretKey, sessionToken); + + try { + S3Client s3 = AccessController + .doPrivileged( + (PrivilegedExceptionAction) () -> S3Client + .builder() + .region(Region.of(region)) // Specify the region here + .credentialsProvider(StaticCredentialsProvider.create(credentials)) + .build() + ); + return s3; + } catch (PrivilegedActionException e) { + throw new RuntimeException("Can't load credentials", e); + } + } + + public static void putObject(S3Client s3Client, String bucketName, String key, String content) { + try { + AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + PutObjectRequest request = PutObjectRequest.builder().bucket(bucketName).key(key).build(); + + s3Client.putObject(request, RequestBody.fromString(content)); + log.debug("Successfully uploaded file to S3: s3://{}/{}", bucketName, key); + return null; // Void return type for doPrivileged + }); + } catch (PrivilegedActionException e) { + throw new RuntimeException("Failed to upload file to S3: s3://" + bucketName + "/" + key, e); + } + } + + public static String getS3BucketName(String s3Uri) { + // Remove the "s3://" prefix + String uriWithoutPrefix = s3Uri.substring(5); + // Find the first slash after the bucket name + int slashIndex = uriWithoutPrefix.indexOf('/'); + // If there is no slash, the entire remaining string is the bucket name + if (slashIndex == -1) { + return uriWithoutPrefix; + } + // Otherwise, the bucket name is the substring up to the first slash + return uriWithoutPrefix.substring(0, slashIndex); + } + + public static String getS3KeyName(String s3Uri) { + String uriWithoutPrefix = s3Uri.substring(5); + // Find the first slash after the bucket name + int slashIndex = uriWithoutPrefix.indexOf('/'); + // If there is no slash, it means there is no key, return an empty string or handle as needed + if (slashIndex == -1) { + return ""; + } + // The key name is the substring after the first slash + return uriWithoutPrefix.substring(slashIndex + 1); + } + +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/S3UtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/S3UtilsTest.java new file mode 100644 index 0000000000..18e5b6078e --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/S3UtilsTest.java @@ -0,0 +1,89 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.utils; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; + +public class S3UtilsTest { + + @Mock + private S3Client s3Client; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + } + + @Test + public void testInitS3Client() { + String accessKey = "test-access-key"; + String secretKey = "test-secret-key"; + String sessionToken = "test-session-token"; + String region = "us-west-2"; + + S3Client client = S3Utils.initS3Client(accessKey, secretKey, sessionToken, region); + assertNotNull(client); + } + + @Test + public void testInitS3ClientWithoutSessionToken() { + String accessKey = "test-access-key"; + String secretKey = "test-secret-key"; + String region = "us-west-2"; + + S3Client client = S3Utils.initS3Client(accessKey, secretKey, null, region); + assertNotNull(client); + } + + @Test + public void testPutObject() { + String bucketName = "test-bucket"; + String key = "test-key"; + String content = "test-content"; + + when(s3Client.putObject(any(PutObjectRequest.class), any(RequestBody.class))).thenReturn(PutObjectResponse.builder().build()); + + S3Utils.putObject(s3Client, bucketName, key, content); + + verify(s3Client, times(1)).putObject(any(PutObjectRequest.class), any(RequestBody.class)); + } + + @Test + public void testGetS3BucketName() { + String s3Uri = "s3://test-bucket/path/to/file"; + assertEquals("test-bucket", S3Utils.getS3BucketName(s3Uri)); + + s3Uri = "s3://test-bucket"; + assertEquals("test-bucket", S3Utils.getS3BucketName(s3Uri)); + } + + @Test + public void testGetS3KeyName() { + String s3Uri = "s3://test-bucket/path/to/file"; + assertEquals("path/to/file", S3Utils.getS3KeyName(s3Uri)); + + s3Uri = "s3://test-bucket"; + assertEquals("", S3Utils.getS3KeyName(s3Uri)); + } +} diff --git a/plugin/build.gradle b/plugin/build.gradle index 97ebdb12a2..b6cffa8158 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -39,6 +39,11 @@ opensearchplugin { name 'opensearch-ml' description 'machine learning plugin for opensearch' classname 'org.opensearch.ml.plugin.MachineLearningPlugin' + extendedPlugins = ['opensearch-job-scheduler'] +} + +configurations { + zipArchive } dependencies { @@ -47,7 +52,20 @@ dependencies { implementation project(':opensearch-ml-algorithms') implementation project(':opensearch-ml-search-processors') implementation project(':opensearch-ml-memory') + compileOnly "com.google.guava:guava:32.1.3-jre" + + implementation group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12' + implementation group: 'software.amazon.awssdk', name: 's3', version: '2.29.12' + implementation group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12' + + implementation group: 'software.amazon.awssdk', name: 'aws-xml-protocol', version: '2.29.12' + implementation group: 'software.amazon.awssdk', name: 'aws-query-protocol', version: '2.29.12' + + implementation group: 'software.amazon.awssdk', name: 'protocol-core', version: '2.29.12' + + zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${opensearch_build}" + compileOnly "org.opensearch:opensearch-job-scheduler-spi:${opensearch_build}" implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" implementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}" // Multi-tenant SDK Client @@ -92,6 +110,9 @@ publishing { repositories { maven { + mavenLocal() + mavenCentral() + maven { url "https://ci.opensearch.org/ci/dbc/snapshots/lucene/" } name = "Snapshots" url = "https://aws.oss.sonatype.org/content/repositories/snapshots" credentials { @@ -212,6 +233,17 @@ testClusters.integTest { } } plugin(project.tasks.bundlePlugin.archiveFile) + plugin(provider(new Callable(){ + @Override + RegularFile call() throws Exception { + return new RegularFile() { + @Override + File getAsFile() { + return configurations.zipArchive.asFileTree.getSingleFile() + } + } + } + })) nodes.each { node -> def plugins = node.plugins @@ -221,6 +253,21 @@ testClusters.integTest { } } +testClusters.yamlRestTest { + + plugin(provider(new Callable(){ + @Override + RegularFile call() throws Exception { + return new RegularFile() { + @Override + File getAsFile() { + return configurations.zipArchive.asFileTree.getSingleFile() + } + } + } + })) +} + task integTestRemote(type: RestIntegTestTask) { testClassesDirs = sourceSets.test.output.classesDirs classpath = sourceSets.test.runtimeClasspath @@ -310,7 +357,9 @@ List jacocoExclusions = [ 'org.opensearch.ml.action.models.DeleteModelTransportAction.2', 'org.opensearch.ml.model.MLModelCacheHelper', 'org.opensearch.ml.model.MLModelCacheHelper.1', - 'org.opensearch.ml.action.tasks.CancelBatchJobTransportAction' + 'org.opensearch.ml.action.tasks.CancelBatchJobTransportAction', + 'org.opensearch.ml.jobs.MLBatchTaskUpdateExtension', + 'org.opensearch.ml.jobs.MLBatchTaskUpdateJobRunner' ] @@ -340,6 +389,7 @@ check.dependsOn jacocoTestCoverageVerification configurations.all { exclude group: "org.jetbrains", module: "annotations" + exclude group: "com.google.guava", module: "failureaccess" resolutionStrategy.force 'org.apache.commons:commons-lang3:3.10' resolutionStrategy.force 'commons-logging:commons-logging:1.2' resolutionStrategy.force 'org.objenesis:objenesis:3.2' @@ -429,40 +479,73 @@ tasks.withType(licenseHeaders.class) { String bwcVersion = "2.4.0.0" String bwcShortVersion = bwcVersion[0..4] String baseName = "mlCommonsBwcCluster" -String bwcMlPlugin = "opensearch-ml-" + bwcVersion + ".zip" -String bwcFilePath = "src/test/resources/org/opensearch/ml/bwc/" -String bwcRemoteFile = "https://ci.opensearch.org/ci/dbc/distribution-build-opensearch/" + bwcShortVersion + "/latest/linux/x64/tar/builds/opensearch/plugins/" + bwcMlPlugin +String bwcBasePath = "src/test/resources/org/opensearch/ml/bwc/" +String bwcMLFilePath = bwcBasePath + "ml/" +String bwcJobSchedulerFilePath = bwcBasePath + "job-scheduler/" String opensearchMlPlugin = "opensearch-ml-" + project.version + ".zip" +String bwcMlPlugin = "opensearch-ml-" + bwcVersion + ".zip" +String bwcJobSchedulerPlugin = "opensearch-job-scheduler-" + bwcVersion + ".zip" +String bwcRemoteFile = "https://ci.opensearch.org/ci/dbc/distribution-build-opensearch/" + bwcShortVersion + "/latest/linux/x64/tar/builds/opensearch/plugins/" +String bwcRemoteMlFile = bwcRemoteFile + bwcMlPlugin +String bwcRemoteJobSchedulerFile = bwcRemoteFile + bwcJobSchedulerPlugin + 2.times {i -> testClusters { "${baseName}$i" { testDistribution = "ARCHIVE" versions = [bwcShortVersion, opensearch_version] numberOfNodes = 3 + + plugin(provider(new Callable() { + @Override + RegularFile call() throws Exception { + return new RegularFile() { + @Override + File getAsFile() { + + File bwcDir = new File('./plugin/' + bwcBasePath) + if (!bwcDir.exists()) { + bwcDir.mkdirs() + } + File dir = new File('./plugin/' + bwcJobSchedulerFilePath + bwcVersion) + if (!dir.exists()) { + dir.mkdirs() + } + File mlJobSchedulerFile = new File(dir, bwcJobSchedulerPlugin) + if (!mlJobSchedulerFile.exists()) { + new URL(bwcRemoteJobSchedulerFile).withInputStream{ ins -> mlJobSchedulerFile.withOutputStream{ it << ins }} + } + return fileTree(bwcJobSchedulerFilePath + bwcVersion).getSingleFile() + } + } + } + })) + plugin(provider(new Callable() { @Override RegularFile call() throws Exception { return new RegularFile() { @Override File getAsFile() { - File bwcDir = new File('./plugin/' + bwcFilePath) + File bwcDir = new File('./plugin/' + bwcBasePath) if (!bwcDir.exists()) { bwcDir.mkdirs() } - File dir = new File('./plugin/' + bwcFilePath + bwcVersion) + File dir = new File('./plugin/' + bwcMLFilePath + bwcVersion) if (!dir.exists()) { dir.mkdirs() } - File f = new File(dir, bwcMlPlugin) - if (!f.exists()) { - new URL(bwcRemoteFile).withInputStream{ ins -> f.withOutputStream{ it << ins }} + File mlPluginFile = new File(dir, bwcMlPlugin) + if (!mlPluginFile.exists()) { + new URL(bwcRemoteMlFile).withInputStream{ ins -> mlPluginFile.withOutputStream{ it << ins }} } - return fileTree(bwcFilePath + bwcVersion).getSingleFile() + return fileTree(bwcMLFilePath + bwcVersion).getSingleFile() } } } })) + setting 'path.repo', "${buildDir}/cluster/shared/repo/${baseName}" setting 'http.content_type.required', 'true' } @@ -476,12 +559,23 @@ List> plugins = [ return new RegularFile() { @Override File getAsFile() { - project.mkdir "$bwcFilePath/$project.version" + return configurations.zipArchive.asFileTree.getSingleFile() + } + } + } + }), + provider(new Callable() { + @Override + RegularFile call() throws Exception { + return new RegularFile() { + @Override + File getAsFile() { + project.mkdir "$bwcMLFilePath/$project.version" copy { from "$buildDir/distributions/$opensearchMlPlugin" - into "$bwcFilePath/$project.version" + into "$bwcMLFilePath/$project.version" } - return fileTree(bwcFilePath + project.version).getSingleFile() + return fileTree(bwcMLFilePath + project.version).getSingleFile() } } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java index 5402d47456..6ba73af44f 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java @@ -181,7 +181,11 @@ private void processRemoteBatchPrediction(MLTask mlTask, ActionListener { @@ -115,6 +128,10 @@ public class GetTaskTransportAction extends HandledTransportAction decryptedCredential; @Inject public GetTaskTransportAction( @@ -131,7 +148,8 @@ public GetTaskTransportAction( MLTaskManager mlTaskManager, MLModelManager mlModelManager, MLFeatureEnabledSetting mlFeatureEnabledSetting, - Settings settings + Settings settings, + MLEngine mlEngine ) { super(MLTaskGetAction.NAME, transportService, actionFilters, MLTaskGetRequest::new); this.client = client; @@ -145,6 +163,7 @@ public GetTaskTransportAction( this.mlTaskManager = mlTaskManager; this.mlModelManager = mlModelManager; this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + this.mlEngine = mlEngine; remoteJobStatusFields = ML_COMMONS_REMOTE_JOB_STATUS_FIELD.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_REMOTE_JOB_STATUS_FIELD, it -> remoteJobStatusFields = it); @@ -172,6 +191,12 @@ public GetTaskTransportAction( clusterService, (regex) -> remoteJobExpiredStatusRegexPattern = Pattern.compile(regex, Pattern.CASE_INSENSITIVE) ); + initializeRegexPattern( + ML_COMMONS_REMOTE_JOB_STATUS_FAILED_REGEX, + settings, + clusterService, + (regex) -> remoteJobFailedStatusRegexPattern = Pattern.compile(regex, Pattern.CASE_INSENSITIVE) + ); } private void initializeRegexPattern( @@ -189,6 +214,8 @@ private void initializeRegexPattern( protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.fromActionRequest(request); String taskId = mlTaskGetRequest.getTaskId(); + Boolean isUserInitiatedGetTaskRequest = mlTaskGetRequest.isUserInitiatedGetTaskRequest(); + String tenantId = mlTaskGetRequest.getTenantId(); if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) { @@ -207,7 +234,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { context.restore(); - handleAsyncResponse(r, throwable, taskId, tenantId, actionListener); + handleAsyncResponse(r, throwable, taskId, isUserInitiatedGetTaskRequest, tenantId, actionListener); }); } catch (Exception e) { log.error("Failed to get ML task {}", taskId, e); @@ -219,6 +246,7 @@ private void handleAsyncResponse( GetDataObjectResponse response, Throwable throwable, String taskId, + Boolean isUserInitiatedGetTaskRequest, String tenantId, ActionListener actionListener ) { @@ -228,8 +256,7 @@ private void handleAsyncResponse( handleThrowable(throwable, taskId, actionListener); return; } - - processResponse(response, taskId, tenantId, actionListener); + processResponse(response, taskId, isUserInitiatedGetTaskRequest, tenantId, actionListener); } private void handleThrowable(Throwable throwable, String taskId, ActionListener actionListener) { @@ -247,6 +274,7 @@ private void handleThrowable(Throwable throwable, String taskId, ActionListener< private void processResponse( GetDataObjectResponse response, String taskId, + Boolean isUserInitiatedGetTaskRequest, String tenantId, ActionListener actionListener ) { @@ -258,14 +286,20 @@ private void processResponse( return; } - parseAndHandleTask(gr, taskId, tenantId, actionListener); + parseAndHandleTask(gr, taskId, isUserInitiatedGetTaskRequest, tenantId, actionListener); } catch (Exception e) { log.error("Failed to parse GetDataObjectResponse for task {}", taskId, e); actionListener.onFailure(e); } } - private void parseAndHandleTask(GetResponse gr, String taskId, String tenantId, ActionListener actionListener) { + private void parseAndHandleTask( + GetResponse gr, + String taskId, + Boolean isUserInitiatedGetTaskRequest, + String tenantId, + ActionListener actionListener + ) { try ( XContentParser parser = jsonXContent.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, gr.getSourceAsString()) ) { @@ -282,7 +316,7 @@ private void parseAndHandleTask(GetResponse gr, String taskId, String tenantId, } if (mlTask.getTaskType() == MLTaskType.BATCH_PREDICTION && mlTask.getFunctionName() == FunctionName.REMOTE) { - processRemoteBatchPrediction(mlTask, taskId, tenantId, actionListener); + processRemoteBatchPrediction(mlTask, taskId, isUserInitiatedGetTaskRequest, tenantId, actionListener); } else { actionListener.onResponse(MLTaskGetResponse.builder().mlTask(mlTask).build()); } @@ -295,6 +329,7 @@ private void parseAndHandleTask(GetResponse gr, String taskId, String tenantId, private void processRemoteBatchPrediction( MLTask mlTask, String taskId, + Boolean isUserInitiatedGetTaskRequest, String tenantId, ActionListener actionListener ) { @@ -319,7 +354,11 @@ private void processRemoteBatchPrediction( .orElse(null) ); - RemoteInferenceInputDataSet inferenceInputDataSet = new RemoteInferenceInputDataSet(parameters, ActionType.BATCH_PREDICT_STATUS); + RemoteInferenceInputDataSet inferenceInputDataSet = new RemoteInferenceInputDataSet( + parameters, + ActionType.BATCH_PREDICT_STATUS, + null + ); MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inferenceInputDataSet).build(); String modelId = mlTask.getModelId(); User user = RestActionUtils.getUserContext(client); @@ -342,10 +381,26 @@ private void processRemoteBatchPrediction( } else { if (model.getConnector() != null) { Connector connector = model.getConnector(); - executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener); + executeConnector( + connector, + mlInput, + taskId, + isUserInitiatedGetTaskRequest, + mlTask, + remoteJob, + actionListener + ); } else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) { ActionListener listener = ActionListener.wrap(connector -> { - executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener); + executeConnector( + connector, + mlInput, + taskId, + isUserInitiatedGetTaskRequest, + mlTask, + remoteJob, + actionListener + ); }, e -> { log.error("Failed to get connector {}", model.getConnectorId(), e); actionListener.onFailure(e); @@ -392,6 +447,7 @@ private void executeConnector( Connector connector, MLInput mlInput, String taskId, + Boolean isUserInitiatedGetTaskRequest, MLTask mlTask, Map remoteJob, ActionListener actionListener @@ -401,23 +457,58 @@ private void executeConnector( ConnectorAction connectorAction = ConnectorUtils.createConnectorAction(connector, BATCH_PREDICT_STATUS); connector.addAction(connectorAction); } - // as we haven't implemented multi-tenancy in batch prediction yet, assigning null as tenantId - connector.decrypt(BATCH_PREDICT_STATUS.name(), (credential, tenantId) -> encryptor.decrypt(credential, null), null); + + final Map decryptedCredential = connector.getDecryptedCredential() != null + && !connector.getDecryptedCredential().isEmpty() + ? mlEngine.getConnectorCredential(connector) + : connector.getDecryptedCredential(); RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class); connectorExecutor.setScriptService(scriptService); connectorExecutor.setClusterService(clusterService); connectorExecutor.setClient(client); connectorExecutor.setXContentRegistry(xContentRegistry); connectorExecutor.executeAction(BATCH_PREDICT_STATUS.name(), mlInput, ActionListener.wrap(taskResponse -> { - processTaskResponse(mlTask, taskId, taskResponse, remoteJob, actionListener); - }, actionListener::onFailure)); + processTaskResponse( + mlTask, + taskId, + isUserInitiatedGetTaskRequest, + taskResponse, + remoteJob, + decryptedCredential, + actionListener + ); + }, e -> { + // When the request to remote service fails, we will retry the request for next 10 minutes (10 runs). + // If it fails even then, we mark it as unreachable in task index and send message to DLQ + if (!isUserInitiatedGetTaskRequest) { + Map updatedTask = new HashMap<>(); + Integer numberOfRetries = (Integer) remoteJob.getOrDefault("num_of_retries", 0); + remoteJob.put("num_of_retries", ++numberOfRetries); + if (numberOfRetries > 10) { + log + .debug( + "Limit exceeded trying to reach the task {} . Marking as UNREACHABLE in task index and removing from further execution", + taskId + ); + updatedTask.put(STATE_FIELD, UNREACHABLE); + mlTask.setState(UNREACHABLE); + mlTask.setError(e.getMessage()); + updateDLQ(mlTask, decryptedCredential); + } + updatedTask.put("remote_job", remoteJob); + mlTaskManager.updateMLTaskDirectly(taskId, updatedTask); + } + actionListener.onFailure(e); + })); } protected void processTaskResponse( MLTask mlTask, String taskId, + Boolean isUserInitiatedGetTaskRequest, MLTaskResponse taskResponse, Map remoteJob, + Map decryptedCredential, ActionListener actionListener ) { try { @@ -430,6 +521,7 @@ protected void processTaskResponse( remoteJob.putAll(remoteJobStatus); Map updatedTask = new HashMap<>(); updatedTask.put(REMOTE_JOB_FIELD, remoteJob); + mlTask.setRemoteJob(remoteJob); for (String statusField : remoteJobStatusFields) { String statusValue = String.valueOf(remoteJob.get(statusField)); @@ -437,7 +529,11 @@ protected void processTaskResponse( updateTaskState(updatedTask, mlTask, statusValue); } } + mlTaskManager.updateMLTaskDirectly(taskId, updatedTask, ActionListener.wrap(response -> { + if (mlTask.getState().equals(FAILED) && !isUserInitiatedGetTaskRequest) { + updateDLQ(mlTask, decryptedCredential); + } actionListener.onResponse(MLTaskGetResponse.builder().mlTask(mlTask).build()); }, e -> { logException("Failed to update task for batch predict model", e, log); @@ -460,6 +556,43 @@ protected void processTaskResponse( } } + @VisibleForTesting + protected void updateDLQ(MLTask mlTask, Map decryptedCredential) { + Map remoteJob = mlTask.getRemoteJob(); + Map dlq = (Map) remoteJob.get("dlq"); + if (dlq != null && !dlq.isEmpty()) { + String taskId = mlTask.getTaskId(); + try { + Map remoteJobDetails = mlTask.getRemoteJob(); + String accessKey = decryptedCredential.get(ACCESS_KEY_FIELD); + String secretKey = decryptedCredential.get(SECRET_KEY_FIELD); + String sessionToken = decryptedCredential.get(SESSION_TOKEN_FIELD); + + String bucketName = dlq.get("bucket"); + String region = dlq.get("region"); + + if (bucketName == null || region == null) { + log.error("Failed to get the bucket name and region from batch predict request"); + } + remoteJobDetails.remove("dlq"); + try (S3Client s3Client = S3Utils.initS3Client(accessKey, secretKey, sessionToken, region)) { + String jobName = (String) remoteJobDetails.getOrDefault("TransformJobName", remoteJob.get("job_name")); + String s3ObjectKey = "BatchJobFailure_" + jobName; + String content = mlTask.getState().equals(UNREACHABLE) + ? String.format("Unable to reach the Job: %s. Error Message: %s", jobName, mlTask.getError()) + : remoteJobDetails.toString(); + + S3Utils.putObject(s3Client, bucketName, s3ObjectKey, content); + log.debug("Task status successfully uploaded to S3 for task ID: {} at {}", taskId, Instant.now()); + } + } catch (S3Exception e) { + log.error("Failed to update task status for task: {}. S3 Exception: {}", taskId, e.awsErrorDetails().errorMessage()); + } catch (Exception e) { + log.error("Failed to update task status for task: " + taskId, e); + } + } + } + private void updateTaskState(Map updatedTask, MLTask mlTask, String statusValue) { if (matchesPattern(remoteJobCancellingStatusRegexPattern, statusValue)) { updatedTask.put(STATE_FIELD, CANCELLING); @@ -473,6 +606,9 @@ private void updateTaskState(Map updatedTask, MLTask mlTask, Str } else if (matchesPattern(remoteJobExpiredStatusRegexPattern, statusValue)) { updatedTask.put(STATE_FIELD, EXPIRED); mlTask.setState(EXPIRED); + } else if (matchesPattern(remoteJobFailedStatusRegexPattern, statusValue)) { + updatedTask.put(STATE_FIELD, FAILED); + mlTask.setState(FAILED); } } diff --git a/plugin/src/main/java/org/opensearch/ml/jobs/MLBatchTaskUpdateExtension.java b/plugin/src/main/java/org/opensearch/ml/jobs/MLBatchTaskUpdateExtension.java new file mode 100644 index 0000000000..775a0e5714 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/jobs/MLBatchTaskUpdateExtension.java @@ -0,0 +1,88 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.jobs; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.time.Instant; + +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.xcontent.XContentParserUtils; +import org.opensearch.jobscheduler.spi.JobSchedulerExtension; +import org.opensearch.jobscheduler.spi.ScheduledJobParser; +import org.opensearch.jobscheduler.spi.ScheduledJobRunner; +import org.opensearch.jobscheduler.spi.schedule.ScheduleParser; +import org.opensearch.ml.common.CommonValue; + +public class MLBatchTaskUpdateExtension implements JobSchedulerExtension { + + @Override + public String getJobType() { + return "checkBatchJobTaskStatus"; + } + + @Override + public ScheduledJobRunner getJobRunner() { + return MLBatchTaskUpdateJobRunner.getJobRunnerInstance(); + } + + @Override + public ScheduledJobParser getJobParser() { + return (parser, id, jobDocVersion) -> { + MLBatchTaskUpdateJobParameter jobParameter = new MLBatchTaskUpdateJobParameter(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + + while (!parser.nextToken().equals(XContentParser.Token.END_OBJECT)) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case MLBatchTaskUpdateJobParameter.NAME_FIELD: + jobParameter.setJobName(parser.text()); + break; + case MLBatchTaskUpdateJobParameter.ENABLED_FILED: + jobParameter.setEnabled(parser.booleanValue()); + break; + case MLBatchTaskUpdateJobParameter.ENABLED_TIME_FILED: + jobParameter.setEnabledTime(parseInstantValue(parser)); + break; + case MLBatchTaskUpdateJobParameter.LAST_UPDATE_TIME_FIELD: + jobParameter.setLastUpdateTime(parseInstantValue(parser)); + break; + case MLBatchTaskUpdateJobParameter.SCHEDULE_FIELD: + jobParameter.setSchedule(ScheduleParser.parse(parser)); + break; + case MLBatchTaskUpdateJobParameter.LOCK_DURATION_SECONDS: + jobParameter.setLockDurationSeconds(parser.longValue()); + break; + case MLBatchTaskUpdateJobParameter.JITTER: + jobParameter.setJitter(parser.doubleValue()); + break; + default: + XContentParserUtils.throwUnknownToken(parser.currentToken(), parser.getTokenLocation()); + } + } + return jobParameter; + }; + } + + private Instant parseInstantValue(XContentParser parser) throws IOException { + if (XContentParser.Token.VALUE_NULL.equals(parser.currentToken())) { + return null; + } + if (parser.currentToken().isValue()) { + return Instant.ofEpochMilli(parser.longValue()); + } + XContentParserUtils.throwUnknownToken(parser.currentToken(), parser.getTokenLocation()); + return null; + } + + @Override + public String getJobIndex() { + return CommonValue.TASK_POLLING_JOB_INDEX; + } + +} diff --git a/plugin/src/main/java/org/opensearch/ml/jobs/MLBatchTaskUpdateJobParameter.java b/plugin/src/main/java/org/opensearch/ml/jobs/MLBatchTaskUpdateJobParameter.java new file mode 100644 index 0000000000..c12b66a1b7 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/jobs/MLBatchTaskUpdateJobParameter.java @@ -0,0 +1,137 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.jobs; + +import java.io.IOException; +import java.time.Instant; + +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.jobscheduler.spi.ScheduledJobParameter; +import org.opensearch.jobscheduler.spi.schedule.Schedule; + +/** + * A sample job parameter. + *

+ * It adds an additional "indexToWatch" field to {@link ScheduledJobParameter}, which stores the index + * the job runner will watch. + */ +public class MLBatchTaskUpdateJobParameter implements ScheduledJobParameter { + public static final String NAME_FIELD = "name"; + public static final String ENABLED_FILED = "enabled"; + public static final String LAST_UPDATE_TIME_FIELD = "last_update_time"; + public static final String LAST_UPDATE_TIME_FIELD_READABLE = "last_update_time_field"; + public static final String SCHEDULE_FIELD = "schedule"; + public static final String ENABLED_TIME_FILED = "enabled_time"; + public static final String ENABLED_TIME_FILED_READABLE = "enabled_time_field"; + public static final String INDEX_NAME_FIELD = "index_name_to_watch"; + public static final String LOCK_DURATION_SECONDS = "lock_duration_seconds"; + public static final String JITTER = "jitter"; + + private String jobName; + private Instant lastUpdateTime; + private Instant enabledTime; + private boolean isEnabled; + private Schedule schedule; + private Long lockDurationSeconds; + private Double jitter; + + public MLBatchTaskUpdateJobParameter() {} + + public MLBatchTaskUpdateJobParameter(String name, Schedule schedule, Long lockDurationSeconds, Double jitter) { + this.jobName = name; + this.schedule = schedule; + this.lockDurationSeconds = lockDurationSeconds; + this.jitter = jitter; + + Instant now = Instant.now(); + this.isEnabled = true; + this.enabledTime = now; + this.lastUpdateTime = now; + } + + @Override + public String getName() { + return this.jobName; + } + + @Override + public Instant getLastUpdateTime() { + return this.lastUpdateTime; + } + + @Override + public Instant getEnabledTime() { + return this.enabledTime; + } + + @Override + public Schedule getSchedule() { + return this.schedule; + } + + @Override + public boolean isEnabled() { + return this.isEnabled; + } + + @Override + public Long getLockDurationSeconds() { + return this.lockDurationSeconds; + } + + @Override + public Double getJitter() { + return jitter; + } + + public void setJobName(String jobName) { + this.jobName = jobName; + } + + public void setLastUpdateTime(Instant lastUpdateTime) { + this.lastUpdateTime = lastUpdateTime; + } + + public void setEnabledTime(Instant enabledTime) { + this.enabledTime = enabledTime; + } + + public void setEnabled(boolean enabled) { + isEnabled = enabled; + } + + public void setSchedule(Schedule schedule) { + this.schedule = schedule; + } + + public void setLockDurationSeconds(Long lockDurationSeconds) { + this.lockDurationSeconds = lockDurationSeconds; + } + + public void setJitter(Double jitter) { + this.jitter = jitter; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(NAME_FIELD, this.jobName).field(ENABLED_FILED, this.isEnabled).field(SCHEDULE_FIELD, this.schedule); + if (this.enabledTime != null) { + builder.timeField(ENABLED_TIME_FILED, ENABLED_TIME_FILED_READABLE, this.enabledTime.toEpochMilli()); + } + if (this.lastUpdateTime != null) { + builder.timeField(LAST_UPDATE_TIME_FIELD, LAST_UPDATE_TIME_FIELD_READABLE, this.lastUpdateTime.toEpochMilli()); + } + if (this.lockDurationSeconds != null) { + builder.field(LOCK_DURATION_SECONDS, this.lockDurationSeconds); + } + if (this.jitter != null) { + builder.field(JITTER, this.jitter); + } + builder.endObject(); + return builder; + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/jobs/MLBatchTaskUpdateJobRunner.java b/plugin/src/main/java/org/opensearch/ml/jobs/MLBatchTaskUpdateJobRunner.java new file mode 100644 index 0000000000..44eb15ee20 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/jobs/MLBatchTaskUpdateJobRunner.java @@ -0,0 +1,168 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.jobs; + +import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; + +import java.time.Instant; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.action.ActionListener; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.jobscheduler.spi.JobExecutionContext; +import org.opensearch.jobscheduler.spi.ScheduledJobParameter; +import org.opensearch.jobscheduler.spi.ScheduledJobRunner; +import org.opensearch.jobscheduler.spi.utils.LockService; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.transport.task.MLTaskGetAction; +import org.opensearch.ml.common.transport.task.MLTaskGetRequest; +import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.threadpool.ThreadPool; + +public class MLBatchTaskUpdateJobRunner implements ScheduledJobRunner { + private static final Logger log = LogManager.getLogger(ScheduledJobRunner.class); + + private static MLBatchTaskUpdateJobRunner INSTANCE; + + public static MLBatchTaskUpdateJobRunner getJobRunnerInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (MLBatchTaskUpdateJobRunner.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new MLBatchTaskUpdateJobRunner(); + return INSTANCE; + } + } + + private ClusterService clusterService; + private ThreadPool threadPool; + private Client client; + private MLTaskManager taskManager; + private boolean initialized; + + private MLBatchTaskUpdateJobRunner() { + // Singleton class, use getJobRunner method instead of constructor + } + + public void setClusterService(ClusterService clusterService) { + this.clusterService = clusterService; + } + + public void setThreadPool(ThreadPool threadPool) { + this.threadPool = threadPool; + } + + public void setClient(Client client) { + this.client = client; + } + + public void initialize(final ClusterService clusterService, final ThreadPool threadPool, final Client client) { + this.clusterService = clusterService; + this.threadPool = threadPool; + this.client = client; + this.initialized = true; + } + + @Override + public void runJob(ScheduledJobParameter scheduledJobParameter, JobExecutionContext jobExecutionContext) { + if (initialized == false) { + throw new AssertionError("this instance is not initialized"); + } + + final LockService lockService = jobExecutionContext.getLockService(); + + Runnable runnable = () -> { + lockService.acquireLock(scheduledJobParameter, jobExecutionContext, ActionListener.wrap(lock -> { + if (lock == null) { + return; + } + + try { + String jobName = scheduledJobParameter.getName(); + log.info("Starting job execution for job ID: {} at {}", jobName, Instant.now()); + + log.debug("Running batch task polling job"); + + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + BoolQueryBuilder boolQuery = QueryBuilders + .boolQuery() + .must(QueryBuilders.termQuery("task_type", MLTaskType.BATCH_PREDICTION)) + .must(QueryBuilders.termQuery("function_name", FunctionName.REMOTE)) + .must( + QueryBuilders + .boolQuery() + .should(QueryBuilders.termQuery("state", MLTaskState.RUNNING)) + .should(QueryBuilders.termQuery("state", MLTaskState.CANCELLING)) + ); + + sourceBuilder.query(boolQuery); + sourceBuilder.size(100); + sourceBuilder.fetchSource(new String[] { "_id" }, null); + + SearchRequest searchRequest = new SearchRequest(ML_TASK_INDEX); + searchRequest.source(sourceBuilder); + + client.search(searchRequest, ActionListener.wrap(response -> { + if (response == null || response.getHits() == null || response.getHits().getHits().length == 0) { + log.info("No pending tasks found to be polled by the job"); + return; + } + + SearchHit[] searchHits = response.getHits().getHits(); + for (SearchHit searchHit : searchHits) { + String taskId = searchHit.getId(); + log.debug("Starting polling for task: {} at {}", taskId, Instant.now()); + MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest + .builder() + .taskId(taskId) + .isUserInitiatedGetTaskRequest(false) + .build(); + + client.execute(MLTaskGetAction.INSTANCE, mlTaskGetRequest, ActionListener.wrap(taskResponse -> { + log.info("Updated Task status for taskId: {} at {}", taskId, Instant.now()); + }, exception -> { log.error("Failed to get task status for task: " + taskId, exception); })); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + log.info("No tasks found to be polled by the job"); + } else { + log.error("Failed to search for tasks to be polled by the job ", e); + } + })); + + log.info("Completed job execution for job ID: {} at {}", jobName, Instant.now()); + } finally { + lockService + .release( + lock, + ActionListener + .wrap( + released -> { log.debug("Released lock for job {}", scheduledJobParameter.getName()); }, + exception -> { + throw new IllegalStateException("Failed to release lock."); + } + ) + ); + } + }, exception -> { throw new IllegalStateException("Failed to acquire lock."); })); + }; + + threadPool.generic().submit(runnable); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index fa6df11dd1..ae0c890d61 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -200,6 +200,7 @@ import org.opensearch.ml.engine.utils.AgentModelsSearcher; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.jobs.MLBatchTaskUpdateJobRunner; import org.opensearch.ml.memory.ConversationalMemoryHandler; import org.opensearch.ml.memory.action.conversation.CreateConversationAction; import org.opensearch.ml.memory.action.conversation.CreateConversationTransportAction; @@ -722,6 +723,8 @@ public Collection createComponents( .getClusterSettings() .addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, it -> ragSearchPipelineEnabled = it); + MLBatchTaskUpdateJobRunner.getJobRunnerInstance().initialize(clusterService, threadPool, client); + return ImmutableList .of( encryptor, @@ -1036,6 +1039,7 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX, MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX, MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX, + MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_FAILED_REGEX, MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED, MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED, MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED, diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 082fe04535..8f1d4d8ba8 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -244,7 +244,7 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX = Setting .simpleString( "plugins.ml_commons.remote_job.status_regex.completed", - "(complete|completed)", + "(complete|completed|partiallyCompleted)", Setting.Property.NodeScope, Setting.Property.Dynamic ); @@ -270,6 +270,14 @@ private MLCommonsSettings() {} Setting.Property.Dynamic ); + public static final Setting ML_COMMONS_REMOTE_JOB_STATUS_FAILED_REGEX = Setting + .simpleString( + "plugins.ml_commons.remote_job.status_regex.failed", + "(failed)", + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + public static final Setting ML_COMMONS_CONTROLLER_ENABLED = Setting .boolSetting("plugins.ml_commons.controller_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index cf359a1c53..8b8c9daeeb 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -7,6 +7,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.common.CommonValue.TASK_POLLING_JOB_INDEX; import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD; import static org.opensearch.ml.common.utils.StringUtils.getErrorMessage; import static org.opensearch.ml.permission.AccessController.checkUserPermissions; @@ -97,6 +98,9 @@ public class MLPredictTaskRunner extends MLTaskRunner listener ) { String modelId = request.getModelId(); + Map dlq; + String bucketName, region; + if (request.getMlInput().getInputDataset() instanceof RemoteInferenceInputDataSet) { + RemoteInferenceInputDataSet inputDataset = (RemoteInferenceInputDataSet) request.getMlInput().getInputDataset(); + dlq = inputDataset.getDlq(); + if (dlq != null) { + bucketName = dlq.get(BUCKET_FIELD); + region = dlq.get(REGION_FIELD); + + if (bucketName == null || region == null) { + throw new IllegalArgumentException("DLQ bucketName or region cannot be null"); + } + // TODO: check if we are able to input an object into the s3 bucket. + // Or check permissions to DLQ write access + } + } + try { ActionListener actionListener = ActionListener.wrap(node -> { if (clusterService.localNode().getId().equals(node.getId())) { @@ -399,6 +420,8 @@ private void runPredict( .getDataAsMap(); if (dataAsMap != null && statusCode != null && statusCode >= 200 && statusCode < 300) { remoteJob.putAll(dataAsMap); + // put dlq info in remote job + remoteJob.put("dlq", ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getDlq()); mlTask.setRemoteJob(remoteJob); mlTask.setTaskId(null); mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> { @@ -410,6 +433,10 @@ private void runPredict( remoteJob ); + if (!clusterService.state().metadata().indices().containsKey(TASK_POLLING_JOB_INDEX)) { + mlTaskManager.startTaskPollingJob(); + } + MLTaskResponse predictOutput = MLTaskResponse.builder().output(outputBuilder).build(); internalListener.onResponse(predictOutput); }, e -> { diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java index 90322edc87..2dd1426131 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java @@ -6,6 +6,7 @@ package org.opensearch.ml.task; import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; +import static org.opensearch.ml.common.CommonValue.TASK_POLLING_JOB_INDEX; import static org.opensearch.ml.common.MLTask.LAST_UPDATE_TIME_FIELD; import static org.opensearch.ml.common.MLTask.STATE_FIELD; import static org.opensearch.ml.common.MLTask.TASK_TYPE_FIELD; @@ -16,6 +17,7 @@ import java.io.IOException; import java.time.Instant; +import java.time.temporal.ChronoUnit; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -28,6 +30,7 @@ import java.util.concurrent.atomic.AtomicInteger; import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; @@ -37,11 +40,13 @@ import org.opensearch.client.Client; import org.opensearch.client.Requests; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; @@ -49,6 +54,7 @@ import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.jobs.MLBatchTaskUpdateJobParameter; import org.opensearch.remote.metadata.client.PutDataObjectRequest; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.remote.metadata.client.UpdateDataObjectRequest; @@ -522,4 +528,27 @@ private ActionListener getUpdateResponseListener(String taskId, }); } + public void startTaskPollingJob() throws IOException { + String id = "ml_batch_task_polling_job"; + String jobName = "poll_batch_jobs"; + String interval = "1"; + Long lockDurationSeconds = 20L; + + MLBatchTaskUpdateJobParameter jobParameter = new MLBatchTaskUpdateJobParameter( + jobName, + new IntervalSchedule(Instant.now(), Integer.parseInt(interval), ChronoUnit.MINUTES), + lockDurationSeconds, + null + ); + IndexRequest indexRequest = new IndexRequest() + .index(TASK_POLLING_JOB_INDEX) + .id(id) + .source(jobParameter.toXContent(JsonXContent.contentBuilder(), null)) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + client.index(indexRequest, ActionListener.wrap(r -> { log.info("Indexed ml task polling job successfully"); }, e -> { + log.error("Failed to index task polling job", e); + })); + } + } diff --git a/plugin/src/main/resources/META-INF/services/org.opensearch.jobscheduler.spi.JobSchedulerExtension b/plugin/src/main/resources/META-INF/services/org.opensearch.jobscheduler.spi.JobSchedulerExtension new file mode 100644 index 0000000000..48795cc2af --- /dev/null +++ b/plugin/src/main/resources/META-INF/services/org.opensearch.jobscheduler.spi.JobSchedulerExtension @@ -0,0 +1,8 @@ +# +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +# + +# This file is needed to register MLBatchTaskUpdateExtension in job scheduler framework +# See https://github.com/opensearch-project/job-scheduler/blob/main/README.md#getting-started +org.opensearch.ml.jobs.MLBatchTaskUpdateExtension \ No newline at end of file diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java index 9109514f61..5bed3cd5ce 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java @@ -15,10 +15,12 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.connector.AbstractConnector.*; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_FAILED_REGEX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_FIELD; import java.io.IOException; @@ -69,6 +71,7 @@ import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.task.MLTaskGetRequest; import org.opensearch.ml.common.transport.task.MLTaskGetResponse; +import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.helper.ModelAccessControlHelper; @@ -131,6 +134,9 @@ public class GetTaskTransportActionTests extends OpenSearchTestCase { @Rule public ExpectedException exceptionRule = ExpectedException.none(); + @Mock + MLEngine mlEngine; + GetTaskTransportAction getTaskTransportAction; MLTaskGetRequest mlTaskGetRequest; ThreadContext threadContext; @@ -147,6 +153,7 @@ public void setup() throws IOException { .put(ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX.getKey(), "(stopped|cancelled)") .put(ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX.getKey(), "(stopping|cancelling)") .put(ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX.getKey(), "(expired|timeout)") + .put(ML_COMMONS_REMOTE_JOB_STATUS_FAILED_REGEX.getKey(), "(failed)") .build(); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); @@ -167,7 +174,8 @@ public void setup() throws IOException { ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX, ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX, ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX, - ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX + ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX, + ML_COMMONS_REMOTE_JOB_STATUS_FAILED_REGEX ) ) ); @@ -187,7 +195,8 @@ public void setup() throws IOException { mlTaskManager, mlModelManager, mlFeatureEnabledSetting, - settings + settings, + mlEngine ) ); @@ -453,6 +462,10 @@ public void test_processTaskResponse_expired() { processTaskResponse("status", "expired", MLTaskState.EXPIRED); } + public void test_processTaskResponse_failed() { + processTaskResponse("status", "failed", MLTaskState.FAILED); + } + public void test_processTaskResponse_WrongStatusField() { processTaskResponse("wrong_status_field", "expired", null); } @@ -484,7 +497,7 @@ private void processTaskResponse(String statusField, String remoteJobResponseSta ActionListener actionListener = mock(ActionListener.class); ArgumentCaptor> updatedTaskCaptor = ArgumentCaptor.forClass(Map.class); - getTaskTransportAction.processTaskResponse(mlTask, taskId, taskResponse, mlTask.getRemoteJob(), actionListener); + getTaskTransportAction.processTaskResponse(mlTask, taskId, true, taskResponse, mlTask.getRemoteJob(), null, actionListener); verify(mlTaskManager).updateMLTaskDirectly(any(), updatedTaskCaptor.capture(), any()); Map updatedTask = updatedTaskCaptor.getValue(); @@ -493,4 +506,78 @@ private void processTaskResponse(String statusField, String remoteJobResponseSta assertEquals(remoteJobResponseStatus, updatedRemoteJob.get(statusField)); assertEquals(remoteJobName, updatedRemoteJob.get("name")); } + + public void testUpdateDLQ_Success() throws IOException { + // Setup test data + Map remoteJob = new HashMap<>(); + remoteJob.put("TransformJobName", "test-job"); + Map dlq = new HashMap<>(); + dlq.put("bucket", "test-bucket"); + dlq.put("region", "us-west-2"); + remoteJob.put("dlq", dlq); + + MLTask mlTask = MLTask + .builder() + .taskId("test-task") + .state(MLTaskState.FAILED) + .error("Test error message") + .remoteJob(remoteJob) + .build(); + + // Setup decrypted credentials + Map decryptedCredential = new HashMap<>(); + decryptedCredential.put(ACCESS_KEY_FIELD, "test-key"); + decryptedCredential.put(SECRET_KEY_FIELD, "test-secret"); + decryptedCredential.put(SESSION_TOKEN_FIELD, "test-token"); + + // Call the method + getTaskTransportAction.updateDLQ(mlTask, decryptedCredential); + + // Verify remoteJob DLQ is removed + assertNull(mlTask.getRemoteJob().get("dlq")); + } + + public void testUpdateDLQ_MissingBucketOrRegion() { + // Setup test data with missing bucket/region + Map remoteJob = new HashMap<>(); + remoteJob.put("TransformJobName", "test-job"); + Map dlq = new HashMap<>(); + // Intentionally missing bucket and region + remoteJob.put("dlq", dlq); + + MLTask mlTask = MLTask + .builder() + .taskId("test-task") + .state(MLTaskState.FAILED) + .error("Test error message") + .remoteJob(remoteJob) + .build(); + + // Call the method - should not throw exception but log error + getTaskTransportAction.updateDLQ(mlTask, Collections.emptyMap()); + + // Verify DLQ still exists since update failed + assertNotNull(mlTask.getRemoteJob().get("dlq")); + } + + public void testUpdateDLQ_NullDLQ() { + // Setup test data with null DLQ + Map remoteJob = new HashMap<>(); + remoteJob.put("TransformJobName", "test-job"); + // No DLQ configuration + + MLTask mlTask = MLTask + .builder() + .taskId("test-task") + .state(MLTaskState.FAILED) + .error("Test error message") + .remoteJob(remoteJob) + .build(); + + // Call the method - should do nothing + getTaskTransportAction.updateDLQ(mlTask, null); + + // Verify remoteJob is unchanged + assertEquals("test-job", mlTask.getRemoteJob().get("TransformJobName")); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/jobs/MLBatchTaskUpdateExtensionTests.java b/plugin/src/test/java/org/opensearch/ml/jobs/MLBatchTaskUpdateExtensionTests.java new file mode 100644 index 0000000000..fdacdb8f22 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/jobs/MLBatchTaskUpdateExtensionTests.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.jobs; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.time.Instant; + +import org.junit.Ignore; +import org.junit.Test; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.jobscheduler.spi.JobDocVersion; +import org.opensearch.ml.common.CommonValue; + +public class MLBatchTaskUpdateExtensionTests { + + @Test + public void testBasic() { + MLBatchTaskUpdateExtension extension = new MLBatchTaskUpdateExtension(); + assertEquals("checkBatchJobTaskStatus", extension.getJobType()); + assertEquals(CommonValue.TASK_POLLING_JOB_INDEX, extension.getJobIndex()); + assertEquals(MLBatchTaskUpdateJobRunner.getJobRunnerInstance(), extension.getJobRunner()); + } + + @Ignore + @Test + public void testParser() throws IOException { + MLBatchTaskUpdateExtension extension = new MLBatchTaskUpdateExtension(); + + Instant enabledTime = Instant.now(); + Instant lastUpdateTime = Instant.now(); + + String json = "{" + + "\"name\": \"testJob\"," + + "\"enabled\": true," + + "\"enabled_time\": \"" + + enabledTime.toString() + + "\"," + + "\"last_update_time\": \"" + + lastUpdateTime.toString() + + "\"," + + "\"lock_duration_seconds\": 300," + + "\"jitter\": 0.1" + + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, json); + + parser.nextToken(); + MLBatchTaskUpdateJobParameter parsedJobParameter = (MLBatchTaskUpdateJobParameter) extension + .getJobParser() + .parse(parser, "test_id", new JobDocVersion(1, 0, 0)); + + assertEquals("testJob", parsedJobParameter.getName()); + assertTrue(parsedJobParameter.isEnabled()); + } + + @Test(expected = IOException.class) + public void testParserWithInvalidJson() throws IOException { + MLBatchTaskUpdateExtension extension = new MLBatchTaskUpdateExtension(); + + String invalidJson = "{ invalid json }"; + + XContentParser parser = JsonXContent.jsonXContent.createParser(null, null, invalidJson); + extension.getJobParser().parse(parser, "test_id", new JobDocVersion(1, 0, 0)); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/jobs/MLBatchTaskUpdateJobParameterTests.java b/plugin/src/test/java/org/opensearch/ml/jobs/MLBatchTaskUpdateJobParameterTests.java new file mode 100644 index 0000000000..e0f9d12958 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/jobs/MLBatchTaskUpdateJobParameterTests.java @@ -0,0 +1,99 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.jobs; + +import static org.junit.Assert.*; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; + +public class MLBatchTaskUpdateJobParameterTests { + + private MLBatchTaskUpdateJobParameter jobParameter; + private String jobName; + private IntervalSchedule schedule; + private Long lockDurationSeconds; + private Double jitter; + + @Before + public void setUp() { + jobName = "test-job"; + schedule = new IntervalSchedule(Instant.now(), 1, ChronoUnit.MINUTES); + lockDurationSeconds = 20L; + jitter = 0.5; + jobParameter = new MLBatchTaskUpdateJobParameter(jobName, schedule, lockDurationSeconds, jitter); + } + + @Test + public void testConstructor() { + assertNotNull(jobParameter); + assertEquals(jobName, jobParameter.getName()); + assertEquals(schedule, jobParameter.getSchedule()); + assertEquals(lockDurationSeconds, jobParameter.getLockDurationSeconds()); + assertEquals(jitter, jobParameter.getJitter()); + assertTrue(jobParameter.isEnabled()); + assertNotNull(jobParameter.getEnabledTime()); + assertNotNull(jobParameter.getLastUpdateTime()); + } + + @Test + public void testToXContent() throws Exception { + XContentBuilder builder = XContentFactory.jsonBuilder(); + jobParameter.toXContent(builder, null); + String jsonString = builder.toString(); + + assertTrue(jsonString.contains(jobName)); + assertTrue(jsonString.contains("enabled")); + assertTrue(jsonString.contains("schedule")); + assertTrue(jsonString.contains("lock_duration_seconds")); + assertTrue(jsonString.contains("jitter")); + } + + @Test + public void testSetters() { + String newJobName = "new-job"; + jobParameter.setJobName(newJobName); + assertEquals(newJobName, jobParameter.getName()); + + Instant newTime = Instant.now(); + jobParameter.setLastUpdateTime(newTime); + assertEquals(newTime, jobParameter.getLastUpdateTime()); + + jobParameter.setEnabled(false); + assertEquals(false, jobParameter.isEnabled()); + + Long newLockDuration = 30L; + jobParameter.setLockDurationSeconds(newLockDuration); + assertEquals(newLockDuration, jobParameter.getLockDurationSeconds()); + + Double newJitter = 0.7; + jobParameter.setJitter(newJitter); + assertEquals(newJitter, jobParameter.getJitter()); + } + + @Test + public void testNullCase() throws IOException { + String newJobName = "test-job"; + + jobParameter = new MLBatchTaskUpdateJobParameter(newJobName, null, null, null); + jobParameter.setLastUpdateTime(null); + jobParameter.setEnabledTime(null); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + jobParameter.toXContent(builder, null); + String jsonString = builder.toString(); + + assertTrue(jsonString.contains(jobName)); + assertEquals(newJobName, jobParameter.getName()); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/jobs/MLBatchTaskUpdateJobRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/jobs/MLBatchTaskUpdateJobRunnerTests.java new file mode 100644 index 0000000000..bf1763c3dc --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/jobs/MLBatchTaskUpdateJobRunnerTests.java @@ -0,0 +1,146 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.jobs; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import java.io.IOException; + +import org.apache.lucene.search.TotalHits; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.action.ActionListener; +import org.opensearch.jobscheduler.spi.JobExecutionContext; +import org.opensearch.jobscheduler.spi.utils.LockService; +import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.ml.utils.TestHelper; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.threadpool.ThreadPool; + +public class MLBatchTaskUpdateJobRunnerTests { + + @Mock + private ClusterService clusterService; + + @Mock + private ThreadPool threadPool; + + @Mock + private Client client; + + @Mock + private MLTaskManager mlTaskManager; + + @Mock + private JobExecutionContext jobExecutionContext; + + private LockService lockService; + + @Mock + private MLBatchTaskUpdateJobParameter jobParameter; + + private MLBatchTaskUpdateJobRunner jobRunner; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + jobRunner = MLBatchTaskUpdateJobRunner.getJobRunnerInstance(); + jobRunner.initialize(clusterService, threadPool, client); + + lockService = new LockService(client, clusterService); + when(jobExecutionContext.getLockService()).thenReturn(lockService); + } + + @Ignore + @Test + public void testRunJobWithoutInitialization() { + MLBatchTaskUpdateJobRunner uninitializedRunner = MLBatchTaskUpdateJobRunner.getJobRunnerInstance(); + AssertionError exception = Assert.assertThrows(AssertionError.class, () -> { + uninitializedRunner.runJob(jobParameter, jobExecutionContext); + }); + Assert.assertEquals("this instance is not initialized", exception.getMessage()); + } + + @Ignore + @Test + public void testRunJobFailedToAcquireLock() { + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).get(any(), any()); + + jobRunner.runJob(jobParameter, jobExecutionContext); + + verify(jobExecutionContext).getLockService(); + verifyNoMoreInteractions(client); + } + + @Ignore + @Test + public void testRunJobWithLockAcquisitionException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Failed to acquire lock")); + return null; + }).when(client).get(any(), any()); + + Assert.assertThrows(IllegalStateException.class, () -> { jobRunner.runJob(jobParameter, jobExecutionContext); }); + + verify(jobExecutionContext).getLockService(); + verifyNoMoreInteractions(client); + } + + @Ignore + @Test + public void testRunJobWithTasksFound() throws IOException { + SearchResponse searchResponse = createTaskSearchResponse(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), isA(ActionListener.class)); + + when(jobExecutionContext.getLockService()).thenReturn(lockService); + + jobRunner.runJob(jobParameter, jobExecutionContext); + + verify(client).search(any(), isA(ActionListener.class)); + verify(lockService).acquireLock(any(), any(), any()); + } + + private SearchResponse createTaskSearchResponse() throws IOException { + SearchResponse searchResponse = mock(SearchResponse.class); + + String taskContent = "{\n" + + " \"task_type\": \"BATCH_PREDICTION\",\n" + + " \"state\": \"RUNNING\",\n" + + " \"function_name\": \"REMOTE\",\n" + + " \"task_id\": \"example-task-id\"\n" + + "}"; + + SearchHit taskHit = SearchHit.fromXContent(TestHelper.parser(taskContent)); + + SearchHits hits = new SearchHits(new SearchHit[] { taskHit }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), Float.NaN); + + when(searchResponse.getHits()).thenReturn(hits); + + return searchResponse; + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index 0254cd619f..c7fc013b58 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -14,11 +14,13 @@ import java.io.IOException; import java.nio.file.Path; import java.util.Arrays; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.Before; import org.junit.Rule; @@ -32,7 +34,13 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.Metadata; import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodeRole; +import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; @@ -42,6 +50,7 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.get.GetResult; @@ -84,6 +93,7 @@ import org.opensearch.transport.TransportService; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; public class MLPredictTaskRunnerTests extends OpenSearchTestCase { @@ -473,6 +483,40 @@ public void testValidateBatchPredictionSuccess() throws IOException { when(mlModelManager.getPredictor(anyString())).thenReturn(predictor); when(mlModelManager.getWorkerNodes(anyString(), eq(FunctionName.REMOTE), eq(true))).thenReturn(new String[] { "node1" }); + + Settings indexSettings = Settings + .builder() + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 1) + .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) + .build(); + final Settings.Builder existingSettings = Settings.builder().put(indexSettings).put(IndexMetadata.SETTING_INDEX_UUID, "test2UUID"); + + IndexMetadata indexMetaData = IndexMetadata.builder(".ml_commons_task_polling_job").settings(existingSettings).build(); + + final Map indices = Map.of(indexName, indexMetaData); + Metadata metadata = new Metadata.Builder().indices(indices).build(); + DiscoveryNode node = new DiscoveryNode( + "node", + new TransportAddress(TransportAddress.META_ADDRESS, new AtomicInteger().incrementAndGet()), + new HashMap<>(), + ImmutableSet.of(DiscoveryNodeRole.DATA_ROLE), + Version.CURRENT + ); + ClusterState state = new ClusterState( + new ClusterName("test cluster"), + 123l, + "111111", + metadata, + null, + DiscoveryNodes.builder().add(node).build(), + null, + Map.of(), + 0, + false + ); + ; + when(clusterService.state()).thenReturn(state); taskRunner.dispatchTask(FunctionName.REMOTE, remoteInputRequest, transportService, listener); verify(client, never()).get(any(), any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLTaskResponse.class); @@ -543,6 +587,87 @@ public void testValidateModelTensorOutputFailed() { taskRunner.validateOutputSchema("testId", modelTensorOutput); } + public void testValidateBatchPredictionSuccess_InitPollingJob() throws IOException { + setupMocks(true, false, false, false); + RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters( + Map + .of( + "messages", + "[{\\\"role\\\":\\\"system\\\",\\\"content\\\":\\\"You are a helpful assistant.\\\"}," + + "{\\\"role\\\":\\\"user\\\",\\\"content\\\":\\\"Hello!\\\"}]" + ) + ) + .actionType(ConnectorAction.ActionType.BATCH_PREDICT) + .build(); + MLPredictionTaskRequest remoteInputRequest = MLPredictionTaskRequest + .builder() + .modelId("test_model") + .mlInput(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build()) + .build(); + Predictable predictor = mock(Predictable.class); + when(predictor.isModelReady()).thenReturn(true); + ModelTensor modelTensor = ModelTensor + .builder() + .name("response") + .dataAsMap(Map.of("TransformJobArn", "arn:aws:sagemaker:us-east-1:802041417063:transform-job/batch-transform-01")) + .build(); + Map modelInterface = Map + .of( + "output", + "{\"properties\":{\"inference_results\":{\"description\":\"This is a test description field\"," + "\"type\":\"array\"}}}" + ); + ModelTensors modelTensors = ModelTensors.builder().statusCode(200).mlModelTensors(List.of(modelTensor)).statusCode(200).build(); + modelTensors.setStatusCode(200); + ModelTensorOutput modelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(List.of(modelTensors)).build(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(MLTaskResponse.builder().output(modelTensorOutput).build()); + return null; + }).when(predictor).asyncPredict(any(), any()); + + IndexResponse indexResponse = mock(IndexResponse.class); + when(indexResponse.getId()).thenReturn("mockTaskId"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(indexResponse); + return null; + }).when(mlTaskManager).createMLTask(any(MLTask.class), Mockito.isA(ActionListener.class)); + + when(mlModelManager.getModelInterface(any())).thenReturn(modelInterface); + + when(mlModelManager.getPredictor(anyString())).thenReturn(predictor); + when(mlModelManager.getWorkerNodes(anyString(), eq(FunctionName.REMOTE), eq(true))).thenReturn(new String[] { "node1" }); + + // Mocking clusterService to simulate missing TASK_POLLING_JOB_INDEX + Metadata metadata = new Metadata.Builder().indices(Map.of()).build(); + DiscoveryNode node = new DiscoveryNode( + "node", + new TransportAddress(TransportAddress.META_ADDRESS, new AtomicInteger().incrementAndGet()), + new HashMap<>(), + ImmutableSet.of(DiscoveryNodeRole.DATA_ROLE), + Version.CURRENT + ); + ClusterState state = new ClusterState( + new ClusterName("test cluster"), + 123l, + "111111", + metadata, + null, + DiscoveryNodes.builder().add(node).build(), + null, + Map.of(), + 0, + false + ); + ; + when(clusterService.state()).thenReturn(state); + + taskRunner.dispatchTask(FunctionName.REMOTE, remoteInputRequest, transportService, listener); + verify(mlTaskManager).startTaskPollingJob(); + } + private void setupMocks(boolean runOnLocalNode, boolean failedToParseQueryInput, boolean failedToGetModel, boolean nullGetResponse) { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java index f56423e829..9a5ce4bb42 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java @@ -16,7 +16,9 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; +import static org.opensearch.ml.common.CommonValue.TASK_POLLING_JOB_INDEX; +import java.io.IOException; import java.time.Instant; import java.util.Arrays; import java.util.Collections; @@ -30,7 +32,9 @@ import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; @@ -341,4 +345,35 @@ public void testMLTaskCache() { mlTaskManager.addNodeError(task.getTaskId(), node2, error); assertTrue(mlTaskCache.allNodeFailed()); } + + public void testStartTaskPollingJob() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(indexResponse); + return null; + }).when(client).index(any(), any()); + + mlTaskManager.startTaskPollingJob(); + + ArgumentCaptor indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class); + verify(client).index(indexRequestCaptor.capture(), any()); + + IndexRequest capturedRequest = indexRequestCaptor.getValue(); + assertEquals(TASK_POLLING_JOB_INDEX, capturedRequest.index()); + assertNotNull(capturedRequest.id()); + assertEquals(WriteRequest.RefreshPolicy.IMMEDIATE, capturedRequest.getRefreshPolicy()); + } + + public void testStartTaskPollingJob_IndexException() throws IOException { + String errorMessage = "Failed to index task polling job"; + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException(errorMessage)); + return null; + }).when(client).index(any(), any()); + + mlTaskManager.startTaskPollingJob(); + + verify(client).index(any(), any()); + } }