27
27
import org .opensearch .ml .common .dataframe .DefaultDataFrame ;
28
28
import org .opensearch .ml .common .dataframe .DoubleValue ;
29
29
import org .opensearch .ml .common .dataframe .Row ;
30
- import org .opensearch .ml .common .dataset .DataFrameInputDataset ;
31
30
import org .opensearch .ml .common .FunctionName ;
31
+ import org .opensearch .ml .common .connector .ConnectorAction ;
32
+ import org .opensearch .ml .common .dataset .DataFrameInputDataset ;
32
33
import org .opensearch .ml .common .dataset .MLInputDataset ;
33
34
import org .opensearch .ml .common .dataset .SearchQueryInputDataset ;
34
35
import org .opensearch .ml .common .dataset .TextDocsInputDataSet ;
35
36
import org .opensearch .ml .common .dataset .TextSimilarityInputDataSet ;
37
+ import org .opensearch .ml .common .dataset .remote .RemoteInferenceInputDataSet ;
36
38
import org .opensearch .ml .common .input .nlp .TextSimilarityMLInput ;
37
39
import org .opensearch .ml .common .input .parameter .regression .LinearRegressionParams ;
38
40
import org .opensearch .ml .common .output .model .ModelResultFilter ;
44
46
import java .util .Arrays ;
45
47
import java .util .Collections ;
46
48
import java .util .List ;
49
+ import java .util .Map ;
47
50
import java .util .function .Consumer ;
48
51
import java .util .function .Function ;
49
52
@@ -168,6 +171,40 @@ public void parse_NLPRelated_NullResultFilter() throws IOException {
168
171
parse_NLPModel_NullResultFilter (FunctionName .SPARSE_ENCODING );
169
172
}
170
173
174
+ @ Test
175
+ public void parse_Remote_Model () throws IOException {
176
+ Map <String , String > parameters = Map .of ("TransformJobName" , "new name" );
177
+ RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet .builder ()
178
+ .parameters (parameters )
179
+ .actionType (ConnectorAction .ActionType .PREDICT )
180
+ .build ();
181
+
182
+ String expectedInputStr = "{\" algorithm\" :\" REMOTE\" ,\" parameters\" :{\" TransformJobName\" :\" new name\" },\" action_type\" :\" PREDICT\" }" ;
183
+
184
+ testParse (FunctionName .REMOTE , remoteInferenceInputDataSet , expectedInputStr , parsedInput -> {
185
+ assertNotNull (parsedInput .getInputDataset ());
186
+ RemoteInferenceInputDataSet parsedInputDataSet = (RemoteInferenceInputDataSet ) parsedInput .getInputDataset ();
187
+ assertEquals (ConnectorAction .ActionType .PREDICT , parsedInputDataSet .getActionType ());
188
+ });
189
+ }
190
+
191
+ @ Test
192
+ public void parse_Remote_Model_With_ActionType () throws IOException {
193
+ Map <String , String > parameters = Map .of ("TransformJobName" , "new name" );
194
+ RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet .builder ()
195
+ .parameters (parameters )
196
+ .actionType (ConnectorAction .ActionType .BATCH_PREDICT )
197
+ .build ();
198
+
199
+ String expectedInputStr = "{\" algorithm\" :\" REMOTE\" ,\" parameters\" :{\" TransformJobName\" :\" new name\" },\" action_type\" :\" BATCH_PREDICT\" }" ;
200
+
201
+ testParseWithActionType (FunctionName .REMOTE , remoteInferenceInputDataSet , ConnectorAction .ActionType .BATCH_PREDICT , expectedInputStr , parsedInput -> {
202
+ assertNotNull (parsedInput .getInputDataset ());
203
+ RemoteInferenceInputDataSet parsedInputDataSet = (RemoteInferenceInputDataSet ) parsedInput .getInputDataset ();
204
+ assertEquals (ConnectorAction .ActionType .BATCH_PREDICT , parsedInputDataSet .getActionType ());
205
+ });
206
+ }
207
+
171
208
private void testParse (FunctionName algorithm , MLInputDataset inputDataset , String expectedInputStr , Consumer <MLInput > verify ) throws IOException {
172
209
MLInput input = MLInput .builder ().inputDataset (inputDataset ).algorithm (algorithm ).build ();
173
210
XContentBuilder builder = MediaTypeRegistry .contentBuilder (XContentType .JSON );
@@ -186,6 +223,24 @@ private void testParse(FunctionName algorithm, MLInputDataset inputDataset, Stri
186
223
verify .accept (parsedInput );
187
224
}
188
225
226
+ private void testParseWithActionType (FunctionName algorithm , MLInputDataset inputDataset , ConnectorAction .ActionType actionType , String expectedInputStr , Consumer <MLInput > verify ) throws IOException {
227
+ MLInput input = MLInput .builder ().inputDataset (inputDataset ).algorithm (algorithm ).build ();
228
+ XContentBuilder builder = MediaTypeRegistry .contentBuilder (XContentType .JSON );
229
+ input .toXContent (builder , ToXContent .EMPTY_PARAMS );
230
+ assertNotNull (builder );
231
+ String jsonStr = builder .toString ();
232
+ assertEquals (expectedInputStr , jsonStr );
233
+
234
+ XContentParser parser = XContentType .JSON .xContent ()
235
+ .createParser (new NamedXContentRegistry (new SearchModule (Settings .EMPTY ,
236
+ Collections .emptyList ()).getNamedXContents ()), null , jsonStr );
237
+ parser .nextToken ();
238
+ MLInput parsedInput = MLInput .parse (parser , algorithm .name (), actionType );
239
+ assertEquals (input .getFunctionName (), parsedInput .getFunctionName ());
240
+ assertEquals (input .getInputDataset ().getInputDataType (), parsedInput .getInputDataset ().getInputDataType ());
241
+ verify .accept (parsedInput );
242
+ }
243
+
189
244
@ Test
190
245
public void readInputStream_Success () throws IOException {
191
246
readInputStream (input , parsedInput -> {
0 commit comments