27
27
import java .util .Map ;
28
28
import java .util .function .BooleanSupplier ;
29
29
30
- import org .opensearch .OpenSearchException ;
31
30
import org .opensearch .action .search .SearchRequest ;
32
31
import org .opensearch .action .search .SearchResponse ;
33
32
import org .opensearch .client .Client ;
33
+ import org .opensearch .core .action .ActionListener ;
34
+ import org .opensearch .core .common .Strings ;
34
35
import org .opensearch .ingest .ConfigurationUtils ;
35
36
import org .opensearch .ml .common .conversation .Interaction ;
36
37
import org .opensearch .ml .common .exception .MLException ;
37
38
import org .opensearch .search .SearchHit ;
38
39
import org .opensearch .search .pipeline .AbstractProcessor ;
40
+ import org .opensearch .search .pipeline .PipelineProcessingContext ;
39
41
import org .opensearch .search .pipeline .Processor ;
40
42
import org .opensearch .search .pipeline .SearchResponseProcessor ;
41
43
import org .opensearch .searchpipelines .questionanswering .generative .client .ConversationalMemoryClient ;
42
44
import org .opensearch .searchpipelines .questionanswering .generative .ext .GenerativeQAParamUtil ;
43
45
import org .opensearch .searchpipelines .questionanswering .generative .ext .GenerativeQAParameters ;
46
+ import org .opensearch .searchpipelines .questionanswering .generative .llm .ChatCompletionInput ;
44
47
import org .opensearch .searchpipelines .questionanswering .generative .llm .ChatCompletionOutput ;
45
48
import org .opensearch .searchpipelines .questionanswering .generative .llm .Llm ;
46
49
import org .opensearch .searchpipelines .questionanswering .generative .llm .LlmIOUtil ;
@@ -65,8 +68,6 @@ public class GenerativeQAResponseProcessor extends AbstractProcessor implements
65
68
66
69
private static final int DEFAULT_PROCESSOR_TIME_IN_SECONDS = 30 ;
67
70
68
- // TODO Add "interaction_count". This is how far back in chat history we want to go back when calling LLM.
69
-
70
71
private final String llmModel ;
71
72
private final List <String > contextFields ;
72
73
@@ -106,8 +107,18 @@ protected GenerativeQAResponseProcessor(
106
107
}
107
108
108
109
@ Override
109
- public SearchResponse processResponse (SearchRequest request , SearchResponse response ) throws Exception {
110
+ public SearchResponse processResponse (SearchRequest searchRequest , SearchResponse searchResponse ) {
111
+ // Synchronous call is no longer supported because this execution can occur on a transport thread.
112
+ throw new UnsupportedOperationException ();
113
+ }
110
114
115
+ @ Override
116
+ public void processResponseAsync (
117
+ SearchRequest request ,
118
+ SearchResponse response ,
119
+ PipelineProcessingContext requestContext ,
120
+ ActionListener <SearchResponse > responseListener
121
+ ) {
111
122
log .info ("Entering processResponse." );
112
123
113
124
if (!this .featureFlagSupplier .getAsBoolean ()) {
@@ -116,10 +127,12 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
116
127
117
128
GenerativeQAParameters params = GenerativeQAParamUtil .getGenerativeQAParameters (request );
118
129
119
- Integer timeout = params .getTimeout ();
120
- if (timeout == null || timeout == GenerativeQAParameters .SIZE_NULL_VALUE ) {
121
- timeout = DEFAULT_PROCESSOR_TIME_IN_SECONDS ;
130
+ Integer t = params .getTimeout ();
131
+ if (t == null || t == GenerativeQAParameters .SIZE_NULL_VALUE ) {
132
+ t = DEFAULT_PROCESSOR_TIME_IN_SECONDS ;
122
133
}
134
+ final int timeout = t ;
135
+ log .info ("Timeout for this request: {} seconds." , timeout );
123
136
124
137
String llmQuestion = params .getLlmQuestion ();
125
138
String llmModel = params .getLlmModel () == null ? this .llmModel : params .getLlmModel ();
@@ -128,14 +141,16 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
128
141
}
129
142
String conversationId = params .getConversationId ();
130
143
144
+ if (conversationId != null && !Strings .hasText (conversationId )) {
145
+ throw new IllegalArgumentException ("Empty conversation_id is not allowed." );
146
+ }
147
+ // log.info("LLM question: {}, LLM model {}, conversation id: {}", llmQuestion, llmModel, conversationId);
131
148
Instant start = Instant .now ();
132
149
Integer interactionSize = params .getInteractionSize ();
133
150
if (interactionSize == null || interactionSize == GenerativeQAParameters .SIZE_NULL_VALUE ) {
134
151
interactionSize = DEFAULT_CHAT_HISTORY_WINDOW ;
135
152
}
136
- List <Interaction > chatHistory = (conversationId == null )
137
- ? Collections .emptyList ()
138
- : memoryClient .getInteractions (conversationId , interactionSize );
153
+ log .info ("Using interaction size of {}" , interactionSize );
139
154
140
155
Integer topN = params .getContextSize ();
141
156
if (topN == null ) {
@@ -153,10 +168,35 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
153
168
effectiveUserInstructions = params .getUserInstructions ();
154
169
}
155
170
156
- start = Instant .now ();
157
- try {
158
- ChatCompletionOutput output = llm
159
- .doChatCompletion (
171
+ // log.info("system_prompt: {}", systemPrompt);
172
+ // log.info("user_instructions: {}", userInstructions);
173
+
174
+ final List <Interaction > chatHistory = new ArrayList <>();
175
+ if (conversationId == null ) {
176
+ doChatCompletion (
177
+ LlmIOUtil
178
+ .createChatCompletionInput (
179
+ systemPrompt ,
180
+ userInstructions ,
181
+ llmModel ,
182
+ llmQuestion ,
183
+ chatHistory ,
184
+ searchResults ,
185
+ timeout ,
186
+ params .getLlmResponseField ()
187
+ ),
188
+ null ,
189
+ llmQuestion ,
190
+ searchResults ,
191
+ response ,
192
+ responseListener
193
+ );
194
+ } else {
195
+ final Instant memoryStart = Instant .now ();
196
+ memoryClient .getInteractions (conversationId , interactionSize , ActionListener .wrap (r -> {
197
+ log .info ("getInteractions complete. ({})" , getDuration (memoryStart ));
198
+ chatHistory .addAll (r );
199
+ doChatCompletion (
160
200
LlmIOUtil
161
201
.createChatCompletionInput (
162
202
systemPrompt ,
@@ -167,53 +207,82 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
167
207
searchResults ,
168
208
timeout ,
169
209
params .getLlmResponseField ()
170
- )
210
+ ),
211
+ conversationId ,
212
+ llmQuestion ,
213
+ searchResults ,
214
+ response ,
215
+ responseListener
171
216
);
172
- log .info ("doChatCompletion complete. ({})" , getDuration (start ));
217
+ }, responseListener ::onFailure ));
218
+ }
219
+ }
173
220
174
- String answer = null ;
175
- String errorMessage = null ;
176
- String interactionId = null ;
177
- if (output .isErrorOccurred ()) {
178
- errorMessage = output .getErrors ().get (0 );
179
- } else {
180
- answer = (String ) output .getAnswers ().get (0 );
221
+ private void doChatCompletion (
222
+ ChatCompletionInput input ,
223
+ String conversationId ,
224
+ String llmQuestion ,
225
+ List <String > searchResults ,
226
+ SearchResponse response ,
227
+ ActionListener <SearchResponse > responseListener
228
+ ) {
229
+
230
+ final Instant chatStart = Instant .now ();
231
+ llm .doChatCompletion (input , new ActionListener <>() {
232
+ @ Override
233
+ public void onResponse (ChatCompletionOutput output ) {
234
+ log .info ("doChatCompletion complete. ({})" , getDuration (chatStart ));
235
+
236
+ final String answer = getAnswer (output );
237
+ final String errorMessage = getError (output );
181
238
182
239
if (conversationId != null ) {
183
- start = Instant .now ();
184
- interactionId = memoryClient
240
+ final Instant memoryStart = Instant .now ();
241
+ memoryClient
185
242
.createInteraction (
186
243
conversationId ,
187
244
llmQuestion ,
188
245
PromptUtil .getPromptTemplate (systemPrompt , userInstructions ),
189
246
answer ,
190
247
GenerativeQAProcessorConstants .RESPONSE_PROCESSOR_TYPE ,
191
- Collections .singletonMap ("metadata" , jsonArrayToString (searchResults ))
248
+ Collections .singletonMap ("metadata" , jsonArrayToString (searchResults )),
249
+ ActionListener .wrap (r -> {
250
+ responseListener .onResponse (insertAnswer (response , answer , errorMessage , r ));
251
+ log .info ("Created a new interaction: {} ({})" , r , getDuration (memoryStart ));
252
+ }, responseListener ::onFailure )
192
253
);
193
- log .info ("Created a new interaction: {} ({})" , interactionId , getDuration (start ));
254
+
255
+ } else {
256
+ responseListener .onResponse (insertAnswer (response , answer , errorMessage , null ));
194
257
}
258
+
195
259
}
196
260
197
- return insertAnswer (response , answer , errorMessage , interactionId );
198
- } catch (NullPointerException nullPointerException ) {
199
- throw new IllegalArgumentException (IllegalArgumentMessage );
200
- } catch (Exception e ) {
201
- throw new OpenSearchException ("GenerativeQAResponseProcessor failed in precessing response" );
202
- }
203
- }
261
+ @ Override
262
+ public void onFailure (Exception e ) {
263
+ responseListener .onFailure (e );
264
+ }
204
265
205
- long getDuration (Instant start ) {
206
- return Duration .between (start , Instant .now ()).toMillis ();
266
+ private String getError (ChatCompletionOutput output ) {
267
+ return output .isErrorOccurred () ? output .getErrors ().get (0 ) : null ;
268
+ }
269
+
270
+ private String getAnswer (ChatCompletionOutput output ) {
271
+ return output .isErrorOccurred () ? null : (String ) output .getAnswers ().get (0 );
272
+ }
273
+ });
207
274
}
208
275
209
276
@ Override
210
277
public String getType () {
211
278
return GenerativeQAProcessorConstants .RESPONSE_PROCESSOR_TYPE ;
212
279
}
213
280
214
- private SearchResponse insertAnswer (SearchResponse response , String answer , String errorMessage , String interactionId ) {
281
+ private long getDuration (Instant start ) {
282
+ return Duration .between (start , Instant .now ()).toMillis ();
283
+ }
215
284
216
- // TODO return the interaction id in the response.
285
+ private SearchResponse insertAnswer ( SearchResponse response , String answer , String errorMessage , String interactionId ) {
217
286
218
287
return new GenerativeSearchResponse (
219
288
answer ,
@@ -240,9 +309,7 @@ private List<String> getSearchResults(SearchResponse response, Integer topN) {
240
309
for (String contextField : contextFields ) {
241
310
Object context = docSourceMap .get (contextField );
242
311
if (context == null ) {
243
- log .error ("Context " + contextField + " not found in search hit " + hits [i ]);
244
- // TODO throw a more meaningful error here?
245
- throw new RuntimeException ();
312
+ throw new RuntimeException ("Context " + contextField + " not found in search hit " + hits [i ]);
246
313
}
247
314
searchResults .add (context .toString ());
248
315
}
0 commit comments