Skip to content

Commit b7a0d78

Browse files
authored
Bug Fix: Fix for rag processor throwing NPE when optional parameters are not provided (opensearch-project#3057)
* fix (rag npe): optional and empty fields are handled appropriately Signed-off-by: Pavan Yekbote <mail2pavanyekbote@gmail.com> * fix: test cases Signed-off-by: Pavan Yekbote <mail2pavanyekbote@gmail.com> * fix: format violations Signed-off-by: Pavan Yekbote <mail2pavanyekbote@gmail.com> * tests: adding empty params test case Signed-off-by: Pavan Yekbote <mail2pavanyekbote@gmail.com> * fix: remove wildcard import Signed-off-by: Pavan Yekbote <mail2pavanyekbote@gmail.com> --------- Signed-off-by: Pavan Yekbote <mail2pavanyekbote@gmail.com>
1 parent 4850254 commit b7a0d78

File tree

3 files changed

+175
-50
lines changed

3 files changed

+175
-50
lines changed

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java

+128-42
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,25 @@
1717
*/
1818
package org.opensearch.searchpipelines.questionanswering.generative.ext;
1919

20+
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
21+
2022
import java.io.IOException;
2123
import java.util.ArrayList;
2224
import java.util.List;
2325
import java.util.Objects;
2426

25-
import org.opensearch.core.ParseField;
2627
import org.opensearch.core.common.Strings;
2728
import org.opensearch.core.common.io.stream.StreamInput;
2829
import org.opensearch.core.common.io.stream.StreamOutput;
2930
import org.opensearch.core.common.io.stream.Writeable;
30-
import org.opensearch.core.xcontent.ObjectParser;
3131
import org.opensearch.core.xcontent.ToXContentObject;
3232
import org.opensearch.core.xcontent.XContentBuilder;
3333
import org.opensearch.core.xcontent.XContentParser;
34-
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants;
3534
import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock;
3635

3736
import com.google.common.base.Preconditions;
3837

38+
import lombok.Builder;
3939
import lombok.Getter;
4040
import lombok.NoArgsConstructor;
4141
import lombok.Setter;
@@ -48,60 +48,44 @@
4848
@NoArgsConstructor
4949
public class GenerativeQAParameters implements Writeable, ToXContentObject {
5050

51-
private static final ObjectParser<GenerativeQAParameters, Void> PARSER;
52-
5351
// Optional parameter; if provided, conversational memory will be used for RAG
5452
// and the current interaction will be saved in the conversation referenced by this id.
55-
private static final ParseField CONVERSATION_ID = new ParseField("memory_id");
53+
private static final String CONVERSATION_ID = "memory_id";
5654

5755
// Optional parameter; if an LLM model is not set at the search pipeline level, one must be
5856
// provided at the search request level.
59-
private static final ParseField LLM_MODEL = new ParseField("llm_model");
57+
private static final String LLM_MODEL = "llm_model";
6058

6159
// Required parameter; this is sent to LLMs as part of the user prompt.
6260
// TODO support question rewriting when chat history is not used (conversation_id is not provided).
63-
private static final ParseField LLM_QUESTION = new ParseField("llm_question");
61+
private static final String LLM_QUESTION = "llm_question";
6462

6563
// Optional parameter; this parameter controls the number of search results ("contexts") to
6664
// include in the user prompt.
67-
private static final ParseField CONTEXT_SIZE = new ParseField("context_size");
65+
private static final String CONTEXT_SIZE = "context_size";
6866

6967
// Optional parameter; this parameter controls the number of the interactions to include
7068
// in the user prompt.
71-
private static final ParseField INTERACTION_SIZE = new ParseField("message_size");
69+
private static final String INTERACTION_SIZE = "message_size";
7270

7371
// Optional parameter; this parameter controls how long the search pipeline waits for a response
7472
// from a remote inference endpoint before timing out the request.
75-
private static final ParseField TIMEOUT = new ParseField("timeout");
73+
private static final String TIMEOUT = "timeout";
7674

7775
// Optional parameter: this parameter allows request-level customization of the "system" (role) prompt.
78-
private static final ParseField SYSTEM_PROMPT = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_SYSTEM_PROMPT);
76+
private static final String SYSTEM_PROMPT = "system_prompt";
7977

8078
// Optional parameter: this parameter allows request-level customization of the "user" (role) prompt.
81-
private static final ParseField USER_INSTRUCTIONS = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_USER_INSTRUCTIONS);
79+
private static final String USER_INSTRUCTIONS = "user_instructions";
8280

