Skip to content

Commit 1d36acc

Browse files
ml inference ingest processor support for local models (opensearch-project#2508) (opensearch-project#2532)
* ml inference ingest processor support for local models Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com> (cherry picked from commit 7cd5291) Co-authored-by: Bhavana Ramaram <rbhavna@amazon.com>
1 parent 7daf457 commit 1d36acc

File tree

6 files changed

+1265
-128
lines changed

6 files changed

+1265
-128
lines changed

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -1007,7 +1007,10 @@ public void loadExtensions(ExtensionLoader loader) {
10071007
public Map<String, org.opensearch.ingest.Processor.Factory> getProcessors(org.opensearch.ingest.Processor.Parameters parameters) {
10081008
Map<String, org.opensearch.ingest.Processor.Factory> processors = new HashMap<>();
10091009
processors
1010-
.put(MLInferenceIngestProcessor.TYPE, new MLInferenceIngestProcessor.Factory(parameters.scriptService, parameters.client));
1010+
.put(
1011+
MLInferenceIngestProcessor.TYPE,
1012+
new MLInferenceIngestProcessor.Factory(parameters.scriptService, parameters.client, xContentRegistry)
1013+
);
10111014
return Collections.unmodifiableMap(processors);
10121015
}
10131016
}

plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java

+145-56
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,31 @@
66

77
import static org.opensearch.ml.processor.InferenceProcessorAttributes.*;
88

9+
import java.io.IOException;
910
import java.util.ArrayList;
1011
import java.util.Collection;
1112
import java.util.HashMap;
13+
import java.util.HashSet;
1214
import java.util.List;
1315
import java.util.Map;
1416
import java.util.Set;
1517
import java.util.function.BiConsumer;
1618

19+
import org.apache.logging.log4j.LogManager;
20+
import org.apache.logging.log4j.Logger;
1721
import org.opensearch.action.ActionRequest;
1822
import org.opensearch.action.support.GroupedActionListener;
1923
import org.opensearch.client.Client;
2024
import org.opensearch.core.action.ActionListener;
2125
import org.opensearch.core.common.Strings;
26+
import org.opensearch.core.xcontent.NamedXContentRegistry;
2227
import org.opensearch.ingest.AbstractProcessor;
2328
import org.opensearch.ingest.ConfigurationUtils;
2429
import org.opensearch.ingest.IngestDocument;
2530
import org.opensearch.ingest.Processor;
2631
import org.opensearch.ingest.ValueSource;
27-
import org.opensearch.ml.common.output.model.ModelTensorOutput;
32+
import org.opensearch.ml.common.FunctionName;
33+
import org.opensearch.ml.common.output.MLOutput;
2834
import org.opensearch.ml.common.transport.MLTaskResponse;
2935
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
3036
import org.opensearch.ml.common.utils.StringUtils;
@@ -42,20 +48,31 @@
4248
*/
4349
public class MLInferenceIngestProcessor extends AbstractProcessor implements ModelExecutor {
4450

51+
private static final Logger logger = LogManager.getLogger(MLInferenceIngestProcessor.class);
52+
4553
public static final String DOT_SYMBOL = ".";
4654
private final InferenceProcessorAttributes inferenceProcessorAttributes;
4755
private final boolean ignoreMissing;
56+
private final String functionName;
57+
private final boolean fullResponsePath;
4858
private final boolean ignoreFailure;
59+
private final boolean override;
60+
private final String modelInput;
4961
private final ScriptService scriptService;
5062
private static Client client;
5163
public static final String TYPE = "ml_inference";
5264
public static final String DEFAULT_OUTPUT_FIELD_NAME = "inference_results";
5365
// allow to ignore a field from mapping is not present in the document, and when the outfield is not found in the
5466
// prediction outcomes, return the whole prediction outcome by skipping filtering
5567
public static final String IGNORE_MISSING = "ignore_missing";
68+
public static final String OVERRIDE = "override";
69+
public static final String FUNCTION_NAME = "function_name";
70+
public static final String FULL_RESPONSE_PATH = "full_response_path";
71+
public static final String MODEL_INPUT = "model_input";
5672
// At default, ml inference processor allows maximum 10 prediction tasks running in parallel
5773
// it can be overwritten using max_prediction_tasks when creating processor
5874
public static final int DEFAULT_MAX_PREDICTION_TASKS = 10;
75+
private final NamedXContentRegistry xContentRegistry;
5976

6077
private Configuration suppressExceptionConfiguration = Configuration
6178
.builder()
@@ -71,9 +88,14 @@ protected MLInferenceIngestProcessor(
7188
String tag,
7289
String description,
7390
boolean ignoreMissing,
91+
String functionName,
92+
boolean fullResponsePath,
7493
boolean ignoreFailure,
94+
boolean override,
95+
String modelInput,
7596
ScriptService scriptService,
76-
Client client
97+
Client client,
98+
NamedXContentRegistry xContentRegistry
7799
) {
78100
super(tag, description);
79101
this.inferenceProcessorAttributes = new InferenceProcessorAttributes(
@@ -84,9 +106,14 @@ protected MLInferenceIngestProcessor(
84106
maxPredictionTask
85107
);
86108
this.ignoreMissing = ignoreMissing;
109+
this.functionName = functionName;
110+
this.fullResponsePath = fullResponsePath;
87111
this.ignoreFailure = ignoreFailure;
112+
this.override = override;
113+
this.modelInput = modelInput;
88114
this.scriptService = scriptService;
89115
this.client = client;
116+
this.xContentRegistry = xContentRegistry;
90117
}
91118

92119
/**
@@ -162,10 +189,48 @@ private void processPredictions(
162189
List<Map<String, String>> processOutputMap,
163190
int inputMapIndex,
164191
int inputMapSize
165-
) {
192+
) throws IOException {
166193
Map<String, String> modelParameters = new HashMap<>();
194+
Map<String, String> modelConfigs = new HashMap<>();
195+
167196
if (inferenceProcessorAttributes.getModelConfigMaps() != null) {
168197
modelParameters.putAll(inferenceProcessorAttributes.getModelConfigMaps());
198+
modelConfigs.putAll(inferenceProcessorAttributes.getModelConfigMaps());
199+
}
200+
201+
Map<String, Object> ingestDocumentSourceAndMetaData = new HashMap<>();
202+
ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata());
203+
ingestDocumentSourceAndMetaData.put(IngestDocument.INGEST_KEY, ingestDocument.getIngestMetadata());
204+
205+
Map<String, List<String>> newOutputMapping = new HashMap<>();
206+
if (processOutputMap != null) {
207+
208+
Map<String, String> outputMapping = processOutputMap.get(inputMapIndex);
209+
for (Map.Entry<String, String> entry : outputMapping.entrySet()) {
210+
String newDocumentFieldName = entry.getKey();
211+
List<String> dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName);
212+
newOutputMapping.put(newDocumentFieldName, dotPathsInArray);
213+
}
214+
215+
for (Map.Entry<String, String> entry : outputMapping.entrySet()) {
216+
String newDocumentFieldName = entry.getKey();
217+
List<String> dotPaths = newOutputMapping.get(newDocumentFieldName);
218+
219+
int existingFields = 0;
220+
for (String path : dotPaths) {
221+
if (ingestDocument.hasField(path)) {
222+
existingFields++;
223+
}
224+
}
225+
if (!override && existingFields == dotPaths.size()) {
226+
logger.debug("{} already exists in the ingest document. Removing it from output mapping", newDocumentFieldName);
227+
newOutputMapping.remove(newDocumentFieldName);
228+
}
229+
}
230+
if (newOutputMapping.size() == 0) {
231+
batchPredictionListener.onResponse(null);
232+
return;
233+
}
169234
}
170235
// when no input mapping is provided, default to read all fields from documents as model input
171236
if (inputMapSize == 0) {
@@ -184,15 +249,30 @@ private void processPredictions(
184249
}
185250
}
186251

187-
ActionRequest request = getRemoteModelInferenceRequest(modelParameters, inferenceProcessorAttributes.getModelId());
252+
Set<String> inputMapKeys = new HashSet<>(modelParameters.keySet());
253+
inputMapKeys.removeAll(modelConfigs.keySet());
254+
255+
Map<String, String> inputMappings = new HashMap<>();
256+
for (String k : inputMapKeys) {
257+
inputMappings.put(k, modelParameters.get(k));
258+
}
259+
ActionRequest request = getMLModelInferenceRequest(
260+
xContentRegistry,
261+
modelParameters,
262+
modelConfigs,
263+
inputMappings,
264+
inferenceProcessorAttributes.getModelId(),
265+
functionName,
266+
modelInput
267+
);
188268

189269
client.execute(MLPredictionTaskAction.INSTANCE, request, new ActionListener<>() {
190270

191271
@Override
192272
public void onResponse(MLTaskResponse mlTaskResponse) {
193-
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlTaskResponse.getOutput();
273+
MLOutput mlOutput = mlTaskResponse.getOutput();
194274
if (processOutputMap == null || processOutputMap.isEmpty()) {
195-
appendFieldValue(modelTensorOutput, null, DEFAULT_OUTPUT_FIELD_NAME, ingestDocument);
275+
appendFieldValue(mlOutput, null, DEFAULT_OUTPUT_FIELD_NAME, ingestDocument);
196276
} else {
197277
// outMapping serves as a filter to modelTensorOutput, the fields that are not specified
198278
// in the outputMapping will not write to document
@@ -202,14 +282,10 @@ public void onResponse(MLTaskResponse mlTaskResponse) {
202282
// document field as key, model field as value
203283
String newDocumentFieldName = entry.getKey();
204284
String modelOutputFieldName = entry.getValue();
205-
if (ingestDocument.hasField(newDocumentFieldName)) {
206-
throw new IllegalArgumentException(
207-
"document already has field name "
208-
+ newDocumentFieldName
209-
+ ". Not allow to overwrite the same field name, please check output_map."
210-
);
285+
if (!newOutputMapping.containsKey(newDocumentFieldName)) {
286+
continue;
211287
}
212-
appendFieldValue(modelTensorOutput, modelOutputFieldName, newDocumentFieldName, ingestDocument);
288+
appendFieldValue(mlOutput, modelOutputFieldName, newDocumentFieldName, ingestDocument);
213289
}
214290
}
215291
batchPredictionListener.onResponse(null);
@@ -305,63 +381,61 @@ private String getFieldPath(IngestDocument ingestDocument, String documentFieldN
305381
/**
306382
* Appends the model output value to the specified field in the IngestDocument without modifying the source.
307383
*
308-
* @param modelTensorOutput the ModelTensorOutput containing the model output
384+
* @param mlOutput the MLOutput containing the model output
309385
* @param modelOutputFieldName the name of the field in the model output
310386
* @param newDocumentFieldName the name of the field in the IngestDocument to append the value to
311387
* @param ingestDocument the IngestDocument to append the value to
312388
*/
313389
private void appendFieldValue(
314-
ModelTensorOutput modelTensorOutput,
390+
MLOutput mlOutput,
315391
String modelOutputFieldName,
316392
String newDocumentFieldName,
317393
IngestDocument ingestDocument
318394
) {
319-
Object modelOutputValue = null;
320395

321-
if (modelTensorOutput.getMlModelOutputs() != null && modelTensorOutput.getMlModelOutputs().size() > 0) {
396+
if (mlOutput == null) {
397+
throw new RuntimeException("model inference output is null");
398+
}
322399

323-
modelOutputValue = getModelOutputValue(modelTensorOutput, modelOutputFieldName, ignoreMissing);
400+
Object modelOutputValue = getModelOutputValue(mlOutput, modelOutputFieldName, ignoreMissing, fullResponsePath);
324401

325-
Map<String, Object> ingestDocumentSourceAndMetaData = new HashMap<>();
326-
ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata());
327-
ingestDocumentSourceAndMetaData.put(IngestDocument.INGEST_KEY, ingestDocument.getIngestMetadata());
328-
List<String> dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName);
402+
Map<String, Object> ingestDocumentSourceAndMetaData = new HashMap<>();
403+
ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata());
404+
ingestDocumentSourceAndMetaData.put(IngestDocument.INGEST_KEY, ingestDocument.getIngestMetadata());
405+
List<String> dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName);
329406

330-
if (dotPathsInArray.size() == 1) {
331-
ValueSource ingestValue = ValueSource.wrap(modelOutputValue, scriptService);
407+
if (dotPathsInArray.size() == 1) {
408+
ValueSource ingestValue = ValueSource.wrap(modelOutputValue, scriptService);
409+
TemplateScript.Factory ingestField = ConfigurationUtils
410+
.compileTemplate(TYPE, tag, dotPathsInArray.get(0), dotPathsInArray.get(0), scriptService);
411+
ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing);
412+
} else {
413+
if (!(modelOutputValue instanceof List)) {
414+
throw new IllegalArgumentException("Model output is not an array, cannot assign to array in documents.");
415+
}
416+
List<?> modelOutputValueArray = (List<?>) modelOutputValue;
417+
// check length of the prediction array to be the same of the document array
418+
if (dotPathsInArray.size() != modelOutputValueArray.size()) {
419+
throw new RuntimeException(
420+
"the prediction field: "
421+
+ modelOutputFieldName
422+
+ " is an array in size of "
423+
+ modelOutputValueArray.size()
424+
+ " but the document field array from field "
425+
+ newDocumentFieldName
426+
+ " is in size of "
427+
+ dotPathsInArray.size()
428+
);
429+
}
430+
// Iterate over dotPathInArray
431+
for (int i = 0; i < dotPathsInArray.size(); i++) {
432+
String dotPathInArray = dotPathsInArray.get(i);
433+
Object modelOutputValueInArray = modelOutputValueArray.get(i);
434+
ValueSource ingestValue = ValueSource.wrap(modelOutputValueInArray, scriptService);
332435
TemplateScript.Factory ingestField = ConfigurationUtils
333-
.compileTemplate(TYPE, tag, dotPathsInArray.get(0), dotPathsInArray.get(0), scriptService);
436+
.compileTemplate(TYPE, tag, dotPathInArray, dotPathInArray, scriptService);
334437
ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing);
335-
} else {
336-
if (!(modelOutputValue instanceof List)) {
337-
throw new IllegalArgumentException("Model output is not an array, cannot assign to array in documents.");
338-
}
339-
List<?> modelOutputValueArray = (List<?>) modelOutputValue;
340-
// check length of the prediction array to be the same of the document array
341-
if (dotPathsInArray.size() != modelOutputValueArray.size()) {
342-
throw new RuntimeException(
343-
"the prediction field: "
344-
+ modelOutputFieldName
345-
+ " is an array in size of "
346-
+ modelOutputValueArray.size()
347-
+ " but the document field array from field "
348-
+ newDocumentFieldName
349-
+ " is in size of "
350-
+ dotPathsInArray.size()
351-
);
352-
}
353-
// Iterate over dotPathInArray
354-
for (int i = 0; i < dotPathsInArray.size(); i++) {
355-
String dotPathInArray = dotPathsInArray.get(i);
356-
Object modelOutputValueInArray = modelOutputValueArray.get(i);
357-
ValueSource ingestValue = ValueSource.wrap(modelOutputValueInArray, scriptService);
358-
TemplateScript.Factory ingestField = ConfigurationUtils
359-
.compileTemplate(TYPE, tag, dotPathInArray, dotPathInArray, scriptService);
360-
ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing);
361-
}
362438
}
363-
} else {
364-
throw new RuntimeException("model inference output cannot be null");
365439
}
366440
}
367441

