Skip to content

Commit 54774c1

Browse files
junqiu-leiakolarkunnu
authored andcommitted
Support sentence highlighting QA model (opensearch-project#3600)
Signed-off-by: Junqiu Lei <junqiu@amazon.com>
1 parent e2c19e3 commit 54774c1

File tree

13 files changed

+1723
-10
lines changed

13 files changed

+1723
-10
lines changed

common/src/main/java/org/opensearch/ml/common/dataset/QuestionAnsweringInputDataSet.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ public class QuestionAnsweringInputDataSet extends MLInputDataset {
2727
@Builder(toBuilder = true)
2828
public QuestionAnsweringInputDataSet(String question, String context) {
2929
super(MLInputDataType.QUESTION_ANSWERING);
30-
if (question == null) {
30+
if (question == null || question.isEmpty()) {
3131
throw new IllegalArgumentException("Question is not provided");
3232
}
33-
if (context == null) {
33+
if (context == null || context.isEmpty()) {
3434
throw new IllegalArgumentException("Context is not provided");
3535
}
3636
this.question = question;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.engine.algorithms.question_answering;
7+
8+
/**
9+
* Constants for Question Answering models and related functionality.
10+
*/
11+
public final class QAConstants {
12+
// Model types
13+
public static final String SENTENCE_HIGHLIGHTING_TYPE = "sentence_highlighting";
14+
15+
// Output field names
16+
public static final String FIELD_HIGHLIGHTS = "highlights";
17+
public static final String FIELD_ERROR = "error";
18+
public static final String FIELD_TEXT = "text";
19+
public static final String FIELD_POSITION = "position";
20+
public static final String FIELD_START = "start";
21+
public static final String FIELD_END = "end";
22+
23+
// Context keys
24+
public static final String KEY_SENTENCES = "sentences";
25+
26+
// Model input names
27+
public static final String INPUT_IDS = "input_ids";
28+
public static final String ATTENTION_MASK = "attention_mask";
29+
public static final String TOKEN_TYPE_IDS = "token_type_ids";
30+
31+
// Default values for warm-up
32+
public static final String DEFAULT_WARMUP_QUESTION = "How is the weather?";
33+
public static final String DEFAULT_WARMUP_CONTEXT = "The weather is nice, it is beautiful day. The sun is shining. The sky is blue.";
34+
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnsweringModel.java

+34-7
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66
package org.opensearch.ml.engine.algorithms.question_answering;
77

8-
import static org.opensearch.ml.engine.ModelHelper.*;
8+
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.DEFAULT_WARMUP_CONTEXT;
9+
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.DEFAULT_WARMUP_QUESTION;
10+
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.SENTENCE_HIGHLIGHTING_TYPE;
911

1012
import java.util.ArrayList;
1113
import java.util.List;
@@ -28,22 +30,45 @@
2830
import ai.djl.translate.TranslatorFactory;
2931
import lombok.extern.log4j.Log4j2;
3032

33+
/**
34+
* Question answering model implementation that supports both standard QA and
35+
* highlighting sentence.
36+
*/
3137
@Log4j2
3238
@Function(FunctionName.QUESTION_ANSWERING)
3339
public class QuestionAnsweringModel extends DLModel {
3440

3541
@Override
3642
public void warmUp(Predictor predictor, String modelId, MLModelConfig modelConfig) throws TranslateException {
37-
String question = "How is the weather?";
38-
String context = "The weather is nice, it is beautiful day.";
43+
if (predictor == null) {
44+
throw new IllegalArgumentException("predictor is null");
45+
}
46+
if (modelId == null) {
47+
throw new IllegalArgumentException("model id is null");
48+
}
49+
50+
// Create input for the predictor
3951
Input input = new Input();
40-
input.add(question);
41-
input.add(context);
52+
input.add(DEFAULT_WARMUP_QUESTION);
53+
input.add(DEFAULT_WARMUP_CONTEXT);
4254

43-
// First request takes longer time. Predict once to warm up model.
55+
// Run prediction to warm up the model
4456
predictor.predict(input);
4557
}
4658

59+
/**
60+
* Checks if the model is configured for sentence highlighting.
61+
*
62+
* @param modelConfig The model configuration
63+
* @return true if the model is configured for sentence highlighting, false otherwise
64+
*/
65+
private boolean isSentenceHighlightingType(MLModelConfig modelConfig) {
66+
if (modelConfig != null) {
67+
return SENTENCE_HIGHLIGHTING_TYPE.equalsIgnoreCase(modelConfig.getModelType());
68+
}
69+
return false;
70+
}
71+
4772
@Override
4873
public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException {
4974
MLInputDataset inputDataSet = mlInput.getInputDataset();
@@ -62,12 +87,14 @@ public ModelTensorOutput predict(String modelId, MLInput mlInput) throws Transla
6287

6388
@Override
6489
public Translator<Input, Output> getTranslator(String engine, MLModelConfig modelConfig) throws IllegalArgumentException {
90+
if (isSentenceHighlightingType(modelConfig)) {
91+
return SentenceHighlightingQATranslator.createDefault();
92+
}
6593
return new QuestionAnsweringTranslator();
6694
}
6795

6896
@Override
6997
public TranslatorFactory getTranslatorFactory(String engine, MLModelConfig modelConfig) {
7098
return null;
7199
}
72-
73100
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.engine.algorithms.question_answering;
7+
8+
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.ATTENTION_MASK;
9+
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.FIELD_END;
10+
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.FIELD_ERROR;
11+
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.FIELD_HIGHLIGHTS;
12+
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.FIELD_POSITION;
13+
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.FIELD_START;
14+
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.FIELD_TEXT;
15+
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.INPUT_IDS;
16+
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.KEY_SENTENCES;
17+
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.TOKEN_TYPE_IDS;
18+
19+
import java.io.IOException;
20+
import java.nio.file.Path;
21+
import java.util.ArrayList;
22+
import java.util.HashMap;
23+
import java.util.List;
24+
import java.util.Locale;
25+
import java.util.Map;
26+
27+
import org.jetbrains.annotations.NotNull;
28+
import org.opensearch.ml.common.input.MLInput;
29+
import org.opensearch.ml.common.output.model.ModelTensor;
30+
import org.opensearch.ml.common.output.model.ModelTensors;
31+
import org.opensearch.ml.engine.algorithms.question_answering.sentence.DefaultSentenceSegmenter;
32+
import org.opensearch.ml.engine.algorithms.question_answering.sentence.Sentence;
33+
import org.opensearch.ml.engine.algorithms.question_answering.sentence.SentenceSegmenter;
34+
35+
import ai.djl.huggingface.tokenizers.Encoding;
36+
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
37+
import ai.djl.modality.Input;
38+
import ai.djl.modality.Output;
39+
import ai.djl.ndarray.NDArray;
40+
import ai.djl.ndarray.NDList;
41+
import ai.djl.ndarray.NDManager;
42+
import ai.djl.translate.ServingTranslator;
43+
import ai.djl.translate.TranslatorContext;
44+
import lombok.Builder;
45+
import lombok.Getter;
46+
import lombok.extern.log4j.Log4j2;
47+
48+
/**
49+
* Translator for sentence highlighting question answering model.
50+
*
51+
* <p>Expected model output format:
52+
* The model should output binary predictions for each sentence, where:
53+
* - 1 indicates a relevant sentence (that answers the question)
54+
* - 0 indicates a non-relevant sentence
55+
*
56+
* This format can be customized by overriding the isRelevantPrediction method.
57+
*/
58+
@Log4j2
59+
@Getter
60+
@Builder
61+
public class SentenceHighlightingQATranslator implements ServingTranslator {
62+
/**
63+
* Default relevance value that indicates a sentence is relevant.
64+
* By default, 1 means relevant and 0 means not relevant.
65+
* The method specifically checks for equality with RELEVANT_VALUE (1) to determine relevance.
66+
*/
67+
private static final long RELEVANT_VALUE = 1L;
68+
69+
/**
70+
* Determines if a prediction value indicates a relevant sentence.
71+
*
72+
* @param predictionValue The prediction value from the model
73+
* @return true if the prediction indicates a relevant sentence, false otherwise
74+
*/
75+
protected boolean isRelevantPrediction(long predictionValue) {
76+
return predictionValue == RELEVANT_VALUE;
77+
}
78+
79+
@Builder.Default
80+
private final SentenceSegmenter segmenter = new DefaultSentenceSegmenter();
81+
82+
private HuggingFaceTokenizer tokenizer;
83+
84+
/**
85+
* Creates a new translator with default settings.
86+
*
87+
* @return A new SentenceHighlightingQATranslator instance
88+
*/
89+
public static SentenceHighlightingQATranslator createDefault() {
90+
return builder().build();
91+
}
92+
93+
@Override
94+
public void prepare(TranslatorContext ctx) throws IOException {
95+
Path path = ctx.getModel().getModelPath();
96+
tokenizer = HuggingFaceTokenizer.builder().optPadding(true).optTokenizerPath(path.resolve("tokenizer.json")).build();
97+
}
98+
99+
@Override
100+
public void setArguments(Map<String, ?> arguments) {
101+
// No arguments needed for this translator
102+
}
103+
104+
@Override
105+
public NDList processInput(TranslatorContext ctx, Input input) {
106+
try {
107+
NDManager manager = ctx.getNDManager();
108+
String question = input.getAsString(0);
109+
String context = input.getAsString(1);
110+
111+
List<Sentence> sentences = segmenter.segment(context);
112+
ctx.setAttachment(KEY_SENTENCES, sentences);
113+
ctx.setAttachment(MLInput.QUESTION_FIELD, question);
114+
115+
Encoding encodings = tokenizer.encode(question, context);
116+
117+
NDArray indicesArray = manager.create(encodings.getIds());
118+
indicesArray.setName(INPUT_IDS);
119+
120+
NDArray attentionMaskArray = manager.create(encodings.getAttentionMask());
121+
if (attentionMaskArray.isEmpty()) {
122+
throw new IllegalArgumentException("Attention mask is empty in sentence highlighting QA model input");
123+
}
124+
attentionMaskArray.setName(ATTENTION_MASK);
125+
126+
NDArray tokenTypeIdsArray = manager.create(encodings.getTypeIds());
127+
tokenTypeIdsArray.setName(TOKEN_TYPE_IDS);
128+
129+
return new NDList(indicesArray, attentionMaskArray, tokenTypeIdsArray);
130+
} catch (Exception e) {
131+
throw new IllegalArgumentException(String.format(Locale.ROOT, "Error processing input: %s", e.getMessage()), e);
132+
}
133+
}
134+
135+
@Override
136+
public Output processOutput(TranslatorContext ctx, NDList list) {
137+
try {
138+
Output output = new Output(200, "OK");
139+
140+
@SuppressWarnings("unchecked")
141+
List<Sentence> sentences = (List<Sentence>) ctx.getAttachment(KEY_SENTENCES);
142+
boolean[] isRelevant = new boolean[sentences.size()];
143+
144+
// Check if we have valid output from the model
145+
if (list == null || list.isEmpty()) {
146+
return createErrorOutput("Model returned empty or null output");
147+
}
148+
149+
// The model returns a tensor where 1 means relevant, 0 means not relevant
150+
NDArray binaryPreds = list.getFirst();
151+
152+
// Validate prediction shape
153+
if (binaryPreds.getShape().dimension() == 0 || binaryPreds.getShape().get(0) == 0) {
154+
return createErrorOutput(String.format("Invalid prediction shape: %s", binaryPreds.getShape()));
155+
}
156+
157+
// Convert to boolean array
158+
for (int i = 0; i < Math.min(sentences.size(), binaryPreds.getShape().get(0)); i++) {
159+
try {
160+
long predValue = binaryPreds.getLong(i);
161+
isRelevant[i] = isRelevantPrediction(predValue);
162+
} catch (Exception e) {
163+
log.warn(String.format("Error processing prediction for sentence %d: %s", i, e.getMessage()));
164+
isRelevant[i] = false;
165+
}
166+
}
167+
168+
// Create sentence data objects
169+
List<SentenceData> sentenceDataList = new ArrayList<>();
170+
for (int i = 0; i < sentences.size(); i++) {
171+
Sentence sentence = sentences.get(i);
172+
boolean relevant = isRelevant[i];
173+
sentenceDataList.add(new SentenceData(sentence.getText(), relevant, sentence.getPosition()));
174+
}
175+
176+
// Prepare output list for relevant sentences
177+
List<Map<String, Object>> relevantSentenceDetails = getRelevantSentenceDetails(sentenceDataList, sentences);
178+
log.info("Relevant sentence details: {}", relevantSentenceDetails);
179+
180+
// Create a map to hold our data
181+
Map<String, Object> dataMap = new HashMap<>();
182+
dataMap.put(FIELD_HIGHLIGHTS, relevantSentenceDetails);
183+
184+
// Create the ModelTensor using the builder pattern
185+
ModelTensor tensor = ModelTensor.builder().name(FIELD_HIGHLIGHTS).dataAsMap(dataMap).build();
186+
187+
// Wrap in ModelTensors and convert to bytes
188+
ModelTensors modelTensorOutput = new ModelTensors(List.of(tensor));
189+
output.add(modelTensorOutput.toBytes());
190+
return output;
191+
} catch (Exception e) {
192+
return createErrorOutput(String.format("Error processing output: %s", e.getMessage()));
193+
}
194+
}
195+
196+
private static @NotNull List<Map<String, Object>> getRelevantSentenceDetails(
197+
List<SentenceData> sentenceDataList,
198+
List<Sentence> sentences
199+
) {
200+
List<Map<String, Object>> relevantSentenceDetails = new ArrayList<>();
201+
202+
for (SentenceData data : sentenceDataList) {
203+
if (data.isRelevant) {
204+
// Find the corresponding sentence to get start and end indices
205+
for (Sentence sentence : sentences) {
206+
if (sentence.getPosition() == data.position) {
207+
Map<String, Object> sentenceDetail = new HashMap<>();
208+
sentenceDetail.put(FIELD_TEXT, data.text);
209+
sentenceDetail.put(FIELD_POSITION, data.position);
210+
sentenceDetail.put(FIELD_START, sentence.getStartIndex());
211+
sentenceDetail.put(FIELD_END, sentence.getEndIndex());
212+
relevantSentenceDetails.add(sentenceDetail);
213+
break;
214+
}
215+
}
216+
}
217+
}
218+
return relevantSentenceDetails;
219+
}
220+
221+
private Output createErrorOutput(String errorMessage) {
222+
Output output = new Output(400, "Bad Request");
223+
224+
// Create a map to hold our error data
225+
Map<String, Object> errorData = new HashMap<>();
226+
errorData.put(FIELD_ERROR, errorMessage);
227+
errorData.put(FIELD_HIGHLIGHTS, new ArrayList<>());
228+
229+
// Create the ModelTensor using the builder pattern
230+
ModelTensor tensor = ModelTensor.builder().name(FIELD_ERROR).dataAsMap(errorData).build();
231+
232+
// Wrap in ModelTensors and convert to bytes
233+
ModelTensors modelTensorOutput = new ModelTensors(List.of(tensor));
234+
output.add(modelTensorOutput.toBytes());
235+
return output;
236+
}
237+
238+
// Helper class to store sentence data
239+
private record SentenceData(String text, boolean isRelevant, int position) {
240+
}
241+
}

0 commit comments

Comments
 (0)