21
21
import org .opensearch .core .xcontent .XContentParser ;
22
22
import org .opensearch .index .query .MatchAllQueryBuilder ;
23
23
import org .opensearch .ml .common .FunctionName ;
24
+ import org .opensearch .ml .common .connector .ConnectorAction ;
24
25
import org .opensearch .ml .common .dataframe .*;
25
26
import org .opensearch .ml .common .dataset .DataFrameInputDataset ;
26
27
import org .opensearch .ml .common .dataset .MLInputDataset ;
27
28
import org .opensearch .ml .common .dataset .SearchQueryInputDataset ;
28
29
import org .opensearch .ml .common .dataset .TextDocsInputDataSet ;
29
30
import org .opensearch .ml .common .dataset .TextSimilarityInputDataSet ;
31
+ import org .opensearch .ml .common .dataset .remote .RemoteInferenceInputDataSet ;
30
32
import org .opensearch .ml .common .input .nlp .TextSimilarityMLInput ;
31
33
import org .opensearch .ml .common .input .parameter .regression .LinearRegressionParams ;
32
34
import org .opensearch .ml .common .output .model .ModelResultFilter ;
37
39
import java .util .ArrayList ;
38
40
import java .util .Arrays ;
39
41
import java .util .Collections ;
42
+ import java .util .HashMap ;
40
43
import java .util .List ;
44
+ import java .util .Map ;
41
45
import java .util .function .Consumer ;
42
46
import java .util .function .Function ;
43
47
@@ -160,6 +164,40 @@ public void parse_NLPRelated_NullResultFilter() throws IOException {
160
164
parse_NLPModel_NullResultFilter (FunctionName .SPARSE_ENCODING );
161
165
}
162
166
167
+ @ Test
168
+ public void parse_Remote_Model () throws IOException {
169
+ Map <String , String > parameters = Map .of ("TransformJobName" , "new name" );
170
+ RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet .builder ()
171
+ .parameters (parameters )
172
+ .actionType (ConnectorAction .ActionType .PREDICT )
173
+ .build ();
174
+
175
+ String expectedInputStr = "{\" algorithm\" :\" REMOTE\" ,\" parameters\" :{\" TransformJobName\" :\" new name\" },\" action_type\" :\" PREDICT\" }" ;
176
+
177
+ testParse (FunctionName .REMOTE , remoteInferenceInputDataSet , expectedInputStr , parsedInput -> {
178
+ assertNotNull (parsedInput .getInputDataset ());
179
+ RemoteInferenceInputDataSet parsedInputDataSet = (RemoteInferenceInputDataSet ) parsedInput .getInputDataset ();
180
+ assertEquals (ConnectorAction .ActionType .PREDICT , parsedInputDataSet .getActionType ());
181
+ });
182
+ }
183
+
184
+ @ Test
185
+ public void parse_Remote_Model_With_ActionType () throws IOException {
186
+ Map <String , String > parameters = Map .of ("TransformJobName" , "new name" );
187
+ RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet .builder ()
188
+ .parameters (parameters )
189
+ .actionType (ConnectorAction .ActionType .BATCH_PREDICT )
190
+ .build ();
191
+
192
+ String expectedInputStr = "{\" algorithm\" :\" REMOTE\" ,\" parameters\" :{\" TransformJobName\" :\" new name\" },\" action_type\" :\" BATCH_PREDICT\" }" ;
193
+
194
+ testParseWithActionType (FunctionName .REMOTE , remoteInferenceInputDataSet , ConnectorAction .ActionType .BATCH_PREDICT , expectedInputStr , parsedInput -> {
195
+ assertNotNull (parsedInput .getInputDataset ());
196
+ RemoteInferenceInputDataSet parsedInputDataSet = (RemoteInferenceInputDataSet ) parsedInput .getInputDataset ();
197
+ assertEquals (ConnectorAction .ActionType .BATCH_PREDICT , parsedInputDataSet .getActionType ());
198
+ });
199
+ }
200
+
163
201
private void testParse (FunctionName algorithm , MLInputDataset inputDataset , String expectedInputStr , Consumer <MLInput > verify ) throws IOException {
164
202
MLInput input = MLInput .builder ().inputDataset (inputDataset ).algorithm (algorithm ).build ();
165
203
XContentBuilder builder = MediaTypeRegistry .contentBuilder (XContentType .JSON );
@@ -178,6 +216,24 @@ private void testParse(FunctionName algorithm, MLInputDataset inputDataset, Stri
178
216
verify .accept (parsedInput );
179
217
}
180
218
219
+ private void testParseWithActionType (FunctionName algorithm , MLInputDataset inputDataset , ConnectorAction .ActionType actionType , String expectedInputStr , Consumer <MLInput > verify ) throws IOException {
220
+ MLInput input = MLInput .builder ().inputDataset (inputDataset ).algorithm (algorithm ).build ();
221
+ XContentBuilder builder = MediaTypeRegistry .contentBuilder (XContentType .JSON );
222
+ input .toXContent (builder , ToXContent .EMPTY_PARAMS );
223
+ assertNotNull (builder );
224
+ String jsonStr = builder .toString ();
225
+ assertEquals (expectedInputStr , jsonStr );
226
+
227
+ XContentParser parser = XContentType .JSON .xContent ()
228
+ .createParser (new NamedXContentRegistry (new SearchModule (Settings .EMPTY ,
229
+ Collections .emptyList ()).getNamedXContents ()), null , jsonStr );
230
+ parser .nextToken ();
231
+ MLInput parsedInput = MLInput .parse (parser , algorithm .name (), actionType );
232
+ assertEquals (input .getFunctionName (), parsedInput .getFunctionName ());
233
+ assertEquals (input .getInputDataset ().getInputDataType (), parsedInput .getInputDataset ().getInputDataType ());
234
+ verify .accept (parsedInput );
235
+ }
236
+
181
237
@ Test
182
238
public void readInputStream_Success () throws IOException {
183
239
readInputStream (input , parsedInput -> {
0 commit comments