Skip to content

Commit 8b5b38e

Browse files
authored
[Backport 2.17] Bug Fix: Fix for rag processor throwing NPE when optional parameters are not provided (opensearch-project#3066)
1 parent 217afd0 commit 8b5b38e

File tree

3 files changed

+161
-47
lines changed

3 files changed

+161
-47
lines changed

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java

+114-39
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,22 @@
1717
*/
1818
package org.opensearch.searchpipelines.questionanswering.generative.ext;
1919

20+
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
21+
2022
import java.io.IOException;
2123
import java.util.Objects;
2224

23-
import org.opensearch.core.ParseField;
2425
import org.opensearch.core.common.Strings;
2526
import org.opensearch.core.common.io.stream.StreamInput;
2627
import org.opensearch.core.common.io.stream.StreamOutput;
2728
import org.opensearch.core.common.io.stream.Writeable;
28-
import org.opensearch.core.xcontent.ObjectParser;
2929
import org.opensearch.core.xcontent.ToXContentObject;
3030
import org.opensearch.core.xcontent.XContentBuilder;
3131
import org.opensearch.core.xcontent.XContentParser;
32-
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants;
3332

3433
import com.google.common.base.Preconditions;
3534

35+
import lombok.Builder;
3636
import lombok.Getter;
3737
import lombok.NoArgsConstructor;
3838
import lombok.Setter;
@@ -45,57 +45,42 @@
4545
@NoArgsConstructor
4646
public class GenerativeQAParameters implements Writeable, ToXContentObject {
4747

48-
private static final ObjectParser<GenerativeQAParameters, Void> PARSER;
49-
5048
// Optional parameter; if provided, conversational memory will be used for RAG
5149
// and the current interaction will be saved in the conversation referenced by this id.
52-
private static final ParseField CONVERSATION_ID = new ParseField("memory_id");
50+
private static final String CONVERSATION_ID = "memory_id";
5351

5452
// Optional parameter; if an LLM model is not set at the search pipeline level, one must be
5553
// provided at the search request level.
56-
private static final ParseField LLM_MODEL = new ParseField("llm_model");
54+
private static final String LLM_MODEL = "llm_model";
5755

5856
// Required parameter; this is sent to LLMs as part of the user prompt.
5957
// TODO support question rewriting when chat history is not used (conversation_id is not provided).
60-
private static final ParseField LLM_QUESTION = new ParseField("llm_question");
58+
private static final String LLM_QUESTION = "llm_question";
6159

6260
// Optional parameter; this parameter controls the number of search results ("contexts") to
6361
// include in the user prompt.
64-
private static final ParseField CONTEXT_SIZE = new ParseField("context_size");
62+
private static final String CONTEXT_SIZE = "context_size";
6563

6664
// Optional parameter; this parameter controls the number of the interactions to include
6765
// in the user prompt.
68-
private static final ParseField INTERACTION_SIZE = new ParseField("message_size");
66+
private static final String INTERACTION_SIZE = "message_size";
6967

7068
// Optional parameter; this parameter controls how long the search pipeline waits for a response
7169
// from a remote inference endpoint before timing out the request.
72-
private static final ParseField TIMEOUT = new ParseField("timeout");
70+
private static final String TIMEOUT = "timeout";
7371

7472
// Optional parameter: this parameter allows request-level customization of the "system" (role) prompt.
75-
private static final ParseField SYSTEM_PROMPT = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_SYSTEM_PROMPT);
73+
private static final String SYSTEM_PROMPT = "system_prompt";
7674

7775
// Optional parameter: this parameter allows request-level customization of the "user" (role) prompt.
78-
private static final ParseField USER_INSTRUCTIONS = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_USER_INSTRUCTIONS);
76+
private static final String USER_INSTRUCTIONS = "user_instructions";
7977

8078
// Optional parameter; this parameter indicates the name of the field in the LLM response
8179
// that contains the chat completion text, i.e. "answer".
82-
private static final ParseField LLM_RESPONSE_FIELD = new ParseField("llm_response_field");
80+
private static final String LLM_RESPONSE_FIELD = "llm_response_field";
8381

