Skip to content

Commit ff6fe67

Browse files
authored
Enhance batch job task management by adding default action types (opensearch-project#3080)
* enhance batch job task management by adding default action types Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>
1 parent 75d454e commit ff6fe67

File tree

10 files changed

+266
-18
lines changed

10 files changed

+266
-18
lines changed

common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java

+5
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ public Optional<ConnectorAction> findAction(String action) {
125125
return Optional.empty();
126126
}
127127

128+
@Override
129+
public void addAction(ConnectorAction action) {
130+
actions.add(action);
131+
}
132+
128133
@Override
129134
public void removeCredential() {
130135
this.credential = null;

common/src/main/java/org/opensearch/ml/common/connector/Connector.java

+2
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ public interface Connector extends ToXContentObject, Writeable {
6565

6666
List<ConnectorAction> getActions();
6767

68+
void addAction(ConnectorAction action);
69+
6870
ConnectorClientConfig getConnectorClientConfig();
6971

7072
String getActionEndpoint(String action, Map<String, String> parameters);

common/src/main/java/org/opensearch/ml/common/output/MLPredictionOutput.java

+26
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.opensearch.ml.common.output;
77

88
import java.io.IOException;
9+
import java.util.Map;
910

1011
import org.opensearch.core.common.io.stream.StreamInput;
1112
import org.opensearch.core.common.io.stream.StreamOutput;
@@ -30,8 +31,12 @@ public class MLPredictionOutput extends MLOutput {
3031
public static final String STATUS_FIELD = "status";
3132
public static final String PREDICTION_RESULT_FIELD = "prediction_result";
3233

34+
// This field will be created for offline batch prediction tasks containing details of the batch job as outputted by the remote server.
35+
public static final String REMOTE_JOB_FIELD = "remote_job";
36+
3337
String taskId;
3438
String status;
39+
Map<String, Object> remoteJob;
3540

3641
@ToString.Exclude
3742
DataFrame predictionResult;
@@ -44,6 +49,14 @@ public MLPredictionOutput(String taskId, String status, DataFrame predictionResu
4449
this.predictionResult = predictionResult;
4550
}
4651

52+
@Builder
53+
public MLPredictionOutput(String taskId, String status, Map<String, Object> remoteJob) {
54+
super(OUTPUT_TYPE);
55+
this.taskId = taskId;
56+
this.status = status;
57+
this.remoteJob = remoteJob;
58+
}
59+
4760
public MLPredictionOutput(StreamInput in) throws IOException {
4861
super(OUTPUT_TYPE);
4962
this.taskId = in.readOptionalString();
@@ -56,6 +69,9 @@ public MLPredictionOutput(StreamInput in) throws IOException {
5669
break;
5770
}
5871
}
72+
if (in.readBoolean()) {
73+
this.remoteJob = in.readMap(s -> s.readString(), s -> s.readGenericValue());
74+
}
5975
}
6076

6177
@Override
@@ -69,6 +85,12 @@ public void writeTo(StreamOutput out) throws IOException {
6985
} else {
7086
out.writeBoolean(false);
7187
}
88+
if (remoteJob != null) {
89+
out.writeBoolean(true);
90+
out.writeMap(remoteJob, StreamOutput::writeString, StreamOutput::writeGenericValue);
91+
} else {
92+
out.writeBoolean(false);
93+
}
7294
}
7395

7496
@Override
@@ -87,6 +109,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
87109
builder.endObject();
88110
}
89111

112+
if (remoteJob != null) {
113+
builder.field(REMOTE_JOB_FIELD, remoteJob);
114+
}
115+
90116
builder.endObject();
91117
return builder;
92118
}

common/src/test/java/org/opensearch/ml/common/output/MLPredictionOutputTest.java

+14
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99

1010
import java.io.IOException;
1111
import java.util.ArrayList;
12+
import java.util.HashMap;
1213
import java.util.List;
14+
import java.util.Map;
1315

