6
6
package org .opensearch .ml .common .connector ;
7
7
8
8
import static org .opensearch .core .xcontent .XContentParserUtils .ensureExpectedToken ;
9
+ import static org .opensearch .ml .common .connector .MLPostProcessFunction .BEDROCK_BATCH_JOB_ARN ;
10
+ import static org .opensearch .ml .common .connector .MLPostProcessFunction .BEDROCK_EMBEDDING ;
11
+ import static org .opensearch .ml .common .connector .MLPostProcessFunction .BEDROCK_RERANK ;
12
+ import static org .opensearch .ml .common .connector .MLPostProcessFunction .COHERE_EMBEDDING ;
13
+ import static org .opensearch .ml .common .connector .MLPostProcessFunction .COHERE_RERANK ;
14
+ import static org .opensearch .ml .common .connector .MLPostProcessFunction .DEFAULT_EMBEDDING ;
15
+ import static org .opensearch .ml .common .connector .MLPostProcessFunction .DEFAULT_RERANK ;
16
+ import static org .opensearch .ml .common .connector .MLPostProcessFunction .OPENAI_EMBEDDING ;
17
+ import static org .opensearch .ml .common .connector .MLPreProcessFunction .IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT ;
18
+ import static org .opensearch .ml .common .connector .MLPreProcessFunction .TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT ;
19
+ import static org .opensearch .ml .common .connector .MLPreProcessFunction .TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT ;
20
+ import static org .opensearch .ml .common .connector .MLPreProcessFunction .TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT ;
21
+ import static org .opensearch .ml .common .connector .MLPreProcessFunction .TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT ;
22
+ import static org .opensearch .ml .common .connector .MLPreProcessFunction .TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT ;
23
+ import static org .opensearch .ml .common .connector .MLPreProcessFunction .TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT ;
24
+ import static org .opensearch .ml .common .connector .MLPreProcessFunction .TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT ;
25
+ import static org .opensearch .ml .common .connector .MLPreProcessFunction .TEXT_SIMILARITY_TO_DEFAULT_INPUT ;
9
26
10
27
import java .io .IOException ;
11
28
import java .util .HashSet ;
29
+ import java .util .List ;
12
30
import java .util .Locale ;
13
31
import java .util .Map ;
14
32
import java .util .Set ;
15
33
34
+ import org .apache .commons .text .StringSubstitutor ;
16
35
import org .opensearch .core .common .io .stream .StreamInput ;
17
36
import org .opensearch .core .common .io .stream .StreamOutput ;
18
37
import org .opensearch .core .common .io .stream .Writeable ;
@@ -35,6 +54,13 @@ public class ConnectorAction implements ToXContentObject, Writeable {
35
54
public static final String REQUEST_BODY_FIELD = "request_body" ;
36
55
public static final String ACTION_PRE_PROCESS_FUNCTION = "pre_process_function" ;
37
56
public static final String ACTION_POST_PROCESS_FUNCTION = "post_process_function" ;
57
+ public static final String OPENAI = "openai" ;
58
+ public static final String COHERE = "cohere" ;
59
+ public static final String BEDROCK = "bedrock" ;
60
+ public static final String SAGEMAKER = "sagemaker" ;
61
+ public static final List <String > SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES = List .of (SAGEMAKER , OPENAI , BEDROCK , COHERE );
62
+
63
+ private static final String INBUILT_FUNC_PREFIX = "connector." ;
38
64
39
65
private ActionType actionType ;
40
66
private String method ;
@@ -185,6 +211,128 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {
185
211
.build ();
186
212
}
187
213
214
+ public void validatePrePostProcessFunctions (Map <String , String > parameters ) {
215
+ StringSubstitutor substitutor = new StringSubstitutor (parameters , "${parameters." , "}" );
216
+ String endPoint = substitutor .replace (url );
217
+ String remoteServer = getRemoteServerFromURL (endPoint );
218
+ validatePreProcessFunctions (remoteServer );
219
+ validatePostProcessFunctions (remoteServer );
220
+ }
221
+
222
+ private void validatePreProcessFunctions (String remoteServer ) {
223
+ if (isInBuiltFunction (preProcessFunction )) {
224
+ switch (remoteServer ) {
225
+ case OPENAI :
226
+ if (!TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT .equals (preProcessFunction )) {
227
+ throw new IllegalArgumentException (
228
+ "LLM service is " + OPENAI + ", so PreProcessFunction should be " + TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT
229
+ );
230
+ }
231
+ break ;
232
+ case COHERE :
233
+ if (!(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT .equals (preProcessFunction )
234
+ || IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT .equals (preProcessFunction )
235
+ || TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT .equals (preProcessFunction ))) {
236
+ throw new IllegalArgumentException (
237
+ "LLM service is "
238
+ + COHERE
239
+ + ", so PreProcessFunction should be "
240
+ + TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT
241
+ + " or "
242
+ + IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT
243
+ + " or "
244
+ + TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT
245
+ );
246
+ }
247
+ break ;
248
+ case BEDROCK :
249
+ if (!(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT .equals (preProcessFunction )
250
+ || TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT .equals (preProcessFunction )
251
+ || TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT .equals (preProcessFunction ))) {
252
+ throw new IllegalArgumentException (
253
+ "LLM service is "
254
+ + BEDROCK
255
+ + ", so PreProcessFunction should be "
256
+ + TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT
257
+ + " or "
258
+ + TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT
259
+ + " or "
260
+ + TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT
261
+ );
262
+ }
263
+ break ;
264
+ case SAGEMAKER :
265
+ if (!(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT .equals (preProcessFunction )
266
+ || TEXT_SIMILARITY_TO_DEFAULT_INPUT .equals (preProcessFunction ))) {
267
+ throw new IllegalArgumentException (
268
+ "LLM service is "
269
+ + SAGEMAKER
270
+ + ", so PreProcessFunction should be "
271
+ + TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT
272
+ + " or "
273
+ + TEXT_SIMILARITY_TO_DEFAULT_INPUT
274
+ );
275
+ }
276
+ }
277
+ }
278
+ }
279
+
280
+ private void validatePostProcessFunctions (String remoteServer ) {
281
+ if (isInBuiltFunction (postProcessFunction )) {
282
+ switch (remoteServer ) {
283
+ case OPENAI :
284
+ if (!OPENAI_EMBEDDING .equals (postProcessFunction )) {
285
+ throw new IllegalArgumentException (
286
+ "LLM service is " + OPENAI + ", so PostProcessFunction should be " + OPENAI_EMBEDDING
287
+ );
288
+ }
289
+ break ;
290
+ case COHERE :
291
+ if (!(COHERE_EMBEDDING .equals (postProcessFunction ) || COHERE_RERANK .equals (postProcessFunction ))) {
292
+ throw new IllegalArgumentException (
293
+ "LLM service is " + COHERE + ", so PostProcessFunction should be " + COHERE_EMBEDDING + " or " + COHERE_RERANK
294
+ );
295
+ }
296
+ break ;
297
+ case BEDROCK :
298
+ if (!(BEDROCK_EMBEDDING .equals (postProcessFunction )
299
+ || BEDROCK_BATCH_JOB_ARN .equals (postProcessFunction )
300
+ || BEDROCK_RERANK .equals (postProcessFunction ))) {
301
+ throw new IllegalArgumentException (
302
+ "LLM service is "
303
+ + BEDROCK
304
+ + ", so PostProcessFunction should be "
305
+ + BEDROCK_EMBEDDING
306
+ + " or "
307
+ + BEDROCK_BATCH_JOB_ARN
308
+ + " or "
309
+ + BEDROCK_RERANK
310
+ );
311
+ }
312
+ break ;
313
+ case SAGEMAKER :
314
+ if (!(DEFAULT_EMBEDDING .equals (postProcessFunction ) || DEFAULT_RERANK .equals (postProcessFunction ))) {
315
+ throw new IllegalArgumentException (
316
+ "LLM service is "
317
+ + SAGEMAKER
318
+ + ", so PostProcessFunction should be "
319
+ + DEFAULT_EMBEDDING
320
+ + " or "
321
+ + DEFAULT_RERANK
322
+ );
323
+ }
324
+ }
325
+ }
326
+ }
327
+
328
+ private boolean isInBuiltFunction (String function ) {
329
+ return (function != null && function .startsWith (INBUILT_FUNC_PREFIX ));
330
+ }
331
+
332
+ public static String getRemoteServerFromURL (String url ) {
333
+ return SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES .stream ().filter (url ::contains ).findFirst ().orElse ("" );
334
+ }
335
+
188
336
public enum ActionType {
189
337
PREDICT ,
190
338
EXECUTE ,
0 commit comments