8482
public static final int SIZE_NULL_VALUE = -1;
8583

86-
static {
87-
PARSER = new ObjectParser<>("generative_qa_parameters", GenerativeQAParameters::new);
88-
PARSER.declareString(GenerativeQAParameters::setConversationId, CONVERSATION_ID);
89-
PARSER.declareString(GenerativeQAParameters::setLlmModel, LLM_MODEL);
90-
PARSER.declareString(GenerativeQAParameters::setLlmQuestion, LLM_QUESTION);
91-
PARSER.declareStringOrNull(GenerativeQAParameters::setSystemPrompt, SYSTEM_PROMPT);
92-
PARSER.declareStringOrNull(GenerativeQAParameters::setUserInstructions, USER_INSTRUCTIONS);
93-
PARSER.declareIntOrNull(GenerativeQAParameters::setContextSize, SIZE_NULL_VALUE, CONTEXT_SIZE);
94-
PARSER.declareIntOrNull(GenerativeQAParameters::setInteractionSize, SIZE_NULL_VALUE, INTERACTION_SIZE);
95-
PARSER.declareIntOrNull(GenerativeQAParameters::setTimeout, SIZE_NULL_VALUE, TIMEOUT);
96-
PARSER.declareStringOrNull(GenerativeQAParameters::setLlmResponseField, LLM_RESPONSE_FIELD);
97-
}
98-
9984
@Setter
10085
@Getter
10186
private String conversationId;
@@ -132,6 +117,7 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
132117
@Getter
133118
private String llmResponseField;
134119

120+
@Builder
135121
public GenerativeQAParameters(
136122
String conversationId,
137123
String llmModel,
@@ -148,7 +134,7 @@ public GenerativeQAParameters(
148134

149135
// TODO: keep this requirement until we can extract the question from the query or from the request processor parameters
150136
// for question rewriting.
151-
Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION.getPreferredName() + " must be provided.");
137+
Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION + " must be provided.");
152138
this.llmQuestion = llmQuestion;
153139
this.systemPrompt = systemPrompt;
154140
this.userInstructions = userInstructions;
@@ -172,16 +158,45 @@ public GenerativeQAParameters(StreamInput input) throws IOException {
172158

173159
@Override
174160
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
175-
return xContentBuilder
176-
.field(CONVERSATION_ID.getPreferredName(), this.conversationId)
177-
.field(LLM_MODEL.getPreferredName(), this.llmModel)
178-
.field(LLM_QUESTION.getPreferredName(), this.llmQuestion)
179-
.field(SYSTEM_PROMPT.getPreferredName(), this.systemPrompt)
180-
.field(USER_INSTRUCTIONS.getPreferredName(), this.userInstructions)
181-
.field(CONTEXT_SIZE.getPreferredName(), this.contextSize)
182-
.field(INTERACTION_SIZE.getPreferredName(), this.interactionSize)
183-
.field(TIMEOUT.getPreferredName(), this.timeout)
184-
.field(LLM_RESPONSE_FIELD.getPreferredName(), this.llmResponseField);
161+
xContentBuilder.startObject();
162+
if (this.conversationId != null) {
163+
xContentBuilder.field(CONVERSATION_ID, this.conversationId);
164+
}
165+
166+
if (this.llmModel != null) {
167+
xContentBuilder.field(LLM_MODEL, this.llmModel);
168+
}
169+
170+
if (this.llmQuestion != null) {
171+
xContentBuilder.field(LLM_QUESTION, this.llmQuestion);
172+
}
173+
174+
if (this.systemPrompt != null) {
175+
xContentBuilder.field(SYSTEM_PROMPT, this.systemPrompt);
176+
}
177+
178+
if (this.userInstructions != null) {
179+
xContentBuilder.field(USER_INSTRUCTIONS, this.userInstructions);
180+
}
181+
182+
if (this.contextSize != null) {
183+
xContentBuilder.field(CONTEXT_SIZE, this.contextSize);
184+
}
185+
186+
if (this.interactionSize != null) {
187+
xContentBuilder.field(INTERACTION_SIZE, this.interactionSize);
188+
}
189+
190+
if (this.timeout != null) {
191+
xContentBuilder.field(TIMEOUT, this.timeout);
192+
}
193+
194+
if (this.llmResponseField != null) {
195+
xContentBuilder.field(LLM_RESPONSE_FIELD, this.llmResponseField);
196+
}
197+
198+
xContentBuilder.endObject();
199+
return xContentBuilder;
185200
}
186201

