27
27
import java .util .Map ;
28
28
import java .util .function .BooleanSupplier ;
29
29
30
- import org .opensearch .OpenSearchException ;
31
30
import org .opensearch .action .search .SearchRequest ;
32
31
import org .opensearch .action .search .SearchResponse ;
33
32
import org .opensearch .client .Client ;
33
+ import org .opensearch .core .action .ActionListener ;
34
+ import org .opensearch .core .common .Strings ;
34
35
import org .opensearch .ingest .ConfigurationUtils ;
35
36
import org .opensearch .ml .common .conversation .Interaction ;
36
37
import org .opensearch .ml .common .exception .MLException ;
37
38
import org .opensearch .search .SearchHit ;
38
39
import org .opensearch .search .pipeline .AbstractProcessor ;
40
+ import org .opensearch .search .pipeline .PipelineProcessingContext ;
39
41
import org .opensearch .search .pipeline .Processor ;
40
42
import org .opensearch .search .pipeline .SearchResponseProcessor ;
41
43
import org .opensearch .searchpipelines .questionanswering .generative .client .ConversationalMemoryClient ;
42
44
import org .opensearch .searchpipelines .questionanswering .generative .ext .GenerativeQAParamUtil ;
43
45
import org .opensearch .searchpipelines .questionanswering .generative .ext .GenerativeQAParameters ;
46
+ import org .opensearch .searchpipelines .questionanswering .generative .llm .ChatCompletionInput ;
44
47
import org .opensearch .searchpipelines .questionanswering .generative .llm .ChatCompletionOutput ;
45
48
import org .opensearch .searchpipelines .questionanswering .generative .llm .Llm ;
46
49
import org .opensearch .searchpipelines .questionanswering .generative .llm .LlmIOUtil ;
@@ -65,8 +68,6 @@ public class GenerativeQAResponseProcessor extends AbstractProcessor implements
65
68
66
69
private static final int DEFAULT_PROCESSOR_TIME_IN_SECONDS = 30 ;
67
70
68
- // TODO Add "interaction_count". This is how far back in chat history we want to go back when calling LLM.
69
-
70
71
private final String llmModel ;
71
72
private final List <String > contextFields ;
72
73
@@ -106,20 +107,32 @@ protected GenerativeQAResponseProcessor(
106
107
}
107
108
108
109
@ Override
109
- public SearchResponse processResponse (SearchRequest request , SearchResponse response ) throws Exception {
110
+ public SearchResponse processResponse (SearchRequest searchRequest , SearchResponse searchResponse ) {
111
+ // Synchronous call is no longer supported because this execution can occur on a transport thread.
112
+ throw new UnsupportedOperationException ();
113
+ }
110
114
111
- log .info ("Entering processResponse." );
115
+ @ Override
116
+ public void processResponseAsync (
117
+ SearchRequest request ,
118
+ SearchResponse response ,
119
+ PipelineProcessingContext requestContext ,
120
+ ActionListener <SearchResponse > responseListener
121
+ ) {
122
+ log .debug ("Entering processResponse." );
112
123
113
124
if (!this .featureFlagSupplier .getAsBoolean ()) {
114
125
throw new MLException (GenerativeQAProcessorConstants .FEATURE_NOT_ENABLED_ERROR_MSG );
115
126
}
116
127
117
128
GenerativeQAParameters params = GenerativeQAParamUtil .getGenerativeQAParameters (request );
118
129
119
- Integer timeout = params .getTimeout ();
120
- if (timeout == null || timeout == GenerativeQAParameters .SIZE_NULL_VALUE ) {
121
- timeout = DEFAULT_PROCESSOR_TIME_IN_SECONDS ;
130
+ Integer t = params .getTimeout ();
131
+ if (t == null || t == GenerativeQAParameters .SIZE_NULL_VALUE ) {
132
+ t = DEFAULT_PROCESSOR_TIME_IN_SECONDS ;
122
133
}
134
+ final int timeout = t ;
135
+ log .debug ("Timeout for this request: {} seconds." , timeout );
123
136
124
137
String llmQuestion = params .getLlmQuestion ();
125
138
String llmModel = params .getLlmModel () == null ? this .llmModel : params .getLlmModel ();
@@ -128,14 +141,15 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
128
141
}
129
142
String conversationId = params .getConversationId ();
130
143
144
+ if (conversationId != null && !Strings .hasText (conversationId )) {
145
+ throw new IllegalArgumentException ("Empty conversation_id is not allowed." );
146
+ }
131
147
Instant start = Instant .now ();
132
148
Integer interactionSize = params .getInteractionSize ();
133
149
if (interactionSize == null || interactionSize == GenerativeQAParameters .SIZE_NULL_VALUE ) {
134
150
interactionSize = DEFAULT_CHAT_HISTORY_WINDOW ;
135
151
}
136
- List <Interaction > chatHistory = (conversationId == null )
137
- ? Collections .emptyList ()
138
- : memoryClient .getInteractions (conversationId , interactionSize );
152
+ log .debug ("Using interaction size of {}" , interactionSize );
139
153
140
154
Integer topN = params .getContextSize ();
141
155
if (topN == null ) {
@@ -153,10 +167,32 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
153
167
effectiveUserInstructions = params .getUserInstructions ();
154
168
}
155
169
156
- start = Instant .now ();
157
- try {
158
- ChatCompletionOutput output = llm
159
- .doChatCompletion (
170
+ final List <Interaction > chatHistory = new ArrayList <>();
171
+ if (conversationId == null ) {
172
+ doChatCompletion (
173
+ LlmIOUtil
174
+ .createChatCompletionInput (
175
+ systemPrompt ,
176
+ userInstructions ,
177
+ llmModel ,
178
+ llmQuestion ,
179
+ chatHistory ,
180
+ searchResults ,
181
+ timeout ,
182
+ params .getLlmResponseField ()
183
+ ),
184
+ null ,
185
+ llmQuestion ,
186
+ searchResults ,
187
+ response ,
188
+ responseListener
189
+ );
190
+ } else {
191
+ final Instant memoryStart = Instant .now ();
192
+ memoryClient .getInteractions (conversationId , interactionSize , ActionListener .wrap (r -> {
193
+ log .debug ("getInteractions complete. ({})" , getDuration (memoryStart ));
194
+ chatHistory .addAll (r );
195
+ doChatCompletion (
160
196
LlmIOUtil
161
197
.createChatCompletionInput (
162
198
systemPrompt ,
@@ -167,53 +203,82 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
167
203
searchResults ,
168
204
timeout ,
169
205
params .getLlmResponseField ()
170
- )
206
+ ),
207
+ conversationId ,
208
+ llmQuestion ,
209
+ searchResults ,
210
+ response ,
211
+ responseListener
171
212
);
172
- log .info ("doChatCompletion complete. ({})" , getDuration (start ));
213
+ }, responseListener ::onFailure ));
214
+ }
215
+ }
173
216
174
- String answer = null ;
175
- String errorMessage = null ;
176
- String interactionId = null ;
177
- if (output .isErrorOccurred ()) {
178
- errorMessage = output .getErrors ().get (0 );
179
- } else {
180
- answer = (String ) output .getAnswers ().get (0 );
217
+ private void doChatCompletion (
218
+ ChatCompletionInput input ,
219
+ String conversationId ,
220
+ String llmQuestion ,
221
+ List <String > searchResults ,
222
+ SearchResponse response ,
223
+ ActionListener <SearchResponse > responseListener
224
+ ) {
225
+
226
+ final Instant chatStart = Instant .now ();
227
+ llm .doChatCompletion (input , new ActionListener <>() {
228
+ @ Override
229
+ public void onResponse (ChatCompletionOutput output ) {
230
+ log .debug ("doChatCompletion complete. ({})" , getDuration (chatStart ));
231
+
232
+ final String answer = getAnswer (output );
233
+ final String errorMessage = getError (output );
181
234
182
235
if (conversationId != null ) {
183
- start = Instant .now ();
184
- interactionId = memoryClient
236
+ final Instant memoryStart = Instant .now ();
237
+ memoryClient
185
238
.createInteraction (
186
239
conversationId ,
187
240
llmQuestion ,
188
241
PromptUtil .getPromptTemplate (systemPrompt , userInstructions ),
189
242
answer ,
190
243
GenerativeQAProcessorConstants .RESPONSE_PROCESSOR_TYPE ,
191
- Collections .singletonMap ("metadata" , jsonArrayToString (searchResults ))
244
+ Collections .singletonMap ("metadata" , jsonArrayToString (searchResults )),
245
+ ActionListener .wrap (r -> {
246
+ responseListener .onResponse (insertAnswer (response , answer , errorMessage , r ));
247
+ log .info ("Created a new interaction: {} ({})" , r , getDuration (memoryStart ));
248
+ }, responseListener ::onFailure )
192
249
);
193
- log .info ("Created a new interaction: {} ({})" , interactionId , getDuration (start ));
250
+
251
+ } else {
252
+ responseListener .onResponse (insertAnswer (response , answer , errorMessage , null ));
194
253
}
254
+
195
255
}
196
256
197
- return insertAnswer (response , answer , errorMessage , interactionId );
198
- } catch (NullPointerException nullPointerException ) {
199
- throw new IllegalArgumentException (IllegalArgumentMessage );
200
- } catch (Exception e ) {
201
- throw new OpenSearchException ("GenerativeQAResponseProcessor failed in precessing response" );
202
- }
203
- }
257
+ @ Override
258
+ public void onFailure (Exception e ) {
259
+ responseListener .onFailure (e );
260
+ }
204
261
205
- long getDuration (Instant start ) {
206
- return Duration .between (start , Instant .now ()).toMillis ();
262
+ private String getError (ChatCompletionOutput output ) {
263
+ return output .isErrorOccurred () ? output .getErrors ().get (0 ) : null ;
264
+ }
265
+
266
+ private String getAnswer (ChatCompletionOutput output ) {
267
+ return output .isErrorOccurred () ? null : (String ) output .getAnswers ().get (0 );
268
+ }
269
+ });
207
270
}
208
271
209
272
@ Override
210
273
public String getType () {
211
274
return GenerativeQAProcessorConstants .RESPONSE_PROCESSOR_TYPE ;
212
275
}
213
276
214
- private SearchResponse insertAnswer (SearchResponse response , String answer , String errorMessage , String interactionId ) {
277
+ private long getDuration (Instant start ) {
278
+ return Duration .between (start , Instant .now ()).toMillis ();
279
+ }
215
280
216
- // TODO return the interaction id in the response.
281
+ private SearchResponse insertAnswer ( SearchResponse response , String answer , String errorMessage , String interactionId ) {
217
282
218
283
return new GenerativeSearchResponse (
219
284
answer ,
@@ -240,9 +305,7 @@ private List<String> getSearchResults(SearchResponse response, Integer topN) {
240
305
for (String contextField : contextFields ) {
241
306
Object context = docSourceMap .get (contextField );
242
307
if (context == null ) {
243
- log .error ("Context " + contextField + " not found in search hit " + hits [i ]);
244
- // TODO throw a more meaningful error here?
245
- throw new RuntimeException ();
308
+ throw new RuntimeException ("Context " + contextField + " not found in search hit " + hits [i ]);
246
309
}
247
310
searchResults .add (context .toString ());
248
311
}
0 commit comments