9
9
import org .opensearch .core .common .io .stream .StreamOutput ;
10
10
import org .opensearch .core .xcontent .XContentParser ;
11
11
import org .opensearch .ml .common .FunctionName ;
12
+ import org .opensearch .ml .common .PredictMode ;
12
13
import org .opensearch .ml .common .dataset .remote .RemoteInferenceInputDataSet ;
13
14
import org .opensearch .ml .common .input .MLInput ;
14
15
import org .opensearch .ml .common .utils .StringUtils ;
21
22
@ org .opensearch .ml .common .annotation .MLInput (functionNames = {FunctionName .REMOTE })
22
23
public class RemoteInferenceMLInput extends MLInput {
23
24
public static final String PARAMETERS_FIELD = "parameters" ;
25
+ public static final String PREDICT_MODE_FIELD = "mode" ;
24
26
25
27
public RemoteInferenceMLInput (StreamInput in ) throws IOException {
26
28
super (in );
@@ -34,21 +36,26 @@ public void writeTo(StreamOutput out) throws IOException {
34
36
public RemoteInferenceMLInput (XContentParser parser , FunctionName functionName ) throws IOException {
35
37
super ();
36
38
this .algorithm = functionName ;
39
+ Map <String , String > parameters = null ;
40
+ PredictMode predictMode = null ;
37
41
ensureExpectedToken (XContentParser .Token .START_OBJECT , parser .currentToken (), parser );
38
42
while (parser .nextToken () != XContentParser .Token .END_OBJECT ) {
39
43
String fieldName = parser .currentName ();
40
44
parser .nextToken ();
41
45
42
46
switch (fieldName ) {
43
47
case PARAMETERS_FIELD :
44
- Map <String , String > parameters = StringUtils .getParameterMap (parser .map ());
45
- inputDataset = new RemoteInferenceInputDataSet (parameters );
48
+ parameters = StringUtils .getParameterMap (parser .map ());
46
49
break ;
50
+ case PREDICT_MODE_FIELD :
51
+ predictMode = PredictMode .from (parser .text ());
47
52
default :
48
53
parser .skipChildren ();
49
54
break ;
50
55
}
51
56
}
57
+ predictMode = predictMode == null ? PredictMode .PREDICT :predictMode ;
58
+ inputDataset = new RemoteInferenceInputDataSet (parameters , predictMode );
52
59
}
53
60
54
61
}
0 commit comments