187202
@Override
@@ -200,7 +215,67 @@ public void writeTo(StreamOutput out) throws IOException {
200215
}
201216

202217
public static GenerativeQAParameters parse(XContentParser parser) throws IOException {
203-
return PARSER.parse(parser, null);
218+
String conversationId = null;
219+
String llmModel = null;
220+
String llmQuestion = null;
221+
String systemPrompt = null;
222+
String userInstructions = null;
223+
Integer contextSize = null;
224+
Integer interactionSize = null;
225+
Integer timeout = null;
226+
String llmResponseField = null;
227+
228+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
229+
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
230+
String field = parser.currentName();
231+
parser.nextToken();
232+
233+
switch (field) {
234+
case CONVERSATION_ID:
235+
conversationId = parser.text();
236+
break;
237+
case LLM_MODEL:
238+
llmModel = parser.text();
239+
break;
240+
case LLM_QUESTION:
241+
llmQuestion = parser.text();
242+
break;
243+
case SYSTEM_PROMPT:
244+
systemPrompt = parser.text();
245+
break;
246+
case USER_INSTRUCTIONS:
247+
userInstructions = parser.text();
248+
break;
249+
case CONTEXT_SIZE:
250+
contextSize = parser.intValue();
251+
break;
252+
case INTERACTION_SIZE:
253+
interactionSize = parser.intValue();
254+
break;
255+
case TIMEOUT:
256+
timeout = parser.intValue();
257+
break;
258+
case LLM_RESPONSE_FIELD:
259+
llmResponseField = parser.text();
260+
break;
261+
default:
262+
parser.skipChildren();
263+
break;
264+
}
265+
}
266+
267+
return GenerativeQAParameters
268+
.builder()
269+
.conversationId(conversationId)
270+
.llmModel(llmModel)
271+
.llmQuestion(llmQuestion)
272+
.systemPrompt(systemPrompt)
273+
.userInstructions(userInstructions)
274+
.contextSize(contextSize)
275+
.interactionSize(interactionSize)
276+
.timeout(timeout)
277+
.llmResponseField(llmResponseField)
278+
.build();
204279
}
205280

206281
@Override

search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java

+35-7
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,23 @@
2121
import static org.mockito.Mockito.mock;
2222
import static org.mockito.Mockito.times;
2323
import static org.mockito.Mockito.verify;
24-
import static org.mockito.Mockito.when;
24+
import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS;
2525

2626
import java.io.EOFException;
2727
import java.io.IOException;
28+
import java.util.Collections;
2829

30+
import org.junit.Assert;
2931
import org.opensearch.common.io.stream.BytesStreamOutput;
32+
import org.opensearch.common.settings.Settings;
3033
import org.opensearch.common.xcontent.XContentType;
3134
import org.opensearch.core.common.bytes.BytesReference;
3235
import org.opensearch.core.common.io.stream.StreamInput;
3336
import org.opensearch.core.common.io.stream.StreamOutput;
34-
import org.opensearch.core.xcontent.XContentHelper;
37+
import org.opensearch.core.xcontent.NamedXContentRegistry;
38+
import org.opensearch.core.xcontent.XContentBuilder;
3539
import org.opensearch.core.xcontent.XContentParser;
40+
import org.opensearch.search.SearchModule;
3641
import org.opensearch.test.OpenSearchTestCase;
3742

