Skip to content

Commit 7df638e

Browse files
committed
[Enhancement] Enhance validation for create connector API
This change will address the second part of validation "pre and post processing function validation". Partially resolves opensearch-project#2993 Signed-off-by: Abdul Muneer Kolarkunnu <muneer.kolarkunnu@netapp.com>
1 parent d4ed7f5 commit 7df638e

File tree

6 files changed

+464
-13
lines changed

6 files changed

+464
-13
lines changed

common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java

+148
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,32 @@
66
package org.opensearch.ml.common.connector;
77

88
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
9+
import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_BATCH_JOB_ARN;
10+
import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_EMBEDDING;
11+
import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_RERANK;
12+
import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_EMBEDDING;
13+
import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_RERANK;
14+
import static org.opensearch.ml.common.connector.MLPostProcessFunction.DEFAULT_EMBEDDING;
15+
import static org.opensearch.ml.common.connector.MLPostProcessFunction.DEFAULT_RERANK;
16+
import static org.opensearch.ml.common.connector.MLPostProcessFunction.OPENAI_EMBEDDING;
17+
import static org.opensearch.ml.common.connector.MLPreProcessFunction.IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT;
18+
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT;
19+
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT;
20+
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT;
21+
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT;
22+
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT;
23+
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT;
24+
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT;
25+
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_SIMILARITY_TO_DEFAULT_INPUT;
926

1027
import java.io.IOException;
1128
import java.util.HashSet;
29+
import java.util.List;
1230
import java.util.Locale;
1331
import java.util.Map;
1432
import java.util.Set;
1533

