|
6 | 6 | package org.opensearch.ml.engine.algorithms.remote;
|
7 | 7 |
|
8 | 8 | import static org.apache.commons.text.StringEscapeUtils.escapeJson;
|
| 9 | +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_PREDICT; |
9 | 10 | import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.CANCEL_BATCH_PREDICT;
|
10 | 11 | import static org.opensearch.ml.common.connector.HttpConnector.RESPONSE_FILTER_FIELD;
|
11 | 12 | import static org.opensearch.ml.common.connector.MLPreProcessFunction.CONVERT_INPUT_TO_JSON_STRING;
|
|
19 | 20 | import java.net.URI;
|
20 | 21 | import java.nio.charset.Charset;
|
21 | 22 | import java.util.ArrayList;
|
| 23 | +import java.util.Collections; |
22 | 24 | import java.util.HashMap;
|
23 | 25 | import java.util.List;
|
24 | 26 | import java.util.Map;
|
@@ -61,6 +63,9 @@ public class ConnectorUtils {
|
61 | 63 | private static final Aws4Signer signer;
|
62 | 64 | public static final String SKIP_VALIDATE_MISSING_PARAMETERS = "skip_validating_missing_parameters";
|
63 | 65 |
|
| 66 | + public static final List<String> SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES = List |
| 67 | + .of("sagemaker", "openai", "bedrock", "cohere"); |
| 68 | + |
64 | 69 | static {
|
65 | 70 | signer = Aws4Signer.create();
|
66 | 71 | }
|
@@ -313,4 +318,63 @@ public static SdkHttpFullRequest buildSdkRequest(
|
313 | 318 | }
|
314 | 319 | return builder.build();
|
315 | 320 | }
|
| 321 | + |
| 322 | + public static ConnectorAction createConnectorAction(Connector connector, ConnectorAction.ActionType actionType) { |
| 323 | + Optional<ConnectorAction> batchPredictAction = connector.findAction(BATCH_PREDICT.name()); |
| 324 | + String predictEndpoint = batchPredictAction.get().getUrl(); |
| 325 | + Map<String, String> parameters = connector.getParameters() != null |
| 326 | + ? new HashMap<>(connector.getParameters()) |
| 327 | + : Collections.emptyMap(); |
| 328 | + |
| 329 | + // Apply parameter substitution only if needed |
| 330 | + if (!parameters.isEmpty()) { |
| 331 | + StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); |
| 332 | + predictEndpoint = substitutor.replace(predictEndpoint); |
| 333 | + } |
| 334 | + |
| 335 | + boolean isCancelAction = actionType == CANCEL_BATCH_PREDICT; |
| 336 | + |
| 337 | + // Initialize the default method and requestBody |
| 338 | + String method = "POST"; |
| 339 | + String requestBody = null; |
| 340 | + String url = ""; |
| 341 | + |
| 342 | + switch (getRemoteServerFromURL(predictEndpoint)) { |
| 343 | + case "sagemaker": |
| 344 | + url = isCancelAction |
| 345 | + ? predictEndpoint.replace("CreateTransformJob", "StopTransformJob") |
| 346 | + : predictEndpoint.replace("CreateTransformJob", "DescribeTransformJob"); |
| 347 | + requestBody = "{ \"TransformJobName\" : \"${parameters.TransformJobName}\"}"; |
| 348 | + break; |
| 349 | + case "openai": |
| 350 | + case "cohere": |
| 351 | + url = isCancelAction ? predictEndpoint + "/${parameters.id}/cancel" : predictEndpoint + "/${parameters.id}"; |
| 352 | + method = isCancelAction ? "POST" : "GET"; |
| 353 | + break; |
| 354 | + case "bedrock": |
| 355 | + url = isCancelAction |
| 356 | + ? predictEndpoint + "/${parameters.processedJobArn}/stop" |
| 357 | + : predictEndpoint + "/${parameters.processedJobArn}"; |
| 358 | + method = isCancelAction ? "POST" : "GET"; |
| 359 | + break; |
| 360 | + default: |
| 361 | + String errorMessage = isCancelAction |
| 362 | + ? "Please configure the action type to cancel the batch job in the connector" |
| 363 | + : "Please configure the action type to get the batch job details in the connector"; |
| 364 | + throw new UnsupportedOperationException(errorMessage); |
| 365 | + } |
| 366 | + |
| 367 | + return ConnectorAction |
| 368 | + .builder() |
| 369 | + .actionType(actionType) |
| 370 | + .method(method) |
| 371 | + .url(url) |
| 372 | + .requestBody(requestBody) |
| 373 | + .headers(batchPredictAction.get().getHeaders()) |
| 374 | + .build(); |
| 375 | + } |
| 376 | + |
| 377 | + public static String getRemoteServerFromURL(String url) { |
| 378 | + return SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES.stream().filter(url::contains).findFirst().orElse(""); |
| 379 | + } |
316 | 380 | }
|
0 commit comments