17
17
*/
18
18
package org .opensearch .searchpipelines .questionanswering .generative .ext ;
19
19
20
+ import static org .opensearch .core .xcontent .XContentParserUtils .ensureExpectedToken ;
21
+
20
22
import java .io .IOException ;
21
23
import java .util .ArrayList ;
22
24
import java .util .List ;
23
25
import java .util .Objects ;
24
26
25
- import org .opensearch .core .ParseField ;
26
27
import org .opensearch .core .common .Strings ;
27
28
import org .opensearch .core .common .io .stream .StreamInput ;
28
29
import org .opensearch .core .common .io .stream .StreamOutput ;
29
30
import org .opensearch .core .common .io .stream .Writeable ;
30
- import org .opensearch .core .xcontent .ObjectParser ;
31
31
import org .opensearch .core .xcontent .ToXContentObject ;
32
32
import org .opensearch .core .xcontent .XContentBuilder ;
33
33
import org .opensearch .core .xcontent .XContentParser ;
34
- import org .opensearch .searchpipelines .questionanswering .generative .GenerativeQAProcessorConstants ;
35
34
import org .opensearch .searchpipelines .questionanswering .generative .llm .MessageBlock ;
36
35
37
36
import com .google .common .base .Preconditions ;
38
37
38
+ import lombok .Builder ;
39
39
import lombok .Getter ;
40
40
import lombok .NoArgsConstructor ;
41
41
import lombok .Setter ;
48
48
@ NoArgsConstructor
49
49
public class GenerativeQAParameters implements Writeable , ToXContentObject {
50
50
51
- private static final ObjectParser <GenerativeQAParameters , Void > PARSER ;
52
-
53
51
// Optional parameter; if provided, conversational memory will be used for RAG
54
52
// and the current interaction will be saved in the conversation referenced by this id.
55
- private static final ParseField CONVERSATION_ID = new ParseField ( "memory_id" ) ;
53
+ private static final String CONVERSATION_ID = "memory_id" ;
56
54
57
55
// Optional parameter; if an LLM model is not set at the search pipeline level, one must be
58
56
// provided at the search request level.
59
- private static final ParseField LLM_MODEL = new ParseField ( "llm_model" ) ;
57
+ private static final String LLM_MODEL = "llm_model" ;
60
58
61
59
// Required parameter; this is sent to LLMs as part of the user prompt.
62
60
// TODO support question rewriting when chat history is not used (conversation_id is not provided).
63
- private static final ParseField LLM_QUESTION = new ParseField ( "llm_question" ) ;
61
+ private static final String LLM_QUESTION = "llm_question" ;
64
62
65
63
// Optional parameter; this parameter controls the number of search results ("contexts") to
66
64
// include in the user prompt.
67
- private static final ParseField CONTEXT_SIZE = new ParseField ( "context_size" ) ;
65
+ private static final String CONTEXT_SIZE = "context_size" ;
68
66
69
67
// Optional parameter; this parameter controls the number of the interactions to include
70
68
// in the user prompt.
71
- private static final ParseField INTERACTION_SIZE = new ParseField ( "message_size" ) ;
69
+ private static final String INTERACTION_SIZE = "message_size" ;
72
70
73
71
// Optional parameter; this parameter controls how long the search pipeline waits for a response
74
72
// from a remote inference endpoint before timing out the request.
75
- private static final ParseField TIMEOUT = new ParseField ( "timeout" ) ;
73
+ private static final String TIMEOUT = "timeout" ;
76
74
77
75
// Optional parameter: this parameter allows request-level customization of the "system" (role) prompt.
78
- private static final ParseField SYSTEM_PROMPT = new ParseField ( GenerativeQAProcessorConstants . CONFIG_NAME_SYSTEM_PROMPT ) ;
76
+ private static final String SYSTEM_PROMPT = "system_prompt" ;
79
77
80
78
// Optional parameter: this parameter allows request-level customization of the "user" (role) prompt.
81
- private static final ParseField USER_INSTRUCTIONS = new ParseField ( GenerativeQAProcessorConstants . CONFIG_NAME_USER_INSTRUCTIONS ) ;
79
+ private static final String USER_INSTRUCTIONS = "user_instructions" ;
82
80
83
81
// Optional parameter; this parameter indicates the name of the field in the LLM response
84
82
// that contains the chat completion text, i.e. "answer".
85
- private static final ParseField LLM_RESPONSE_FIELD = new ParseField ( "llm_response_field" ) ;
83
+ private static final String LLM_RESPONSE_FIELD = "llm_response_field" ;
86
84
87
- private static final ParseField LLM_MESSAGES_FIELD = new ParseField ( "llm_messages" ) ;
85
+ private static final String LLM_MESSAGES_FIELD = "llm_messages" ;
88
86
89
87
public static final int SIZE_NULL_VALUE = -1 ;
90
88
91
- static {
92
- PARSER = new ObjectParser <>("generative_qa_parameters" , GenerativeQAParameters ::new );
93
- PARSER .declareString (GenerativeQAParameters ::setConversationId , CONVERSATION_ID );
94
- PARSER .declareString (GenerativeQAParameters ::setLlmModel , LLM_MODEL );
95
- PARSER .declareString (GenerativeQAParameters ::setLlmQuestion , LLM_QUESTION );
96
- PARSER .declareStringOrNull (GenerativeQAParameters ::setSystemPrompt , SYSTEM_PROMPT );
97
- PARSER .declareStringOrNull (GenerativeQAParameters ::setUserInstructions , USER_INSTRUCTIONS );
98
- PARSER .declareIntOrNull (GenerativeQAParameters ::setContextSize , SIZE_NULL_VALUE , CONTEXT_SIZE );
99
- PARSER .declareIntOrNull (GenerativeQAParameters ::setInteractionSize , SIZE_NULL_VALUE , INTERACTION_SIZE );
100
- PARSER .declareIntOrNull (GenerativeQAParameters ::setTimeout , SIZE_NULL_VALUE , TIMEOUT );
101
- PARSER .declareStringOrNull (GenerativeQAParameters ::setLlmResponseField , LLM_RESPONSE_FIELD );
102
- PARSER .declareObjectArray (GenerativeQAParameters ::setMessageBlock , (p , c ) -> MessageBlock .fromXContent (p ), LLM_MESSAGES_FIELD );
103
- }
104
-
105
89
@ Setter
106
90
@ Getter
107
91
private String conversationId ;
@@ -167,6 +151,7 @@ public GenerativeQAParameters(
167
151
);
168
152
}
169
153
154
+ @ Builder (toBuilder = true )
170
155
public GenerativeQAParameters (
171
156
String conversationId ,
172
157
String llmModel ,
@@ -184,7 +169,7 @@ public GenerativeQAParameters(
184
169
185
170
// TODO: keep this requirement until we can extract the question from the query or from the request processor parameters
186
171
// for question rewriting.
187
- Preconditions .checkArgument (!Strings .isNullOrEmpty (llmQuestion ), LLM_QUESTION . getPreferredName () + " must be provided." );
172
+ Preconditions .checkArgument (!Strings .isNullOrEmpty (llmQuestion ), LLM_QUESTION + " must be provided." );
188
173
this .llmQuestion = llmQuestion ;
189
174
this .systemPrompt = systemPrompt ;
190
175
this .userInstructions = userInstructions ;
@@ -212,17 +197,49 @@ public GenerativeQAParameters(StreamInput input) throws IOException {
212
197
213
198
@ Override
214
199
public XContentBuilder toXContent (XContentBuilder xContentBuilder , Params params ) throws IOException {
215
- return xContentBuilder
216
- .field (CONVERSATION_ID .getPreferredName (), this .conversationId )
217
- .field (LLM_MODEL .getPreferredName (), this .llmModel )
218
- .field (LLM_QUESTION .getPreferredName (), this .llmQuestion )
219
- .field (SYSTEM_PROMPT .getPreferredName (), this .systemPrompt )
220
- .field (USER_INSTRUCTIONS .getPreferredName (), this .userInstructions )
221
- .field (CONTEXT_SIZE .getPreferredName (), this .contextSize )
222
- .field (INTERACTION_SIZE .getPreferredName (), this .interactionSize )
223
- .field (TIMEOUT .getPreferredName (), this .timeout )
224
- .field (LLM_RESPONSE_FIELD .getPreferredName (), this .llmResponseField )
225
- .field (LLM_MESSAGES_FIELD .getPreferredName (), this .llmMessages );
200
+ xContentBuilder .startObject ();
201
+ if (this .conversationId != null ) {
202
+ xContentBuilder .field (CONVERSATION_ID , this .conversationId );
203
+ }
204
+
205
+ if (this .llmModel != null ) {
206
+ xContentBuilder .field (LLM_MODEL , this .llmModel );
207
+ }
208
+
209
+ if (this .llmQuestion != null ) {
210
+ xContentBuilder .field (LLM_QUESTION , this .llmQuestion );
211
+ }
212
+
213
+ if (this .systemPrompt != null ) {
214
+ xContentBuilder .field (SYSTEM_PROMPT , this .systemPrompt );
215
+ }
216
+
217
+ if (this .userInstructions != null ) {
218
+ xContentBuilder .field (USER_INSTRUCTIONS , this .userInstructions );
219
+ }
220
+
221
+ if (this .contextSize != null ) {
222
+ xContentBuilder .field (CONTEXT_SIZE , this .contextSize );
223
+ }
224
+
225
+ if (this .interactionSize != null ) {
226
+ xContentBuilder .field (INTERACTION_SIZE , this .interactionSize );
227
+ }
228
+
229
+ if (this .timeout != null ) {
230
+ xContentBuilder .field (TIMEOUT , this .timeout );
231
+ }
232
+
233
+ if (this .llmResponseField != null ) {
234
+ xContentBuilder .field (LLM_RESPONSE_FIELD , this .llmResponseField );
235
+ }
236
+
237
+ if (this .llmMessages != null && !this .llmMessages .isEmpty ()) {
238
+ xContentBuilder .field (LLM_MESSAGES_FIELD , this .llmMessages );
239
+ }
240
+
241
+ xContentBuilder .endObject ();
242
+ return xContentBuilder ;
226
243
}
227
244
228
245
@ Override
@@ -242,7 +259,76 @@ public void writeTo(StreamOutput out) throws IOException {
242
259
}
243
260
244
261
public static GenerativeQAParameters parse (XContentParser parser ) throws IOException {
245
- return PARSER .parse (parser , null );
262
+ String conversationId = null ;
263
+ String llmModel = null ;
264
+ String llmQuestion = null ;
265
+ String systemPrompt = null ;
266
+ String userInstructions = null ;
267
+ Integer contextSize = null ;
268
+ Integer interactionSize = null ;
269
+ Integer timeout = null ;
270
+ String llmResponseField = null ;
271
+ List <MessageBlock > llmMessages = null ;
272
+
273
+ ensureExpectedToken (XContentParser .Token .START_OBJECT , parser .currentToken (), parser );
274
+ while (parser .nextToken () != XContentParser .Token .END_OBJECT ) {
275
+ String field = parser .currentName ();
276
+ parser .nextToken ();
277
+
278
+ switch (field ) {
279
+ case CONVERSATION_ID :
280
+ conversationId = parser .text ();
281
+ break ;
282
+ case LLM_MODEL :
283
+ llmModel = parser .text ();
284
+ break ;
285
+ case LLM_QUESTION :
286
+ llmQuestion = parser .text ();
287
+ break ;
288
+ case SYSTEM_PROMPT :
289
+ systemPrompt = parser .text ();
290
+ break ;
291
+ case USER_INSTRUCTIONS :
292
+ userInstructions = parser .text ();
293
+ break ;
294
+ case CONTEXT_SIZE :
295
+ contextSize = parser .intValue ();
296
+ break ;
297
+ case INTERACTION_SIZE :
298
+ interactionSize = parser .intValue ();
299
+ break ;
300
+ case TIMEOUT :
301
+ timeout = parser .intValue ();
302
+ break ;
303
+ case LLM_RESPONSE_FIELD :
304
+ llmResponseField = parser .text ();
305
+ break ;
306
+ case LLM_MESSAGES_FIELD :
307
+ llmMessages = new ArrayList <>();
308
+ ensureExpectedToken (XContentParser .Token .START_ARRAY , parser .currentToken (), parser );
309
+ while (parser .nextToken () != XContentParser .Token .END_ARRAY ) {
310
+ llmMessages .add (MessageBlock .fromXContent (parser ));
311
+ }
312
+ break ;
313
+ default :
314
+ parser .skipChildren ();
315
+ break ;
316
+ }
317
+ }
318
+
319
+ return GenerativeQAParameters
320
+ .builder ()
321
+ .conversationId (conversationId )
322
+ .llmModel (llmModel )
323
+ .llmQuestion (llmQuestion )
324
+ .systemPrompt (systemPrompt )
325
+ .userInstructions (userInstructions )
326
+ .contextSize (contextSize )
327
+ .interactionSize (interactionSize )
328
+ .timeout (timeout )
329
+ .llmResponseField (llmResponseField )
330
+ .llmMessages (llmMessages )
331
+ .build ();
246
332
}
247
333
248
334
@ Override
0 commit comments