1416
import org.junit.Before;
1517
import org.junit.Test;
@@ -30,6 +32,7 @@
3032
public class MLPredictionOutputTest {
3133

3234
MLPredictionOutput output;
35+
MLPredictionOutput outputWithRemoteJob;
3336

3437
@Before
3538
public void setUp() {
@@ -38,12 +41,17 @@ public void setUp() {
3841
rows.add(new Row(new ColumnValue[] { new IntValue(1) }));
3942
rows.add(new Row(new ColumnValue[] { new IntValue(2) }));
4043
DataFrame dataFrame = new DefaultDataFrame(columnMetas, rows);
44+
Map<String, Object> remoteJob = new HashMap<>();
45+
remoteJob.put("status", "INPROGRESS");
46+
remoteJob.put("job_id", "testJobID");
4147
output = MLPredictionOutput.builder().taskId("test_task_id").status("test_status").predictionResult(dataFrame).build();
48+
outputWithRemoteJob = new MLPredictionOutput("test_task_id", "test_status", remoteJob);
4249
}
4350

4451
@Test
4552
public void toXContent() throws IOException {
4653
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
54+
XContentBuilder builderWithRemoteJob = MediaTypeRegistry.contentBuilder(XContentType.JSON);
4755
output.toXContent(builder, ToXContent.EMPTY_PARAMS);
4856
String jsonStr = builder.toString();
4957
assertEquals(
@@ -53,6 +61,12 @@ public void toXContent() throws IOException {
5361
+ "\"value\":2}]}]}}",
5462
jsonStr
5563
);
64+
outputWithRemoteJob.toXContent(builderWithRemoteJob, ToXContent.EMPTY_PARAMS);
65+
String jsonStr2 = builderWithRemoteJob.toString();
66+
assertEquals(
67+
"{\"task_id\":\"test_task_id\",\"status\":\"test_status\",\"remote_job\":{\"job_id\":\"testJobID\",\"status\":\"INPROGRESS\"}}",
68+
jsonStr2
69+
);
5670
}
5771

5872
@Test

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java

+64
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.opensearch.ml.engine.algorithms.remote;
77

88
import static org.apache.commons.text.StringEscapeUtils.escapeJson;
9+
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_PREDICT;
910
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.CANCEL_BATCH_PREDICT;
1011
import static org.opensearch.ml.common.connector.HttpConnector.RESPONSE_FILTER_FIELD;
1112
import static org.opensearch.ml.common.connector.MLPreProcessFunction.CONVERT_INPUT_TO_JSON_STRING;
@@ -19,6 +20,7 @@
1920
import java.net.URI;
2021
import java.nio.charset.Charset;
2122
import java.util.ArrayList;
23+
import java.util.Collections;
2224
import java.util.HashMap;
2325
import java.util.List;
2426
import java.util.Map;
@@ -61,6 +63,9 @@ public class ConnectorUtils {
6163
private static final Aws4Signer signer;
6264
public static final String SKIP_VALIDATE_MISSING_PARAMETERS = "skip_validating_missing_parameters";
6365

66+
public static final List<String> SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES = List
67+
.of("sagemaker", "openai", "bedrock", "cohere");
68+
6469
static {
6570
signer = Aws4Signer.create();
6671
}
@@ -313,4 +318,63 @@ public static SdkHttpFullRequest buildSdkRequest(
313318
}
314319
return builder.build();
315320
}
321+
322+
public static ConnectorAction createConnectorAction(Connector connector, ConnectorAction.ActionType actionType) {
323+
Optional<ConnectorAction> batchPredictAction = connector.findAction(BATCH_PREDICT.name());
324+
String predictEndpoint = batchPredictAction.get().getUrl();
325+
Map<String, String> parameters = connector.getParameters() != null
326+
? new HashMap<>(connector.getParameters())
327+
: Collections.emptyMap();
328+
329+
// Apply parameter substitution only if needed
330+
if (!parameters.isEmpty()) {
331+
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
332+
predictEndpoint = substitutor.replace(predictEndpoint);
333+
}
334+
335+
boolean isCancelAction = actionType == CANCEL_BATCH_PREDICT;
336+
337+
// Initialize the default method and requestBody
338+
String method = "POST";
339+
String requestBody = null;
340+
String url = "";
341+
342+
switch (getRemoteServerFromURL(predictEndpoint)) {
343+
case "sagemaker":
344+
url = isCancelAction
345+
? predictEndpoint.replace("CreateTransformJob", "StopTransformJob")
346+
: predictEndpoint.replace("CreateTransformJob", "DescribeTransformJob");
347+
requestBody = "{ \"TransformJobName\" : \"${parameters.TransformJobName}\"}";
348+
break;
349+
case "openai":
350+
case "cohere":
351+
url = isCancelAction ? predictEndpoint + "/${parameters.id}/cancel" : predictEndpoint + "/${parameters.id}";
352+
method = isCancelAction ? "POST" : "GET";
353+
break;
354+
case "bedrock":
355+
url = isCancelAction
356+
? predictEndpoint + "/${parameters.processedJobArn}/stop"
357+
: predictEndpoint + "/${parameters.processedJobArn}";
358+
method = isCancelAction ? "POST" : "GET";
359+
break;
360+
default:
361+
String errorMessage = isCancelAction
362+
? "Please configure the action type to cancel the batch job in the connector"
363+
: "Please configure the action type to get the batch job details in the connector";
364+
throw new UnsupportedOperationException(errorMessage);
365+
}
366+
367+
return ConnectorAction
368+
.builder()
369+
.actionType(actionType)
370+
.method(method)
371+
.url(url)
372+
.requestBody(requestBody)
373+
.headers(batchPredictAction.get().getHeaders())
374+
.build();
375+
}
376+
377+
public static String getRemoteServerFromURL(String url) {
378+
return SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES.stream().filter(url::contains).findFirst().orElse("");
379+
}
316380
}

0 commit comments

Comments
 (0)