Skip to content

Commit af0fb54

Browse files
committed
ml inference ingest processor support for local models
Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>
1 parent c620964 commit af0fb54

File tree

5 files changed

+415
-150
lines changed

5 files changed

+415
-150
lines changed

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -1001,7 +1001,10 @@ public void loadExtensions(ExtensionLoader loader) {
10011001
public Map<String, org.opensearch.ingest.Processor.Factory> getProcessors(org.opensearch.ingest.Processor.Parameters parameters) {
10021002
Map<String, org.opensearch.ingest.Processor.Factory> processors = new HashMap<>();
10031003
processors
1004-
.put(MLInferenceIngestProcessor.TYPE, new MLInferenceIngestProcessor.Factory(parameters.scriptService, parameters.client));
1004+
.put(
1005+
MLInferenceIngestProcessor.TYPE,
1006+
new MLInferenceIngestProcessor.Factory(parameters.scriptService, parameters.client, xContentRegistry)
1007+
);
10051008
return Collections.unmodifiableMap(processors);
10061009
}
10071010
}

plugin/src/main/java/org/opensearch/ml/processor/InferenceProcessorAttributes.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,4 +80,4 @@ public InferenceProcessorAttributes(
8080
this.maxPredictionTask = maxPredictionTask;
8181
}
8282

83-
}
83+
}

plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java

+167-28
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66

77
import static org.opensearch.ml.processor.InferenceProcessorAttributes.*;
88

9+
import java.io.IOException;
910
import java.util.ArrayList;
1011
import java.util.Collection;
1112
import java.util.HashMap;
13+
import java.util.HashSet;
1214
import java.util.List;
1315
import java.util.Map;
1416
import java.util.Set;
@@ -19,11 +21,14 @@
1921
import org.opensearch.client.Client;
2022
import org.opensearch.core.action.ActionListener;
2123
import org.opensearch.core.common.Strings;
24+
import org.opensearch.core.xcontent.NamedXContentRegistry;
2225
import org.opensearch.ingest.AbstractProcessor;
2326
import org.opensearch.ingest.ConfigurationUtils;
2427
import org.opensearch.ingest.IngestDocument;
2528
import org.opensearch.ingest.Processor;
2629
import org.opensearch.ingest.ValueSource;
30+
import org.opensearch.ml.common.FunctionName;
31+
import org.opensearch.ml.common.output.MLOutput;
2732
import org.opensearch.ml.common.output.model.ModelTensorOutput;
2833
import org.opensearch.ml.common.transport.MLTaskResponse;
2934
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
@@ -45,17 +50,26 @@ public class MLInferenceIngestProcessor extends AbstractProcessor implements Mod
4550
public static final String DOT_SYMBOL = ".";
4651
private final InferenceProcessorAttributes inferenceProcessorAttributes;
4752
private final boolean ignoreMissing;
53+
private final String functionName;
54+
private final boolean fullResponsePath;
4855
private final boolean ignoreFailure;
56+
private final boolean override;
57+
private final String modelInput;
4958
private final ScriptService scriptService;
5059
private static Client client;
5160
public static final String TYPE = "ml_inference";
5261
public static final String DEFAULT_OUTPUT_FIELD_NAME = "inference_results";
5362
// allow to ignore a field from mapping is not present in the document, and when the outfield is not found in the
5463
// prediction outcomes, return the whole prediction outcome by skipping filtering
5564
public static final String IGNORE_MISSING = "ignore_missing";
65+
public static final String OVERRIDE = "override";
66+
public static final String FUNCTION_NAME = "function_name";
67+
public static final String FULL_RESPONSE_PATH = "full_response_path";
68+
public static final String MODEL_INPUT = "model_input";
5669
// At default, ml inference processor allows maximum 10 prediction tasks running in parallel
5770
// it can be overwritten using max_prediction_tasks when creating processor
5871
public static final int DEFAULT_MAX_PREDICTION_TASKS = 10;
72+
private final NamedXContentRegistry xContentRegistry;
5973

