Skip to content

Commit 6a250dd

Browse files
add model input validation for local models in ml processor (opensearch-project#2610) (opensearch-project#2615)
* add model input validation for local models in ml processor Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com> --------- Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com> (cherry picked from commit 2b953cd) Co-authored-by: Bhavana Ramaram <rbhavna@amazon.com>
1 parent 63aeaab commit 6a250dd

File tree

2 files changed

+70
-4
lines changed

2 files changed

+70
-4
lines changed

plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java

+14-4
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ public class MLInferenceIngestProcessor extends AbstractProcessor implements Mod
7272
// At default, ml inference processor allows maximum 10 prediction tasks running in parallel
7373
// it can be overwritten using max_prediction_tasks when creating processor
7474
public static final int DEFAULT_MAX_PREDICTION_TASKS = 10;
75+
public static final String DEFAULT_MODEl_INPUT = "{ \"parameters\": ${ml_inference.parameters} }";
7576
private final NamedXContentRegistry xContentRegistry;
7677

7778
private Configuration suppressExceptionConfiguration = Configuration
@@ -489,10 +490,19 @@ public MLInferenceIngestProcessor create(
489490
boolean override = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, OVERRIDE, false);
490491
String functionName = ConfigurationUtils
491492
.readStringProperty(TYPE, processorTag, config, FUNCTION_NAME, FunctionName.REMOTE.name());
492-
String modelInput = ConfigurationUtils
493-
.readStringProperty(TYPE, processorTag, config, MODEL_INPUT, "{ \"parameters\": ${ml_inference.parameters} }");
494-
boolean defaultValue = !functionName.equalsIgnoreCase("remote");
495-
boolean fullResponsePath = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, FULL_RESPONSE_PATH, defaultValue);
493+
494+
String modelInput = ConfigurationUtils.readOptionalStringProperty(TYPE, processorTag, config, MODEL_INPUT);
495+
496+
// if model input is not provided for remote models, use default value
497+
if (functionName.equalsIgnoreCase("remote")) {
498+
modelInput = (modelInput != null) ? modelInput : DEFAULT_MODEl_INPUT;
499+
} else if (modelInput == null) {
500+
// if model input is not provided for local models, throw exception since it is mandatory here
501+
throw new IllegalArgumentException("Please provide model input when using a local model in ML Inference Processor");
502+
}
503+
boolean defaultFullResponsePath = !functionName.equalsIgnoreCase(FunctionName.REMOTE.name());
504+
boolean fullResponsePath = ConfigurationUtils
505+
.readBooleanProperty(TYPE, processorTag, config, FULL_RESPONSE_PATH, defaultFullResponsePath);
496506

497507
boolean ignoreFailure = ConfigurationUtils
498508
.readBooleanProperty(TYPE, processorTag, config, ConfigurationUtils.IGNORE_FAILURE_KEY, false);

plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorFactoryTests.java

+56
Original file line numberDiff line numberDiff line change
@@ -174,4 +174,60 @@ public void testCreateOptionalFields() throws Exception {
174174
assertEquals(mLInferenceIngestProcessor.getTag(), processorTag);
175175
assertEquals(mLInferenceIngestProcessor.getType(), MLInferenceIngestProcessor.TYPE);
176176
}
177+
178+
public void testLocalModel() throws Exception {
179+
Map<String, Processor.Factory> registry = new HashMap<>();
180+
Map<String, Object> config = new HashMap<>();
181+
config.put(MODEL_ID, "model2");
182+
config.put(FUNCTION_NAME, "text_embedding");
183+
Map<String, Object> model_config = new HashMap<>();
184+
model_config.put("return_number", true);
185+
config.put(MODEL_CONFIG, model_config);
186+
config.put(MODEL_INPUT, "{ \"text_docs\": ${ml_inference.text_docs} }");
187+
List<Map<String, String>> inputMap = new ArrayList<>();
188+
Map<String, String> input = new HashMap<>();
189+
input.put("text_docs", "chunks.*.chunk.text.*.context");
190+
inputMap.add(input);
191+
List<Map<String, String>> outputMap = new ArrayList<>();
192+
Map<String, String> output = new HashMap<>();
193+
output.put("chunks.*.chunk.text.*.embedding", "$.inference_results.*.output[2].data");
194+
outputMap.add(output);
195+
config.put(INPUT_MAP, inputMap);
196+
config.put(OUTPUT_MAP, outputMap);
197+
config.put(MAX_PREDICTION_TASKS, 5);
198+
String processorTag = randomAlphaOfLength(10);
199+
200+
MLInferenceIngestProcessor mLInferenceIngestProcessor = factory.create(registry, processorTag, null, config);
201+
assertNotNull(mLInferenceIngestProcessor);
202+
assertEquals(mLInferenceIngestProcessor.getTag(), processorTag);
203+
assertEquals(mLInferenceIngestProcessor.getType(), MLInferenceIngestProcessor.TYPE);
204+
}
205+
206+
public void testModelInputIsNullForLocalModels() throws Exception {
207+
Map<String, Processor.Factory> registry = new HashMap<>();
208+
Map<String, Object> config = new HashMap<>();
209+
config.put(MODEL_ID, "model2");
210+
config.put(FUNCTION_NAME, "text_embedding");
211+
Map<String, Object> model_config = new HashMap<>();
212+
model_config.put("return_number", true);
213+
config.put(MODEL_CONFIG, model_config);
214+
List<Map<String, String>> inputMap = new ArrayList<>();
215+
Map<String, String> input = new HashMap<>();
216+
input.put("text_docs", "chunks.*.chunk.text.*.context");
217+
inputMap.add(input);
218+
List<Map<String, String>> outputMap = new ArrayList<>();
219+
Map<String, String> output = new HashMap<>();
220+
output.put("chunks.*.chunk.text.*.embedding", "$.inference_results.*.output[2].data");
221+
outputMap.add(output);
222+
config.put(INPUT_MAP, inputMap);
223+
config.put(OUTPUT_MAP, outputMap);
224+
config.put(MAX_PREDICTION_TASKS, 5);
225+
String processorTag = randomAlphaOfLength(10);
226+
227+
try {
228+
factory.create(registry, processorTag, null, config);
229+
} catch (IllegalArgumentException e) {
230+
assertEquals(e.getMessage(), ("Please provide model input when using a local model in ML Inference Processor"));
231+
}
232+
}
177233
}

0 commit comments

Comments
 (0)