3843
public class GenerativeQAParamExtBuilderTests extends OpenSearchTestCase {
@@ -107,21 +112,38 @@ public void testMiscMethods() throws IOException {
107112
}
108113

109114
public void testParse() throws IOException {
110-
XContentParser xcParser = mock(XContentParser.class);
111-
when(xcParser.nextToken()).thenReturn(XContentParser.Token.START_OBJECT).thenReturn(XContentParser.Token.END_OBJECT);
112-
GenerativeQAParamExtBuilder builder = GenerativeQAParamExtBuilder.parse(xcParser);
115+
String requiredJsonStr = "{\"llm_question\":\"this is test llm question\"}";
116+
117+
XContentParser parser = XContentType.JSON
118+
.xContent()
119+
.createParser(
120+
new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()),
121+
null,
122+
requiredJsonStr
123+
);
124+
125+
parser.nextToken();
126+
GenerativeQAParamExtBuilder builder = GenerativeQAParamExtBuilder.parse(parser);
113127
assertNotNull(builder);
114128
assertNotNull(builder.getParams());
129+
GenerativeQAParameters params = builder.getParams();
130+
Assert.assertEquals("this is test llm question", params.getLlmQuestion());
115131
}
116132

117133
public void testXContentRoundTrip() throws IOException {
118134
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", null, null, null, null);
119135
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
120136
extBuilder.setParams(param1);
137+
121138
XContentType xContentType = randomFrom(XContentType.values());
122-
BytesReference serialized = XContentHelper.toXContent(extBuilder, xContentType, true);
139+
XContentBuilder builder = XContentBuilder.builder(xContentType.xContent());
140+
builder = extBuilder.toXContent(builder, EMPTY_PARAMS);
141+
BytesReference serialized = BytesReference.bytes(builder);
142+
123143
XContentParser parser = createParser(xContentType.xContent(), serialized);
144+
parser.nextToken();
124145
GenerativeQAParamExtBuilder deserialized = GenerativeQAParamExtBuilder.parse(parser);
146+
125147
assertEquals(extBuilder, deserialized);
126148
GenerativeQAParameters parameters = deserialized.getParams();
127149
assertTrue(GenerativeQAParameters.SIZE_NULL_VALUE == parameters.getContextSize());
@@ -133,10 +155,16 @@ public void testXContentRoundTripAllValues() throws IOException {
133155
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", 1, 2, 3, null);
134156
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
135157
extBuilder.setParams(param1);
158+
136159
XContentType xContentType = randomFrom(XContentType.values());
137-
BytesReference serialized = XContentHelper.toXContent(extBuilder, xContentType, true);
160+
XContentBuilder builder = XContentBuilder.builder(xContentType.xContent());
161+
builder = extBuilder.toXContent(builder, EMPTY_PARAMS);
162+
BytesReference serialized = BytesReference.bytes(builder);
163+
138164
XContentParser parser = createParser(xContentType.xContent(), serialized);
165+
parser.nextToken();
139166
GenerativeQAParamExtBuilder deserialized = GenerativeQAParamExtBuilder.parse(parser);
167+
140168
assertEquals(extBuilder, deserialized);
141169
}
142170

search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java

+12-1
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,18 @@ public void testToXConent() throws IOException {
200200
assertNotNull(parameters.toXContent(builder, null));
201201
}
202202

203-
public void testToXConentAllOptionalParameters() throws IOException {
203+
public void testToXContentEmptyParams() throws IOException {
204+
GenerativeQAParameters parameters = new GenerativeQAParameters();
205+
XContent xc = mock(XContent.class);
206+
OutputStream os = mock(OutputStream.class);
207+
XContentGenerator generator = mock(XContentGenerator.class);
208+
when(xc.createGenerator(any(), any(), any())).thenReturn(generator);
209+
XContentBuilder builder = new XContentBuilder(xc, os);
210+
parameters.toXContent(builder, null);
211+
assertNotNull(parameters.toXContent(builder, null));
212+
}
213+
214+
public void testToXContentAllOptionalParameters() throws IOException {
204215
String conversationId = "a";
205216
String llmModel = "b";
206217
String llmQuestion = "c";

0 commit comments

Comments
 (0)