6
6
package org .opensearch .ml .processor ;
7
7
8
8
import java .io .IOException ;
9
+ import java .lang .reflect .Type ;
9
10
import java .util .ArrayList ;
10
11
import java .util .Arrays ;
11
12
import java .util .List ;
13
+ import java .util .Locale ;
12
14
import java .util .Map ;
13
15
import java .util .stream .Collectors ;
14
16
17
+ import com .google .gson .Gson ;
18
+ import org .apache .logging .log4j .LogManager ;
19
+ import org .apache .logging .log4j .Logger ;
15
20
import org .opensearch .action .ActionRequest ;
16
21
import org .opensearch .ml .common .FunctionName ;
22
+ import org .opensearch .ml .common .dataset .TextDocsInputDataSet ;
17
23
import org .opensearch .ml .common .dataset .remote .RemoteInferenceInputDataSet ;
18
24
import org .opensearch .ml .common .input .MLInput ;
25
+ import org .opensearch .ml .common .input .parameter .MLAlgoParams ;
26
+ import org .opensearch .ml .common .input .parameter .textembedding .AsymmetricTextEmbeddingParameters ;
27
+ import org .opensearch .ml .common .output .model .ModelResultFilter ;
19
28
import org .opensearch .ml .common .output .model .ModelTensor ;
20
29
import org .opensearch .ml .common .output .model .ModelTensorOutput ;
21
30
import org .opensearch .ml .common .output .model .ModelTensors ;
25
34
import com .jayway .jsonpath .Configuration ;
26
35
import com .jayway .jsonpath .JsonPath ;
27
36
import com .jayway .jsonpath .Option ;
37
+ import org .opensearch .ml .repackage .com .google .common .reflect .TypeToken ;
38
+
39
+ import static org .opensearch .ml .common .utils .StringUtils .gson ;
28
40
29
41
/**
30
42
* General ModelExecutor interface.
31
43
*/
32
44
public interface ModelExecutor {
33
45
46
+ Logger logger = LogManager .getLogger (ModelExecutor .class );
47
+
34
48
Configuration suppressExceptionConfiguration = Configuration
35
49
.builder ()
36
50
.options (Option .SUPPRESS_EXCEPTIONS , Option .DEFAULT_PATH_LEAF_TO_NULL )
@@ -45,13 +59,31 @@ public interface ModelExecutor {
45
59
* @return an ActionRequest instance for remote model inference
46
60
* @throws IllegalArgumentException if the input parameters are null
47
61
*/
48
- default <T > ActionRequest getRemoteModelInferenceRequest (Map <String , String > parameters , String modelId ) {
62
+ default <T > ActionRequest getRemoteModelInferenceRequest (Map <String , String > parameters , String modelId , String functionName ) {
63
+ MLInput mlInput = new MLInput ();
49
64
if (parameters == null ) {
50
65
throw new IllegalArgumentException ("wrong input. The model input cannot be empty." );
51
66
}
52
- RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet .builder ().parameters (parameters ).build ();
67
+ if (functionName .equals ("remote" )) {
68
+ RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet .builder ().parameters (parameters ).build ();
69
+ mlInput = MLInput .builder ().algorithm (FunctionName .REMOTE ).inputDataset (inputDataSet ).build ();
70
+ } else if (functionName .equals ("text_embedding" ) || functionName .equals ("sparse_encoding" )) {
71
+ Gson gson = new Gson ();
72
+ String textDocs = parameters .getOrDefault ("text_docs" , "" );
73
+ if (!textDocs .startsWith ("[" ) || !textDocs .endsWith ("]" ) ) {
74
+ textDocs = "[\" " + textDocs + "\" ]" ;
75
+ }
76
+ List <String > docs = gson .fromJson (textDocs , List .class );
77
+ Boolean returnBytes = gson .fromJson (parameters .getOrDefault ("return_bytes" , "false" ), Boolean .class );
78
+ Boolean returnNumber = gson .fromJson (parameters .getOrDefault ("return_number" , "true" ), Boolean .class );
79
+ List <String > targetResponse = gson .fromJson (parameters .getOrDefault ("target_response" , "[]" ), List .class );
80
+ Type listType = new TypeToken <List <Integer >>() {}.getType ();
81
+ List <Integer > targetResponsePositions = gson .fromJson (parameters .getOrDefault ("target_response_positions" , "[]" ), listType );
82
+ ModelResultFilter resultFilter = new ModelResultFilter (returnBytes , returnNumber , targetResponse , targetResponsePositions );
53
83
54
- MLInput mlInput = MLInput .builder ().algorithm (FunctionName .REMOTE ).inputDataset (inputDataSet ).build ();
84
+ TextDocsInputDataSet inputDataSet = TextDocsInputDataSet .builder ().docs (docs ).resultFilter (resultFilter ).build ();
85
+ mlInput = MLInput .builder ().algorithm (FunctionName .TEXT_EMBEDDING ).inputDataset (inputDataSet ).build ();
86
+ }
55
87
56
88
ActionRequest request = new MLPredictionTaskRequest (modelId , mlInput , null );
57
89
0 commit comments