6074
private Configuration suppressExceptionConfiguration = Configuration
6175
.builder()
@@ -71,9 +85,14 @@ protected MLInferenceIngestProcessor(
7185
String tag,
7286
String description,
7387
boolean ignoreMissing,
88+
String functionName,
89+
boolean fullResponsePath,
7490
boolean ignoreFailure,
91+
boolean override,
92+
String modelInput,
7593
ScriptService scriptService,
76-
Client client
94+
Client client,
95+
NamedXContentRegistry xContentRegistry
7796
) {
7897
super(tag, description);
7998
this.inferenceProcessorAttributes = new InferenceProcessorAttributes(
@@ -84,9 +103,14 @@ protected MLInferenceIngestProcessor(
84103
maxPredictionTask
85104
);
86105
this.ignoreMissing = ignoreMissing;
106+
this.functionName = functionName;
107+
this.fullResponsePath = fullResponsePath;
87108
this.ignoreFailure = ignoreFailure;
109+
this.override = override;
110+
this.modelInput = modelInput;
88111
this.scriptService = scriptService;
89112
this.client = client;
113+
this.xContentRegistry = xContentRegistry;
90114
}
91115

92116
/**
@@ -162,10 +186,44 @@ private void processPredictions(
162186
List<Map<String, String>> processOutputMap,
163187
int inputMapIndex,
164188
int inputMapSize
165-
) {
189+
) throws IOException {
166190
Map<String, String> modelParameters = new HashMap<>();
191+
Map<String, String> modelConfigs = new HashMap<>();
192+
167193
if (inferenceProcessorAttributes.getModelConfigMaps() != null) {
168194
modelParameters.putAll(inferenceProcessorAttributes.getModelConfigMaps());
195+
modelConfigs.putAll(inferenceProcessorAttributes.getModelConfigMaps());
196+
}
197+
Map<String, String> outputMapping = processOutputMap.get(inputMapIndex);
198+
199+
Map<String, Object> ingestDocumentSourceAndMetaData = new HashMap<>();
200+
ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata());
201+
ingestDocumentSourceAndMetaData.put(IngestDocument.INGEST_KEY, ingestDocument.getIngestMetadata());
202+
203+
Map<String, List<String>> newOutputMapping = new HashMap<>();
204+
for (Map.Entry<String, String> entry : outputMapping.entrySet()) {
205+
String newDocumentFieldName = entry.getKey();
206+
List<String> dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName);
207+
newOutputMapping.put(newDocumentFieldName, dotPathsInArray);
208+
}
209+
210+
for (Map.Entry<String, String> entry : outputMapping.entrySet()) {
211+
String newDocumentFieldName = entry.getKey();
212+
List<String> dotPaths = newOutputMapping.get(newDocumentFieldName);
213+
214+
int existingFields = 0;
215+
for (String path : dotPaths) {
216+
if (ingestDocument.hasField(path)) {
217+
existingFields++;
218+
}
219+
}
220+
if (!override && existingFields == dotPaths.size()) {
221+
newOutputMapping.remove(newDocumentFieldName);
222+
}
223+
}
224+
if (newOutputMapping.size() == 0) {
225+
batchPredictionListener.onResponse(null);
226+
return;
169227
}
170228
// when no input mapping is provided, default to read all fields from documents as model input
171229
if (inputMapSize == 0) {
@@ -184,15 +242,30 @@ private void processPredictions(
184242
}
185243
}
186244

187-
ActionRequest request = getRemoteModelInferenceRequest(modelParameters, inferenceProcessorAttributes.getModelId());
245+
Set<String> inputMapKeys = new HashSet<>(modelParameters.keySet());
246+
inputMapKeys.removeAll(modelConfigs.keySet());
247+
248+
Map<String, String> inputMappings = new HashMap<>();
249+
for (String k : inputMapKeys) {
250+
inputMappings.put(k, modelParameters.get(k));
251+
}
252+
ActionRequest request = getRemoteModelInferenceRequest(
253+
xContentRegistry,
254+
modelParameters,
255+
modelConfigs,
256+
inputMappings,
257+
inferenceProcessorAttributes.getModelId(),
258+
functionName,
259+
modelInput
260+
);
188261

189262
client.execute(MLPredictionTaskAction.INSTANCE, request, new ActionListener<>() {
190263

191264
@Override
192265
public void onResponse(MLTaskResponse mlTaskResponse) {
193-
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlTaskResponse.getOutput();
266+
MLOutput mlOutput = mlTaskResponse.getOutput();
194267
if (processOutputMap == null || processOutputMap.isEmpty()) {
195-
appendFieldValue(modelTensorOutput, null, DEFAULT_OUTPUT_FIELD_NAME, ingestDocument);
268+
appendFieldValue(mlOutput, null, DEFAULT_OUTPUT_FIELD_NAME, ingestDocument);
196269
} else {
197270
// outMapping serves as a filter to modelTensorOutput, the fields that are not specified
198271
// in the outputMapping will not write to document
@@ -202,14 +275,10 @@ public void onResponse(MLTaskResponse mlTaskResponse) {
202275
// document field as key, model field as value
203276
String newDocumentFieldName = entry.getKey();
204277
String modelOutputFieldName = entry.getValue();
205-
if (ingestDocument.hasField(newDocumentFieldName)) {
206-
throw new IllegalArgumentException(
207-
"document already has field name "
208-
+ newDocumentFieldName
209-
+ ". Not allow to overwrite the same field name, please check output_map."
210-
);
278+
if (!newOutputMapping.containsKey(newDocumentFieldName)) {
279+
continue;
211280
}
212-
appendFieldValue(modelTensorOutput, modelOutputFieldName, newDocumentFieldName, ingestDocument);
281+
appendFieldValue(mlOutput, modelOutputFieldName, newDocumentFieldName, ingestDocument);
213282
}
214283
}
215284
batchPredictionListener.onResponse(null);
@@ -322,16 +391,16 @@ private void appendFieldValue(
322391

323392
modelOutputValue = getModelOutputValue(modelTensorOutput, modelOutputFieldName, ignoreMissing);
324393

325-
Map<String, Object> ingestDocumentSourceAndMetaData = new HashMap<>();
326-
ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata());
327-
ingestDocumentSourceAndMetaData.put(IngestDocument.INGEST_KEY, ingestDocument.getIngestMetadata());
328-
List<String> dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName);
394+
List<String> dotPathsInArray = writeNewDotPathForNestedObject(ingestDocument.getSourceAndMetadata(), newDocumentFieldName);
329395

330396
if (dotPathsInArray.size() == 1) {
331-
ValueSource ingestValue = ValueSource.wrap(modelOutputValue, scriptService);
332-
TemplateScript.Factory ingestField = ConfigurationUtils
333-
.compileTemplate(TYPE, tag, dotPathsInArray.get(0), dotPathsInArray.get(0), scriptService);
334-
ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing);
397+
if (!ingestDocument.hasField(dotPathsInArray.get(0)) || override) {
398+
ValueSource ingestValue = ValueSource.wrap(modelOutputValue, scriptService);
399+
TemplateScript.Factory ingestField = ConfigurationUtils
400+
.compileTemplate(TYPE, tag, dotPathsInArray.get(0), dotPathsInArray.get(0), scriptService);
401+
402+
ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing);
403+
}
335404
} else {
336405
if (!(modelOutputValue instanceof List)) {
337406
throw new IllegalArgumentException("Model output is not an array, cannot assign to array in documents.");
@@ -353,18 +422,73 @@ private void appendFieldValue(
353422
// Iterate over dotPathInArray
354423
for (int i = 0; i < dotPathsInArray.size(); i++) {
355424
String dotPathInArray = dotPathsInArray.get(i);
356-
Object modelOutputValueInArray = modelOutputValueArray.get(i);
357-
ValueSource ingestValue = ValueSource.wrap(modelOutputValueInArray, scriptService);
358-
TemplateScript.Factory ingestField = ConfigurationUtils
359-
.compileTemplate(TYPE, tag, dotPathInArray, dotPathInArray, scriptService);
360-
ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing);
425+
if (!ingestDocument.hasField(dotPathInArray) || override) {
426+
Object modelOutputValueInArray = modelOutputValueArray.get(i);
427+
ValueSource ingestValue = ValueSource.wrap(modelOutputValueInArray, scriptService);
428+
TemplateScript.Factory ingestField = ConfigurationUtils
429+
.compileTemplate(TYPE, tag, dotPathInArray, dotPathInArray, scriptService);
430+
ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing);
431+
}
361432
}
362433
}
363434
} else {
364435
throw new RuntimeException("model inference output cannot be null");
365436
}
366437
}
367438

439+
private void appendFieldValue(
440+
MLOutput mlOutput,
441+
String modelOutputFieldName,
442+
String newDocumentFieldName,
443+
IngestDocument ingestDocument
444+
) {
445+
446+
if (mlOutput == null) {
447+
throw new RuntimeException("model inference output is null");
448+
}
449+
450+
Object modelOutputValue = getModelOutputValue(mlOutput, modelOutputFieldName, ignoreMissing, fullResponsePath);
451+
452+
Map<String, Object> ingestDocumentSourceAndMetaData = new HashMap<>();
453+
ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata());
454+
ingestDocumentSourceAndMetaData.put(IngestDocument.INGEST_KEY, ingestDocument.getIngestMetadata());
455+
List<String> dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName);
456+
457+
if (dotPathsInArray.size() == 1) {
458+
ValueSource ingestValue = ValueSource.wrap(modelOutputValue, scriptService);
459+
TemplateScript.Factory ingestField = ConfigurationUtils
460+
.compileTemplate(TYPE, tag, dotPathsInArray.get(0), dotPathsInArray.get(0), scriptService);
461+
ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing);
462+
} else {
463+
if (!(modelOutputValue instanceof List)) {
464+
throw new IllegalArgumentException("Model output is not an array, cannot assign to array in documents.");
465+
}
466+
List<?> modelOutputValueArray = (List<?>) modelOutputValue;
467+
// check length of the prediction array to be the same of the document array
468+
if (dotPathsInArray.size() != modelOutputValueArray.size()) {
469+
throw new RuntimeException(
470+
"the prediction field: "
471+
+ modelOutputFieldName
472+
+ " is an array in size of "
473+
+ modelOutputValueArray.size()
474+
+ " but the document field array from field "
475+
+ newDocumentFieldName
476+
+ " is in size of "
477+
+ dotPathsInArray.size()
478+
);
479+
}
480+
// Iterate over dotPathInArray
481+
for (int i = 0; i < dotPathsInArray.size(); i++) {
482+
String dotPathInArray = dotPathsInArray.get(i);
483+
Object modelOutputValueInArray = modelOutputValueArray.get(i);
484+
ValueSource ingestValue = ValueSource.wrap(modelOutputValueInArray, scriptService);
485+
TemplateScript.Factory ingestField = ConfigurationUtils
486+
.compileTemplate(TYPE, tag, dotPathInArray, dotPathInArray, scriptService);
487+
ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing);
488+
}
489+
}
490+
}
491+
368492
@Override
369493
public String getType() {
370494
return TYPE;
@@ -374,16 +498,18 @@ public static class Factory implements Processor.Factory {
374498

375499
private final ScriptService scriptService;
376500
private final Client client;
501+
private final NamedXContentRegistry xContentRegistry;
377502

378503
/**
379504
* Constructs a new instance of the Factory class.
380505
*
381506
* @param scriptService the ScriptService instance to be used by the Factory
382507
* @param client the Client instance to be used by the Factory
383508
*/
384-
public Factory(ScriptService scriptService, Client client) {
509+
public Factory(ScriptService scriptService, Client client, NamedXContentRegistry xContentRegistry) {
385510
this.scriptService = scriptService;
386511
this.client = client;
512+
this.xContentRegistry = xContentRegistry;
387513
}
388514

389515
/**
@@ -410,6 +536,14 @@ public MLInferenceIngestProcessor create(
410536
int maxPredictionTask = ConfigurationUtils
411537
.readIntProperty(TYPE, processorTag, config, MAX_PREDICTION_TASKS, DEFAULT_MAX_PREDICTION_TASKS);
412538
boolean ignoreMissing = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, IGNORE_MISSING, false);
539+
boolean override = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, OVERRIDE, false);
540+
String functionName = ConfigurationUtils
541+
.readStringProperty(TYPE, processorTag, config, FUNCTION_NAME, FunctionName.REMOTE.name());
542+
String modelInput = ConfigurationUtils
543+
.readStringProperty(TYPE, processorTag, config, MODEL_INPUT, "{ \"parameters\": ${ml_inference.parameters} }");
544+
boolean defaultValue = !functionName.equals("remote");
545+
boolean fullResponsePath = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, FULL_RESPONSE_PATH, defaultValue);
546+
413547
boolean ignoreFailure = ConfigurationUtils
414548
.readBooleanProperty(TYPE, processorTag, config, ConfigurationUtils.IGNORE_FAILURE_KEY, false);
415549
// convert model config user input data structure to Map<String, String>
@@ -440,11 +574,16 @@ public MLInferenceIngestProcessor create(
440574
processorTag,
441575
description,
442576
ignoreMissing,
577+
functionName,
578+
fullResponsePath,
443579
ignoreFailure,
580+
override,
581+
modelInput,
444582
scriptService,
445-
client
583+
client,
584+
xContentRegistry
446585
);
447586
}
448587
}
449588

450-
}
589+
}

0 commit comments

Comments
 (0)