Skip to content

Commit 222ea7b

Browse files
committed
[Enhancement] Enhance validation for create connector API
This change will address the second part of validation "pre and post processing function validation". Partially resolves opensearch-project#2993 Signed-off-by: Abdul Muneer Kolarkunnu <muneer.kolarkunnu@netapp.com>
1 parent 4d95466 commit 222ea7b

File tree

6 files changed

+453
-13
lines changed

6 files changed

+453
-13
lines changed

common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java

+145
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,32 @@
66
package org.opensearch.ml.common.connector;
77

88
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
9+
import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_BATCH_JOB_ARN;
10+
import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_EMBEDDING;
11+
import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_RERANK;
12+
import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_EMBEDDING;
13+
import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_RERANK;
14+
import static org.opensearch.ml.common.connector.MLPostProcessFunction.DEFAULT_EMBEDDING;
15+
import static org.opensearch.ml.common.connector.MLPostProcessFunction.DEFAULT_RERANK;
16+
import static org.opensearch.ml.common.connector.MLPostProcessFunction.OPENAI_EMBEDDING;
17+
import static org.opensearch.ml.common.connector.MLPreProcessFunction.IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT;
18+
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT;
19+
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT;
20+
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT;
21+
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT;
22+
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT;
23+
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT;
24+
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT;
25+
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_SIMILARITY_TO_DEFAULT_INPUT;
926

1027
import java.io.IOException;
1128
import java.util.HashSet;
29+
import java.util.List;
1230
import java.util.Locale;
1331
import java.util.Map;
1432
import java.util.Set;
1533