8381
// Optional parameter; this parameter indicates the name of the field in the LLM response
8482
// that contains the chat completion text, i.e. "answer".
85-
private static final ParseField LLM_RESPONSE_FIELD = new ParseField("llm_response_field");
83+
private static final String LLM_RESPONSE_FIELD = "llm_response_field";
8684

87-
private static final ParseField LLM_MESSAGES_FIELD = new ParseField("llm_messages");
85+
private static final String LLM_MESSAGES_FIELD = "llm_messages";
8886

8987
public static final int SIZE_NULL_VALUE = -1;
9088

91-
static {
92-
PARSER = new ObjectParser<>("generative_qa_parameters", GenerativeQAParameters::new);
93-
PARSER.declareString(GenerativeQAParameters::setConversationId, CONVERSATION_ID);
94-
PARSER.declareString(GenerativeQAParameters::setLlmModel, LLM_MODEL);
95-
PARSER.declareString(GenerativeQAParameters::setLlmQuestion, LLM_QUESTION);
96-
PARSER.declareStringOrNull(GenerativeQAParameters::setSystemPrompt, SYSTEM_PROMPT);
97-
PARSER.declareStringOrNull(GenerativeQAParameters::setUserInstructions, USER_INSTRUCTIONS);
98-
PARSER.declareIntOrNull(GenerativeQAParameters::setContextSize, SIZE_NULL_VALUE, CONTEXT_SIZE);
99-
PARSER.declareIntOrNull(GenerativeQAParameters::setInteractionSize, SIZE_NULL_VALUE, INTERACTION_SIZE);
100-
PARSER.declareIntOrNull(GenerativeQAParameters::setTimeout, SIZE_NULL_VALUE, TIMEOUT);
101-
PARSER.declareStringOrNull(GenerativeQAParameters::setLlmResponseField, LLM_RESPONSE_FIELD);
102-
PARSER.declareObjectArray(GenerativeQAParameters::setMessageBlock, (p, c) -> MessageBlock.fromXContent(p), LLM_MESSAGES_FIELD);
103-
}
104-
10589
@Setter
10690
@Getter
10791
private String conversationId;
@@ -167,6 +151,7 @@ public GenerativeQAParameters(
167151
);
168152
}
169153

154+
@Builder(toBuilder = true)
170155
public GenerativeQAParameters(
171156
String conversationId,
172157
String llmModel,
@@ -184,7 +169,7 @@ public GenerativeQAParameters(
184169

185170
// TODO: keep this requirement until we can extract the question from the query or from the request processor parameters
186171
// for question rewriting.
187-
Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION.getPreferredName() + " must be provided.");
172+
Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION + " must be provided.");
188173
this.llmQuestion = llmQuestion;
189174
this.systemPrompt = systemPrompt;
190175
this.userInstructions = userInstructions;
@@ -212,17 +197,49 @@ public GenerativeQAParameters(StreamInput input) throws IOException {
212197

213198
@Override
214199
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
215-
return xContentBuilder
216-
.field(CONVERSATION_ID.getPreferredName(), this.conversationId)
217-
.field(LLM_MODEL.getPreferredName(), this.llmModel)
218-
.field(LLM_QUESTION.getPreferredName(), this.llmQuestion)
219-
.field(SYSTEM_PROMPT.getPreferredName(), this.systemPrompt)
220-
.field(USER_INSTRUCTIONS.getPreferredName(), this.userInstructions)
221-
.field(CONTEXT_SIZE.getPreferredName(), this.contextSize)
222-
.field(INTERACTION_SIZE.getPreferredName(), this.interactionSize)
223-
.field(TIMEOUT.getPreferredName(), this.timeout)
224-
.field(LLM_RESPONSE_FIELD.getPreferredName(), this.llmResponseField)
225-
.field(LLM_MESSAGES_FIELD.getPreferredName(), this.llmMessages);
200+
xContentBuilder.startObject();
201+
if (this.conversationId != null) {
202+
xContentBuilder.field(CONVERSATION_ID, this.conversationId);
203+
}
204+
205+
if (this.llmModel != null) {
206+
xContentBuilder.field(LLM_MODEL, this.llmModel);
207+
}
208+
209+
if (this.llmQuestion != null) {
210+
xContentBuilder.field(LLM_QUESTION, this.llmQuestion);
211+
}
212+
213+
if (this.systemPrompt != null) {
214+
xContentBuilder.field(SYSTEM_PROMPT, this.systemPrompt);
215+
}
216+
217+
if (this.userInstructions != null) {
218+
xContentBuilder.field(USER_INSTRUCTIONS, this.userInstructions);
219+
}
220+
221+
if (this.contextSize != null) {
222+
xContentBuilder.field(CONTEXT_SIZE, this.contextSize);
223+
}
224+
225+
if (this.interactionSize != null) {
226+
xContentBuilder.field(INTERACTION_SIZE, this.interactionSize);
227+
}
228+
229+
if (this.timeout != null) {
230+
xContentBuilder.field(TIMEOUT, this.timeout);
231+
}
232+
233+
if (this.llmResponseField != null) {
234+
xContentBuilder.field(LLM_RESPONSE_FIELD, this.llmResponseField);
235+
}
236+
237+
if (this.llmMessages != null && !this.llmMessages.isEmpty()) {
238+
xContentBuilder.field(LLM_MESSAGES_FIELD, this.llmMessages);
239+
}
240+
241+
xContentBuilder.endObject();
242+
return xContentBuilder;
226243
}
227244

