6
6
package org .opensearch .ml .common .connector ;
7
7
8
8
import static org .opensearch .core .xcontent .XContentParserUtils .ensureExpectedToken ;
9
+ import static org .opensearch .ml .common .connector .MLPostProcessFunction .BEDROCK_BATCH_JOB_ARN ;
10
+ import static org .opensearch .ml .common .connector .MLPostProcessFunction .BEDROCK_EMBEDDING ;
11
+ import static org .opensearch .ml .common .connector .MLPostProcessFunction .BEDROCK_RERANK ;
12
+ import static org .opensearch .ml .common .connector .MLPostProcessFunction .COHERE_EMBEDDING ;
13
+ import static org .opensearch .ml .common .connector .MLPostProcessFunction .COHERE_RERANK ;
14
+ import static org .opensearch .ml .common .connector .MLPostProcessFunction .DEFAULT_EMBEDDING ;
15
+ import static org .opensearch .ml .common .connector .MLPostProcessFunction .DEFAULT_RERANK ;
16
+ import static org .opensearch .ml .common .connector .MLPostProcessFunction .OPENAI_EMBEDDING ;
17
+ import static org .opensearch .ml .common .connector .MLPreProcessFunction .IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT ;
18
+ import static org .opensearch .ml .common .connector .MLPreProcessFunction .TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT ;
19
+ import static org .opensearch .ml .common .connector .MLPreProcessFunction .TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT ;
20
+ import static org .opensearch .ml .common .connector .MLPreProcessFunction .TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT ;
21
+ import static org .opensearch .ml .common .connector .MLPreProcessFunction .TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT ;
22
+ import static org .opensearch .ml .common .connector .MLPreProcessFunction .TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT ;
23
+ import static org .opensearch .ml .common .connector .MLPreProcessFunction .TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT ;
24
+ import static org .opensearch .ml .common .connector .MLPreProcessFunction .TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT ;
25
+ import static org .opensearch .ml .common .connector .MLPreProcessFunction .TEXT_SIMILARITY_TO_DEFAULT_INPUT ;
9
26
10
27
import java .io .IOException ;
11
28
import java .util .HashSet ;
29
+ import java .util .List ;
12
30
import java .util .Locale ;
13
31
import java .util .Map ;
14
32
import java .util .Set ;
15
33
34
+ import org .apache .commons .text .StringSubstitutor ;
16
35
import org .opensearch .core .common .io .stream .StreamInput ;
17
36
import org .opensearch .core .common .io .stream .StreamOutput ;
18
37
import org .opensearch .core .common .io .stream .Writeable ;
@@ -35,6 +54,13 @@ public class ConnectorAction implements ToXContentObject, Writeable {
35
54
public static final String REQUEST_BODY_FIELD = "request_body" ;
36
55
public static final String ACTION_PRE_PROCESS_FUNCTION = "pre_process_function" ;
37
56
public static final String ACTION_POST_PROCESS_FUNCTION = "post_process_function" ;
57
+ public static final String OPENAI = "openai" ;
58
+ public static final String COHERE = "cohere" ;
59
+ public static final String BEDROCK = "bedrock" ;
60
+ public static final String SAGEMAKER = "sagemaker" ;
61
+ public static final List <String > SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES = List .of (SAGEMAKER , OPENAI , BEDROCK , COHERE );
62
+
63
+ private static final String INBUILT_FUNC_PREFIX = "connector." ;
38
64
39
65
private ActionType actionType ;
40
66
private String method ;
@@ -185,6 +211,125 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {
185
211
.build ();
186
212
}
187
213
214
+ public void validatePrePostProcessFunctions (Map <String , String > parameters ) {
215
+ var substitutor = new StringSubstitutor (parameters , "${parameters." , "}" );
216
+ var endPoint = substitutor .replace (url );
217
+ var remoteServer = getRemoteServerFromURL (endPoint );
218
+ if (isInBuiltFunction (preProcessFunction )) {
219
+ switch (remoteServer ) {
220
+ case OPENAI :
221
+ if (!TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT .equals (preProcessFunction )) {
222
+ throw new IllegalArgumentException (
223
+ "LLM service is " + OPENAI + ", so PreProcessFunction should be " + TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT
224
+ );
225
+ }
226
+ break ;
227
+ case COHERE :
228
+ if (!(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT .equals (preProcessFunction )
229
+ || IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT .equals (preProcessFunction )
230
+ || TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT .equals (preProcessFunction ))) {
231
+ throw new IllegalArgumentException (
232
+ "LLM service is "
233
+ + COHERE
234
+ + ", so PreProcessFunction should be "
235
+ + TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT
236
+ + " or "
237
+ + IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT
238
+ + " or "
239
+ + TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT
240
+ );
241
+ }
242
+ break ;
243
+ case BEDROCK :
244
+ if (!(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT .equals (preProcessFunction )
245
+ || TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT .equals (preProcessFunction )
246
+ || TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT .equals (preProcessFunction ))) {
247
+ throw new IllegalArgumentException (
248
+ "LLM service is "
249
+ + BEDROCK
250
+ + ", so PreProcessFunction should be "
251
+ + TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT
252
+ + " or "
253
+ + TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT
254
+ + " or "
255
+ + TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT
256
+ );
257
+ }
258
+ break ;
259
+ case SAGEMAKER :
260
+ if (!(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT .equals (preProcessFunction )
261
+ || TEXT_SIMILARITY_TO_DEFAULT_INPUT .equals (preProcessFunction ))) {
262
+ throw new IllegalArgumentException (
263
+ "LLM service is "
264
+ + SAGEMAKER
265
+ + ", so PreProcessFunction should be "
266
+ + TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT
267
+ + " or "
268
+ + TEXT_SIMILARITY_TO_DEFAULT_INPUT
269
+ );
270
+ }
271
+ }
272
+ }
273
+ if (isInBuiltFunction (postProcessFunction )) {
274
+ switch (remoteServer ) {
275
+ case OPENAI :
276
+ if (!OPENAI_EMBEDDING .equals (postProcessFunction )) {
277
+ throw new IllegalArgumentException (
278
+ "LLM service is " + OPENAI + ", so PostProcessFunction should be " + OPENAI_EMBEDDING
279
+ );
280
+ }
281
+ break ;
282
+ case COHERE :
283
+ if (!(COHERE_EMBEDDING .equals (postProcessFunction ) || COHERE_RERANK .equals (postProcessFunction ))) {
284
+ throw new IllegalArgumentException (
285
+ "LLM service is "
286
+ + COHERE
287
+ + ", so PostProcessFunction should be "
288
+ + COHERE_EMBEDDING
289
+ + " or "
290
+ + COHERE_RERANK
291
+ );
292
+ }
293
+ break ;
294
+ case BEDROCK :
295
+ if (!(BEDROCK_EMBEDDING .equals (postProcessFunction )
296
+ || BEDROCK_BATCH_JOB_ARN .equals (postProcessFunction )
297
+ || BEDROCK_RERANK .equals (postProcessFunction ))) {
298
+ throw new IllegalArgumentException (
299
+ "LLM service is "
300
+ + BEDROCK
301
+ + ", so PostProcessFunction should be "
302
+ + BEDROCK_EMBEDDING
303
+ + " or "
304
+ + BEDROCK_BATCH_JOB_ARN
305
+ + " or "
306
+ + BEDROCK_RERANK
307
+ );
308
+ }
309
+ break ;
310
+ case SAGEMAKER :
311
+ if (!(DEFAULT_EMBEDDING .equals (postProcessFunction ) || DEFAULT_RERANK .equals (postProcessFunction ))) {
312
+ throw new IllegalArgumentException (
313
+ "LLM service is "
314
+ + SAGEMAKER
315
+ + ", so PostProcessFunction should be "
316
+ + DEFAULT_EMBEDDING
317
+ + " or "
318
+ + DEFAULT_RERANK
319
+ );
320
+ }
321
+ }
322
+ }
323
+ }
324
+
325
+ private boolean isInBuiltFunction (String function ) {
326
+ return (function != null && function .startsWith (INBUILT_FUNC_PREFIX ));
327
+ }
328
+
329
+ public static String getRemoteServerFromURL (String url ) {
330
+ return SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES .stream ().filter (url ::contains ).findFirst ().orElse ("" );
331
+ }
332
+
188
333
public enum ActionType {
189
334
PREDICT ,
190
335
EXECUTE ,
0 commit comments