@@ -374,16 +448,18 @@ public static class Factory implements Processor.Factory {
374448

375449
private final ScriptService scriptService;
376450
private final Client client;
451+
private final NamedXContentRegistry xContentRegistry;
377452

378453
/**
379454
* Constructs a new instance of the Factory class.
380455
*
381456
* @param scriptService the ScriptService instance to be used by the Factory
382457
* @param client the Client instance to be used by the Factory
383458
*/
384-
public Factory(ScriptService scriptService, Client client) {
459+
public Factory(ScriptService scriptService, Client client, NamedXContentRegistry xContentRegistry) {
385460
this.scriptService = scriptService;
386461
this.client = client;
462+
this.xContentRegistry = xContentRegistry;
387463
}
388464

389465
/**
@@ -410,6 +486,14 @@ public MLInferenceIngestProcessor create(
410486
int maxPredictionTask = ConfigurationUtils
411487
.readIntProperty(TYPE, processorTag, config, MAX_PREDICTION_TASKS, DEFAULT_MAX_PREDICTION_TASKS);
412488
boolean ignoreMissing = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, IGNORE_MISSING, false);
489+
boolean override = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, OVERRIDE, false);
490+
String functionName = ConfigurationUtils
491+
.readStringProperty(TYPE, processorTag, config, FUNCTION_NAME, FunctionName.REMOTE.name());
492+
String modelInput = ConfigurationUtils
493+
.readStringProperty(TYPE, processorTag, config, MODEL_INPUT, "{ \"parameters\": ${ml_inference.parameters} }");
494+
boolean defaultValue = !functionName.equalsIgnoreCase("remote");
495+
boolean fullResponsePath = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, FULL_RESPONSE_PATH, defaultValue);
496+
413497
boolean ignoreFailure = ConfigurationUtils
414498
.readBooleanProperty(TYPE, processorTag, config, ConfigurationUtils.IGNORE_FAILURE_KEY, false);
415499
// convert model config user input data structure to Map<String, String>
@@ -440,9 +524,14 @@ public MLInferenceIngestProcessor create(
440524
processorTag,
441525
description,
442526
ignoreMissing,
527+
functionName,
528+
fullResponsePath,
443529
ignoreFailure,
530+
override,
531+
modelInput,
444532
scriptService,
445-
client
533+
client,
534+
xContentRegistry
446535
);
447536
}
448537
}

0 commit comments

Comments
 (0)