228245
@Override
@@ -242,7 +259,76 @@ public void writeTo(StreamOutput out) throws IOException {
242259
}
243260

244261
public static GenerativeQAParameters parse(XContentParser parser) throws IOException {
245-
return PARSER.parse(parser, null);
262+
String conversationId = null;
263+
String llmModel = null;
264+
String llmQuestion = null;
265+
String systemPrompt = null;
266+
String userInstructions = null;
267+
Integer contextSize = null;
268+
Integer interactionSize = null;
269+
Integer timeout = null;
270+
String llmResponseField = null;
271+
List<MessageBlock> llmMessages = null;
272+
273+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
274+
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
275+
String field = parser.currentName();
276+
parser.nextToken();
277+
278+
switch (field) {
279+
case CONVERSATION_ID:
280+
conversationId = parser.text();
281+
break;
282+
case LLM_MODEL:
283+
llmModel = parser.text();
284+
break;
285+
case LLM_QUESTION:
286+
llmQuestion = parser.text();
287+
break;
288+
case SYSTEM_PROMPT:
289+
systemPrompt = parser.text();
290+
break;
291+
case USER_INSTRUCTIONS:
292+
userInstructions = parser.text();
293+
break;
294+
case CONTEXT_SIZE:
295+
contextSize = parser.intValue();
296+
break;
297+
case INTERACTION_SIZE:
298+
interactionSize = parser.intValue();
299+
break;
300+
case TIMEOUT:
301+
timeout = parser.intValue();
302+
break;
303+
case LLM_RESPONSE_FIELD:
304+
llmResponseField = parser.text();
305+
break;
306+
case LLM_MESSAGES_FIELD:
307+
llmMessages = new ArrayList<>();
308+
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
309+
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
310+
llmMessages.add(MessageBlock.fromXContent(parser));
311+
}
312+
break;
313+
default:
314+
parser.skipChildren();
315+
break;
316+
}
317+
}
318+
319+
return GenerativeQAParameters
320+
.builder()
321+
.conversationId(conversationId)
322+
.llmModel(llmModel)
323+
.llmQuestion(llmQuestion)
324+
.systemPrompt(systemPrompt)
325+
.userInstructions(userInstructions)
326+
.contextSize(contextSize)
327+
.interactionSize(interactionSize)
328+
.timeout(timeout)
329+
.llmResponseField(llmResponseField)
330+
.llmMessages(llmMessages)
331+
.build();
246332
}
247333

248334
@Override

search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java

+35-7
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,25 @@
2121
import static org.mockito.Mockito.mock;
2222
import static org.mockito.Mockito.times;
2323
import static org.mockito.Mockito.verify;
24-
import static org.mockito.Mockito.when;
24+
import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS;
2525

2626
import java.io.EOFException;
2727
import java.io.IOException;
28+
import java.util.Collections;
2829
import java.util.List;
2930
import java.util.Map;
3031

32+
import org.junit.Assert;
3133
import org.opensearch.common.io.stream.BytesStreamOutput;
34+
import org.opensearch.common.settings.Settings;
3235
import org.opensearch.common.xcontent.XContentType;
3336
import org.opensearch.core.common.bytes.BytesReference;
3437
import org.opensearch.core.common.io.stream.StreamInput;
3538
import org.opensearch.core.common.io.stream.StreamOutput;
36-
import org.opensearch.core.xcontent.XContentHelper;
39+
import org.opensearch.core.xcontent.NamedXContentRegistry;
40+
import org.opensearch.core.xcontent.XContentBuilder;
3741
import org.opensearch.core.xcontent.XContentParser;
42+
import org.opensearch.search.SearchModule;
3843
import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock;
3944
import org.opensearch.test.OpenSearchTestCase;
4045