34+
import org.apache.commons.text.StringSubstitutor;
1635
import org.opensearch.core.common.io.stream.StreamInput;
1736
import org.opensearch.core.common.io.stream.StreamOutput;
1837
import org.opensearch.core.common.io.stream.Writeable;
@@ -35,6 +54,13 @@ public class ConnectorAction implements ToXContentObject, Writeable {
3554
public static final String REQUEST_BODY_FIELD = "request_body";
3655
public static final String ACTION_PRE_PROCESS_FUNCTION = "pre_process_function";
3756
public static final String ACTION_POST_PROCESS_FUNCTION = "post_process_function";
57+
public static final String OPENAI = "openai";
58+
public static final String COHERE = "cohere";
59+
public static final String BEDROCK = "bedrock";
60+
public static final String SAGEMAKER = "sagemaker";
61+
public static final List<String> SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES = List.of(SAGEMAKER, OPENAI, BEDROCK, COHERE);
62+
63+
private static final String INBUILT_FUNC_PREFIX = "connector.";
3864

3965
private ActionType actionType;
4066
private String method;
@@ -185,6 +211,128 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {
185211
.build();
186212
}
187213

214+
public void validatePrePostProcessFunctions(Map<String, String> parameters) {
215+
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
216+
String endPoint = substitutor.replace(url);
217+
String remoteServer = getRemoteServerFromURL(endPoint);
218+
validatePreProcessFunctions(remoteServer);
219+
validatePostProcessFunctions(remoteServer);
220+
}
221+
222+
private void validatePreProcessFunctions(String remoteServer) {
223+
if (isInBuiltFunction(preProcessFunction)) {
224+
switch (remoteServer) {
225+
case OPENAI:
226+
if (!TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT.equals(preProcessFunction)) {
227+
throw new IllegalArgumentException(
228+
"LLM service is " + OPENAI + ", so PreProcessFunction should be " + TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT
229+
);
230+
}
231+
break;
232+
case COHERE:
233+
if (!(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT.equals(preProcessFunction)
234+
|| IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT.equals(preProcessFunction)
235+
|| TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT.equals(preProcessFunction))) {
236+
throw new IllegalArgumentException(
237+
"LLM service is "
238+
+ COHERE
239+
+ ", so PreProcessFunction should be "
240+
+ TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT
241+
+ " or "
242+
+ IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT
243+
+ " or "
244+
+ TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT
245+
);
246+
}
247+
break;
248+
case BEDROCK:
249+
if (!(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT.equals(preProcessFunction)
250+
|| TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT.equals(preProcessFunction)
251+
|| TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT.equals(preProcessFunction))) {
252+
throw new IllegalArgumentException(
253+
"LLM service is "
254+
+ BEDROCK
255+
+ ", so PreProcessFunction should be "
256+
+ TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT
257+
+ " or "
258+
+ TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT
259+
+ " or "
260+
+ TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT
261+
);
262+
}
263+
break;
264+
case SAGEMAKER:
265+
if (!(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT.equals(preProcessFunction)
266+
|| TEXT_SIMILARITY_TO_DEFAULT_INPUT.equals(preProcessFunction))) {
267+
throw new IllegalArgumentException(
268+
"LLM service is "
269+
+ SAGEMAKER
270+
+ ", so PreProcessFunction should be "
271+
+ TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT
272+
+ " or "
273+
+ TEXT_SIMILARITY_TO_DEFAULT_INPUT
274+
);
275+
}
276+
}
277+
}
278+
}
279+
280+
private void validatePostProcessFunctions(String remoteServer) {
281+
if (isInBuiltFunction(postProcessFunction)) {
282+
switch (remoteServer) {
283+
case OPENAI:
284+
if (!OPENAI_EMBEDDING.equals(postProcessFunction)) {
285+
throw new IllegalArgumentException(
286+
"LLM service is " + OPENAI + ", so PostProcessFunction should be " + OPENAI_EMBEDDING
287+
);
288+
}
289+
break;
290+
case COHERE:
291+
if (!(COHERE_EMBEDDING.equals(postProcessFunction) || COHERE_RERANK.equals(postProcessFunction))) {
292+
throw new IllegalArgumentException(
293+
"LLM service is " + COHERE + ", so PostProcessFunction should be " + COHERE_EMBEDDING + " or " + COHERE_RERANK
294+
);
295+
}
296+
break;
297+
case BEDROCK:
298+
if (!(BEDROCK_EMBEDDING.equals(postProcessFunction)
299+
|| BEDROCK_BATCH_JOB_ARN.equals(postProcessFunction)
300+
|| BEDROCK_RERANK.equals(postProcessFunction))) {
301+
throw new IllegalArgumentException(
302+
"LLM service is "
303+
+ BEDROCK
304+
+ ", so PostProcessFunction should be "
305+
+ BEDROCK_EMBEDDING
306+
+ " or "
307+
+ BEDROCK_BATCH_JOB_ARN
308+
+ " or "
309+
+ BEDROCK_RERANK
310+
);
311+
}
312+
break;
313+
case SAGEMAKER:
314+
if (!(DEFAULT_EMBEDDING.equals(postProcessFunction) || DEFAULT_RERANK.equals(postProcessFunction))) {
315+
throw new IllegalArgumentException(
316+
"LLM service is "
317+
+ SAGEMAKER
318+
+ ", so PostProcessFunction should be "
319+
+ DEFAULT_EMBEDDING
320+
+ " or "
321+
+ DEFAULT_RERANK
322+
);
323+
}
324+
}
325+
}
326+
}
327+
328+
private boolean isInBuiltFunction(String function) {
329+
return (function != null && function.startsWith(INBUILT_FUNC_PREFIX));
330+
}
331+
332+
public static String getRemoteServerFromURL(String url) {
333+
return SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES.stream().filter(url::contains).findFirst().orElse("");
334+
}
335+
188336
public enum ActionType {
189337
PREDICT,
190338
EXECUTE,

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

+5
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ public HttpConnector(
7070
String tenantId
7171
) {
7272
validateProtocol(protocol);
73+
if (actions != null) {
74+
for (ConnectorAction action : actions) {
75+
action.validatePrePostProcessFunctions(parameters);
76+
}
77+
}
7378
this.name = name;
7479
this.description = description;
7580
this.version = version;

common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java

+5
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ public MLCreateConnectorInput(
102102
if (credential == null || credential.isEmpty()) {
103103
throw new IllegalArgumentException("Connector credential is null or empty list");
104104
}
105+
if (actions != null) {
106+
for (ConnectorAction action : actions) {
107+
action.validatePrePostProcessFunctions(parameters);
108+
}
109+
}
105110
}
106111
this.name = name;
107112
this.description = description;

0 commit comments

Comments
 (0)