34+
import org.apache.commons.text.StringSubstitutor;
1635
import org.opensearch.core.common.io.stream.StreamInput;
1736
import org.opensearch.core.common.io.stream.StreamOutput;
1837
import org.opensearch.core.common.io.stream.Writeable;
@@ -35,6 +54,13 @@ public class ConnectorAction implements ToXContentObject, Writeable {
3554
public static final String REQUEST_BODY_FIELD = "request_body";
3655
public static final String ACTION_PRE_PROCESS_FUNCTION = "pre_process_function";
3756
public static final String ACTION_POST_PROCESS_FUNCTION = "post_process_function";
57+
public static final String OPENAI = "openai";
58+
public static final String COHERE = "cohere";
59+
public static final String BEDROCK = "bedrock";
60+
public static final String SAGEMAKER = "sagemaker";
61+
public static final List<String> SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES = List.of(SAGEMAKER, OPENAI, BEDROCK, COHERE);
62+
63+
private static final String INBUILT_FUNC_PREFIX = "connector.";
3864

3965
private ActionType actionType;
4066
private String method;
@@ -185,6 +211,125 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {
185211
.build();
186212
}
187213

214+
public void validatePrePostProcessFunctions(Map<String, String> parameters) {
215+
var substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
216+
var endPoint = substitutor.replace(url);
217+
var remoteServer = getRemoteServerFromURL(endPoint);
218+
if (isInBuiltFunction(preProcessFunction)) {
219+
switch (remoteServer) {
220+
case OPENAI:
221+
if (!TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT.equals(preProcessFunction)) {
222+
throw new IllegalArgumentException(
223+
"LLM service is " + OPENAI + ", so PreProcessFunction should be " + TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT
224+
);
225+
}
226+
break;
227+
case COHERE:
228+
if (!(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT.equals(preProcessFunction)
229+
|| IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT.equals(preProcessFunction)
230+
|| TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT.equals(preProcessFunction))) {
231+
throw new IllegalArgumentException(
232+
"LLM service is "
233+
+ COHERE
234+
+ ", so PreProcessFunction should be "
235+
+ TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT
236+
+ " or "
237+
+ IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT
238+
+ " or "
239+
+ TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT
240+
);
241+
}
242+
break;
243+
case BEDROCK:
244+
if (!(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT.equals(preProcessFunction)
245+
|| TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT.equals(preProcessFunction)
246+
|| TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT.equals(preProcessFunction))) {
247+
throw new IllegalArgumentException(
248+
"LLM service is "
249+
+ BEDROCK
250+
+ ", so PreProcessFunction should be "
251+
+ TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT
252+
+ " or "
253+
+ TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT
254+
+ " or "
255+
+ TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT
256+
);
257+
}
258+
break;
259+
case SAGEMAKER:
260+
if (!(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT.equals(preProcessFunction)
261+
|| TEXT_SIMILARITY_TO_DEFAULT_INPUT.equals(preProcessFunction))) {
262+
throw new IllegalArgumentException(
263+
"LLM service is "
264+
+ SAGEMAKER
265+
+ ", so PreProcessFunction should be "
266+
+ TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT
267+
+ " or "
268+
+ TEXT_SIMILARITY_TO_DEFAULT_INPUT
269+
);
270+
}
271+
}
272+
}
273+
if (isInBuiltFunction(postProcessFunction)) {
274+
switch (remoteServer) {
275+
case OPENAI:
276+
if (!OPENAI_EMBEDDING.equals(postProcessFunction)) {
277+
throw new IllegalArgumentException(
278+
"LLM service is " + OPENAI + ", so PostProcessFunction should be " + OPENAI_EMBEDDING
279+
);
280+
}
281+
break;
282+
case COHERE:
283+
if (!(COHERE_EMBEDDING.equals(postProcessFunction) || COHERE_RERANK.equals(postProcessFunction))) {
284+
throw new IllegalArgumentException(
285+
"LLM service is "
286+
+ COHERE
287+
+ ", so PostProcessFunction should be "
288+
+ COHERE_EMBEDDING
289+
+ " or "
290+
+ COHERE_RERANK
291+
);
292+
}
293+
break;
294+
case BEDROCK:
295+
if (!(BEDROCK_EMBEDDING.equals(postProcessFunction)
296+
|| BEDROCK_BATCH_JOB_ARN.equals(postProcessFunction)
297+
|| BEDROCK_RERANK.equals(postProcessFunction))) {
298+
throw new IllegalArgumentException(
299+
"LLM service is "
300+
+ BEDROCK
301+
+ ", so PostProcessFunction should be "
302+
+ BEDROCK_EMBEDDING
303+
+ " or "
304+
+ BEDROCK_BATCH_JOB_ARN
305+
+ " or "
306+
+ BEDROCK_RERANK
307+
);
308+
}
309+
break;
310+
case SAGEMAKER:
311+
if (!(DEFAULT_EMBEDDING.equals(postProcessFunction) || DEFAULT_RERANK.equals(postProcessFunction))) {
312+
throw new IllegalArgumentException(
313+
"LLM service is "
314+
+ SAGEMAKER
315+
+ ", so PostProcessFunction should be "
316+
+ DEFAULT_EMBEDDING
317+
+ " or "
318+
+ DEFAULT_RERANK
319+
);
320+
}
321+
}
322+
}
323+
}
324+
325+
private boolean isInBuiltFunction(String function) {
326+
return (function != null && function.startsWith(INBUILT_FUNC_PREFIX));
327+
}
328+
329+
public static String getRemoteServerFromURL(String url) {
330+
return SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES.stream().filter(url::contains).findFirst().orElse("");
331+
}
332+
188333
public enum ActionType {
189334
PREDICT,
190335
EXECUTE,

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

+5
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ public HttpConnector(
7070
String tenantId
7171
) {
7272
validateProtocol(protocol);
73+
if (actions != null) {
74+
for (ConnectorAction action : actions) {
75+
action.validatePrePostProcessFunctions(parameters);
76+
}
77+
}
7378
this.name = name;
7479
this.description = description;
7580
this.version = version;

common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java

+5
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ public MLCreateConnectorInput(
102102
if (credential == null || credential.isEmpty()) {
103103
throw new IllegalArgumentException("Connector credential is null or empty list");
104104
}
105+
if (actions != null) {
106+
for (ConnectorAction action : actions) {
107+
action.validatePrePostProcessFunctions(parameters);
108+
}
109+
}
105110
}
106111
this.name = name;
107112
this.description = description;

0 commit comments

Comments
 (0)