@@ -174,4 +174,60 @@ public void testCreateOptionalFields() throws Exception {
174
174
assertEquals (mLInferenceIngestProcessor .getTag (), processorTag );
175
175
assertEquals (mLInferenceIngestProcessor .getType (), MLInferenceIngestProcessor .TYPE );
176
176
}
177
+
178
+ public void testLocalModel () throws Exception {
179
+ Map <String , Processor .Factory > registry = new HashMap <>();
180
+ Map <String , Object > config = new HashMap <>();
181
+ config .put (MODEL_ID , "model2" );
182
+ config .put (FUNCTION_NAME , "text_embedding" );
183
+ Map <String , Object > model_config = new HashMap <>();
184
+ model_config .put ("return_number" , true );
185
+ config .put (MODEL_CONFIG , model_config );
186
+ config .put (MODEL_INPUT , "{ \" text_docs\" : ${ml_inference.text_docs} }" );
187
+ List <Map <String , String >> inputMap = new ArrayList <>();
188
+ Map <String , String > input = new HashMap <>();
189
+ input .put ("text_docs" , "chunks.*.chunk.text.*.context" );
190
+ inputMap .add (input );
191
+ List <Map <String , String >> outputMap = new ArrayList <>();
192
+ Map <String , String > output = new HashMap <>();
193
+ output .put ("chunks.*.chunk.text.*.embedding" , "$.inference_results.*.output[2].data" );
194
+ outputMap .add (output );
195
+ config .put (INPUT_MAP , inputMap );
196
+ config .put (OUTPUT_MAP , outputMap );
197
+ config .put (MAX_PREDICTION_TASKS , 5 );
198
+ String processorTag = randomAlphaOfLength (10 );
199
+
200
+ MLInferenceIngestProcessor mLInferenceIngestProcessor = factory .create (registry , processorTag , null , config );
201
+ assertNotNull (mLInferenceIngestProcessor );
202
+ assertEquals (mLInferenceIngestProcessor .getTag (), processorTag );
203
+ assertEquals (mLInferenceIngestProcessor .getType (), MLInferenceIngestProcessor .TYPE );
204
+ }
205
+
206
+ public void testModelInputIsNullForLocalModels () throws Exception {
207
+ Map <String , Processor .Factory > registry = new HashMap <>();
208
+ Map <String , Object > config = new HashMap <>();
209
+ config .put (MODEL_ID , "model2" );
210
+ config .put (FUNCTION_NAME , "text_embedding" );
211
+ Map <String , Object > model_config = new HashMap <>();
212
+ model_config .put ("return_number" , true );
213
+ config .put (MODEL_CONFIG , model_config );
214
+ List <Map <String , String >> inputMap = new ArrayList <>();
215
+ Map <String , String > input = new HashMap <>();
216
+ input .put ("text_docs" , "chunks.*.chunk.text.*.context" );
217
+ inputMap .add (input );
218
+ List <Map <String , String >> outputMap = new ArrayList <>();
219
+ Map <String , String > output = new HashMap <>();
220
+ output .put ("chunks.*.chunk.text.*.embedding" , "$.inference_results.*.output[2].data" );
221
+ outputMap .add (output );
222
+ config .put (INPUT_MAP , inputMap );
223
+ config .put (OUTPUT_MAP , outputMap );
224
+ config .put (MAX_PREDICTION_TASKS , 5 );
225
+ String processorTag = randomAlphaOfLength (10 );
226
+
227
+ try {
228
+ factory .create (registry , processorTag , null , config );
229
+ } catch (IllegalArgumentException e ) {
230
+ assertEquals (e .getMessage (), ("Please provide model input when using a local model in ML Inference Processor" ));
231
+ }
232
+ }
177
233
}
0 commit comments