@@ -121,21 +126,38 @@ public void testMiscMethods() throws IOException {
121126
}
122127

123128
public void testParse() throws IOException {
124-
XContentParser xcParser = mock(XContentParser.class);
125-
when(xcParser.nextToken()).thenReturn(XContentParser.Token.START_OBJECT).thenReturn(XContentParser.Token.END_OBJECT);
126-
GenerativeQAParamExtBuilder builder = GenerativeQAParamExtBuilder.parse(xcParser);
129+
String requiredJsonStr = "{\"llm_question\":\"this is test llm question\"}";
130+
131+
XContentParser parser = XContentType.JSON
132+
.xContent()
133+
.createParser(
134+
new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()),
135+
null,
136+
requiredJsonStr
137+
);
138+
139+
parser.nextToken();
140+
GenerativeQAParamExtBuilder builder = GenerativeQAParamExtBuilder.parse(parser);
127141
assertNotNull(builder);
128142
assertNotNull(builder.getParams());
143+
GenerativeQAParameters params = builder.getParams();
144+
Assert.assertEquals("this is test llm question", params.getLlmQuestion());
129145
}
130146

131147
public void testXContentRoundTrip() throws IOException {
132148
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", null, null, null, null, messageList);
133149
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
134150
extBuilder.setParams(param1);
151+
135152
XContentType xContentType = randomFrom(XContentType.values());
136-
BytesReference serialized = XContentHelper.toXContent(extBuilder, xContentType, true);
153+
XContentBuilder builder = XContentBuilder.builder(xContentType.xContent());
154+
builder = extBuilder.toXContent(builder, EMPTY_PARAMS);
155+
BytesReference serialized = BytesReference.bytes(builder);
156+
137157
XContentParser parser = createParser(xContentType.xContent(), serialized);
158+
parser.nextToken();
138159
GenerativeQAParamExtBuilder deserialized = GenerativeQAParamExtBuilder.parse(parser);
160+
139161
assertEquals(extBuilder, deserialized);
140162
GenerativeQAParameters parameters = deserialized.getParams();
141163
assertTrue(GenerativeQAParameters.SIZE_NULL_VALUE == parameters.getContextSize());
@@ -147,10 +169,16 @@ public void testXContentRoundTripAllValues() throws IOException {
147169
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", 1, 2, 3, null);
148170
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
149171
extBuilder.setParams(param1);
172+
150173
XContentType xContentType = randomFrom(XContentType.values());
151-
BytesReference serialized = XContentHelper.toXContent(extBuilder, xContentType, true);
174+
XContentBuilder builder = XContentBuilder.builder(xContentType.xContent());
175+
builder = extBuilder.toXContent(builder, EMPTY_PARAMS);
176+
BytesReference serialized = BytesReference.bytes(builder);
177+
152178
XContentParser parser = createParser(xContentType.xContent(), serialized);
179+
parser.nextToken();
153180
GenerativeQAParamExtBuilder deserialized = GenerativeQAParamExtBuilder.parse(parser);
181+
154182
assertEquals(extBuilder, deserialized);
155183
}
156184

search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java

+12-1
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,18 @@ public void testToXConent() throws IOException {
239239
assertNotNull(parameters.toXContent(builder, null));
240240
}
241241

242-
public void testToXConentAllOptionalParameters() throws IOException {
242+
public void testToXContentEmptyParams() throws IOException {
243+
GenerativeQAParameters parameters = new GenerativeQAParameters();
244+
XContent xc = mock(XContent.class);
245+
OutputStream os = mock(OutputStream.class);
246+
XContentGenerator generator = mock(XContentGenerator.class);
247+
when(xc.createGenerator(any(), any(), any())).thenReturn(generator);
248+
XContentBuilder builder = new XContentBuilder(xc, os);
249+
parameters.toXContent(builder, null);
250+
assertNotNull(parameters.toXContent(builder, null));
251+
}
252+
253+
public void testToXContentAllOptionalParameters() throws IOException {
243254
String conversationId = "a";
244255
String llmModel = "b";
245256
String llmQuestion = "c";

0 commit comments

Comments
 (0)