5
5
6
6
package org .opensearch .ml .common .connector ;
7
7
8
- import com .google .common .collect .ImmutableList ;
9
- import org .opensearch .ml .common .output .model .MLResultDataType ;
8
+ import org .opensearch .ml .common .connector .functions .postprocess .BedrockEmbeddingPostProcessFunction ;
9
+ import org .opensearch .ml .common .connector .functions .postprocess .CohereRerankPostProcessFunction ;
10
+ import org .opensearch .ml .common .connector .functions .postprocess .EmbeddingPostProcessFunction ;
10
11
import org .opensearch .ml .common .output .model .ModelTensor ;
11
12
12
- import java .util .ArrayList ;
13
13
import java .util .HashMap ;
14
14
import java .util .List ;
15
15
import java .util .Map ;
@@ -20,58 +20,41 @@ public class MLPostProcessFunction {
20
20
public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding" ;
21
21
public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding" ;
22
22
public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding" ;
23
+ public static final String COHERE_RERANK = "connector.post_process.cohere.rerank" ;
23
24
public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding" ;
25
+ public static final String DEFAULT_RERANK = "connector.post_process.default.rerank" ;
24
26
25
27
private static final Map <String , String > JSON_PATH_EXPRESSION = new HashMap <>();
26
28
27
- private static final Map <String , Function <List <?>, List <ModelTensor >>> POST_PROCESS_FUNCTIONS = new HashMap <>();
28
-
29
+ private static final Map <String , Function <Object , List <ModelTensor >>> POST_PROCESS_FUNCTIONS = new HashMap <>();
29
30
30
31
static {
32
+ EmbeddingPostProcessFunction embeddingPostProcessFunction = new EmbeddingPostProcessFunction ();
33
+ BedrockEmbeddingPostProcessFunction bedrockEmbeddingPostProcessFunction = new BedrockEmbeddingPostProcessFunction ();
34
+ CohereRerankPostProcessFunction cohereRerankPostProcessFunction = new CohereRerankPostProcessFunction ();
31
35
JSON_PATH_EXPRESSION .put (OPENAI_EMBEDDING , "$.data[*].embedding" );
32
36
JSON_PATH_EXPRESSION .put (COHERE_EMBEDDING , "$.embeddings" );
33
37
JSON_PATH_EXPRESSION .put (DEFAULT_EMBEDDING , "$[*]" );
34
38
JSON_PATH_EXPRESSION .put (BEDROCK_EMBEDDING , "$.embedding" );
35
- POST_PROCESS_FUNCTIONS .put (OPENAI_EMBEDDING , buildModelTensorList ());
36
- POST_PROCESS_FUNCTIONS .put (COHERE_EMBEDDING , buildModelTensorList ());
37
- POST_PROCESS_FUNCTIONS .put (DEFAULT_EMBEDDING , buildModelTensorList ());
38
- POST_PROCESS_FUNCTIONS .put (BEDROCK_EMBEDDING , buildModelTensorList ());
39
- }
40
-
41
- public static Function <List <?>, List <ModelTensor >> buildModelTensorList () {
42
- return embeddings -> {
43
- List <ModelTensor > modelTensors = new ArrayList <>();
44
- if (embeddings == null ) {
45
- throw new IllegalArgumentException ("The list of embeddings is null when using the built-in post-processing function." );
46
- }
47
- if (embeddings .get (0 ) instanceof Number ) {
48
- embeddings = ImmutableList .of (embeddings );
49
- }
50
- embeddings .forEach (embedding -> {
51
- List <Number > eachEmbedding = (List <Number >) embedding ;
52
- modelTensors .add (
53
- ModelTensor
54
- .builder ()
55
- .name ("sentence_embedding" )
56
- .dataType (MLResultDataType .FLOAT32 )
57
- .shape (new long []{eachEmbedding .size ()})
58
- .data (eachEmbedding .toArray (new Number [0 ]))
59
- .build ()
60
- );
61
- });
62
- return modelTensors ;
63
- };
39
+ JSON_PATH_EXPRESSION .put (COHERE_RERANK , "$.results" );
40
+ JSON_PATH_EXPRESSION .put (DEFAULT_RERANK , "$[*]" );
41
+ POST_PROCESS_FUNCTIONS .put (OPENAI_EMBEDDING , embeddingPostProcessFunction );
42
+ POST_PROCESS_FUNCTIONS .put (COHERE_EMBEDDING , embeddingPostProcessFunction );
43
+ POST_PROCESS_FUNCTIONS .put (DEFAULT_EMBEDDING , embeddingPostProcessFunction );
44
+ POST_PROCESS_FUNCTIONS .put (BEDROCK_EMBEDDING , bedrockEmbeddingPostProcessFunction );
45
+ POST_PROCESS_FUNCTIONS .put (COHERE_RERANK , cohereRerankPostProcessFunction );
46
+ POST_PROCESS_FUNCTIONS .put (DEFAULT_RERANK , cohereRerankPostProcessFunction );
64
47
}
65
48
66
49
public static String getResponseFilter (String postProcessFunction ) {
67
50
return JSON_PATH_EXPRESSION .get (postProcessFunction );
68
51
}
69
52
70
- public static Function <List <?> , List <ModelTensor >> get (String postProcessFunction ) {
53
+ public static Function <Object , List <ModelTensor >> get (String postProcessFunction ) {
71
54
return POST_PROCESS_FUNCTIONS .get (postProcessFunction );
72
55
}
73
56
74
57
public static boolean contains (String postProcessFunction ) {
75
58
return POST_PROCESS_FUNCTIONS .containsKey (postProcessFunction );
76
59
}
77
- }
60
+ }
0 commit comments