5
5
6
6
package org .opensearch .ml .common .connector ;
7
7
8
+ import com .google .common .collect .ImmutableList ;
8
9
import org .opensearch .ml .common .output .model .MLResultDataType ;
9
10
import org .opensearch .ml .common .output .model .ModelTensor ;
10
11
@@ -18,38 +19,46 @@ public class MLPostProcessFunction {
18
19
19
20
public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding" ;
20
21
public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding" ;
21
-
22
+ public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding" ;
22
23
public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding" ;
23
24
24
25
private static final Map <String , String > JSON_PATH_EXPRESSION = new HashMap <>();
25
26
26
- private static final Map <String , Function <List <List < Float > >, List <ModelTensor >>> POST_PROCESS_FUNCTIONS = new HashMap <>();
27
+ private static final Map <String , Function <List <? >, List <ModelTensor >>> POST_PROCESS_FUNCTIONS = new HashMap <>();
27
28
28
29
29
30
static {
30
31
JSON_PATH_EXPRESSION .put (OPENAI_EMBEDDING , "$.data[*].embedding" );
31
32
JSON_PATH_EXPRESSION .put (COHERE_EMBEDDING , "$.embeddings" );
32
33
JSON_PATH_EXPRESSION .put (DEFAULT_EMBEDDING , "$[*]" );
34
+ JSON_PATH_EXPRESSION .put (BEDROCK_EMBEDDING , "$.embedding" );
33
35
POST_PROCESS_FUNCTIONS .put (OPENAI_EMBEDDING , buildModelTensorList ());
34
36
POST_PROCESS_FUNCTIONS .put (COHERE_EMBEDDING , buildModelTensorList ());
35
37
POST_PROCESS_FUNCTIONS .put (DEFAULT_EMBEDDING , buildModelTensorList ());
38
+ POST_PROCESS_FUNCTIONS .put (BEDROCK_EMBEDDING , buildModelTensorList ());
36
39
}
37
40
38
- public static Function <List <List < Float > >, List <ModelTensor >> buildModelTensorList () {
41
+ public static Function <List <? >, List <ModelTensor >> buildModelTensorList () {
39
42
return embeddings -> {
40
43
List <ModelTensor > modelTensors = new ArrayList <>();
41
44
if (embeddings == null ) {
42
45
throw new IllegalArgumentException ("The list of embeddings is null when using the built-in post-processing function." );
43
46
}
44
- embeddings .forEach (embedding -> modelTensors .add (
45
- ModelTensor
46
- .builder ()
47
- .name ("sentence_embedding" )
48
- .dataType (MLResultDataType .FLOAT32 )
49
- .shape (new long []{embedding .size ()})
50
- .data (embedding .toArray (new Number [0 ]))
51
- .build ()
52
- ));
47
+ if (embeddings .get (0 ) instanceof Number ) {
48
+ embeddings = ImmutableList .of (embeddings );
49
+ }
50
+ embeddings .forEach (embedding -> {
51
+ List <Number > eachEmbedding = (List <Number >) embedding ;
52
+ modelTensors .add (
53
+ ModelTensor
54
+ .builder ()
55
+ .name ("sentence_embedding" )
56
+ .dataType (MLResultDataType .FLOAT32 )
57
+ .shape (new long []{eachEmbedding .size ()})
58
+ .data (eachEmbedding .toArray (new Number [0 ]))
59
+ .build ()
60
+ );
61
+ });
53
62
return modelTensors ;
54
63
};
55
64
}
@@ -58,7 +67,7 @@ public static String getResponseFilter(String postProcessFunction) {
58
67
return JSON_PATH_EXPRESSION .get (postProcessFunction );
59
68
}
60
69
61
- public static Function <List <List < Float > >, List <ModelTensor >> get (String postProcessFunction ) {
70
+ public static Function <List <? >, List <ModelTensor >> get (String postProcessFunction ) {
62
71
return POST_PROCESS_FUNCTIONS .get (postProcessFunction );
63
72
}
64
73
0 commit comments