Skip to content

Commit 59193e9

Browse files
Migrate RAG pipeline to async processing.
Signed-off-by: Austin Lee <austin@aryn.ai>
1 parent fc555c0 commit 59193e9

File tree

7 files changed

+546
-190
lines changed

7 files changed

+546
-190
lines changed

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java

+108-41
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,23 @@
2727
import java.util.Map;
2828
import java.util.function.BooleanSupplier;
2929

30-
import org.opensearch.OpenSearchException;
3130
import org.opensearch.action.search.SearchRequest;
3231
import org.opensearch.action.search.SearchResponse;
3332
import org.opensearch.client.Client;
33+
import org.opensearch.core.action.ActionListener;
34+
import org.opensearch.core.common.Strings;
3435
import org.opensearch.ingest.ConfigurationUtils;
3536
import org.opensearch.ml.common.conversation.Interaction;
3637
import org.opensearch.ml.common.exception.MLException;
3738
import org.opensearch.search.SearchHit;
3839
import org.opensearch.search.pipeline.AbstractProcessor;
40+
import org.opensearch.search.pipeline.PipelineProcessingContext;
3941
import org.opensearch.search.pipeline.Processor;
4042
import org.opensearch.search.pipeline.SearchResponseProcessor;
4143
import org.opensearch.searchpipelines.questionanswering.generative.client.ConversationalMemoryClient;
4244
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamUtil;
4345
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters;
46+
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionInput;
4447
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput;
4548
import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm;
4649
import org.opensearch.searchpipelines.questionanswering.generative.llm.LlmIOUtil;
@@ -65,8 +68,6 @@ public class GenerativeQAResponseProcessor extends AbstractProcessor implements
6568

6669
private static final int DEFAULT_PROCESSOR_TIME_IN_SECONDS = 30;
6770

68-
// TODO Add "interaction_count". This is how far back in chat history we want to go back when calling LLM.
69-
7071
private final String llmModel;
7172
private final List<String> contextFields;
7273

@@ -106,8 +107,18 @@ protected GenerativeQAResponseProcessor(
106107
}
107108

108109
@Override
109-
public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception {
110+
public SearchResponse processResponse(SearchRequest searchRequest, SearchResponse searchResponse) {
111+
// Synchronous call is no longer supported because this execution can occur on a transport thread.
112+
throw new UnsupportedOperationException();
113+
}
110114

115+
@Override
116+
public void processResponseAsync(
117+
SearchRequest request,
118+
SearchResponse response,
119+
PipelineProcessingContext requestContext,
120+
ActionListener<SearchResponse> responseListener
121+
) {
111122
log.info("Entering processResponse.");
112123

113124
if (!this.featureFlagSupplier.getAsBoolean()) {
@@ -116,10 +127,12 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
116127

117128
GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters(request);
118129

119-
Integer timeout = params.getTimeout();
120-
if (timeout == null || timeout == GenerativeQAParameters.SIZE_NULL_VALUE) {
121-
timeout = DEFAULT_PROCESSOR_TIME_IN_SECONDS;
130+
Integer t = params.getTimeout();
131+
if (t == null || t == GenerativeQAParameters.SIZE_NULL_VALUE) {
132+
t = DEFAULT_PROCESSOR_TIME_IN_SECONDS;
122133
}
134+
final int timeout = t;
135+
log.info("Timeout for this request: {} seconds.", timeout);
123136

124137
String llmQuestion = params.getLlmQuestion();
125138
String llmModel = params.getLlmModel() == null ? this.llmModel : params.getLlmModel();
@@ -128,14 +141,16 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
128141
}
129142
String conversationId = params.getConversationId();
130143

144+
if (conversationId != null && !Strings.hasText(conversationId)) {
145+
throw new IllegalArgumentException("Empty conversation_id is not allowed.");
146+
}
147+
// log.info("LLM question: {}, LLM model {}, conversation id: {}", llmQuestion, llmModel, conversationId);
131148
Instant start = Instant.now();
132149
Integer interactionSize = params.getInteractionSize();
133150
if (interactionSize == null || interactionSize == GenerativeQAParameters.SIZE_NULL_VALUE) {
134151
interactionSize = DEFAULT_CHAT_HISTORY_WINDOW;
135152
}
136-
List<Interaction> chatHistory = (conversationId == null)
137-
? Collections.emptyList()
138-
: memoryClient.getInteractions(conversationId, interactionSize);
153+
log.info("Using interaction size of {}", interactionSize);
139154

140155
Integer topN = params.getContextSize();
141156
if (topN == null) {
@@ -153,10 +168,35 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
153168
effectiveUserInstructions = params.getUserInstructions();
154169
}
155170

156-
start = Instant.now();
157-
try {
158-
ChatCompletionOutput output = llm
159-
.doChatCompletion(
171+
// log.info("system_prompt: {}", systemPrompt);
172+
// log.info("user_instructions: {}", userInstructions);
173+
174+
final List<Interaction> chatHistory = new ArrayList<>();
175+
if (conversationId == null) {
176+
doChatCompletion(
177+
LlmIOUtil
178+
.createChatCompletionInput(
179+
systemPrompt,
180+
userInstructions,
181+
llmModel,
182+
llmQuestion,
183+
chatHistory,
184+
searchResults,
185+
timeout,
186+
params.getLlmResponseField()
187+
),
188+
null,
189+
llmQuestion,
190+
searchResults,
191+
response,
192+
responseListener
193+
);
194+
} else {
195+
final Instant memoryStart = Instant.now();
196+
memoryClient.getInteractions(conversationId, interactionSize, ActionListener.wrap(r -> {
197+
log.info("getInteractions complete. ({})", getDuration(memoryStart));
198+
chatHistory.addAll(r);
199+
doChatCompletion(
160200
LlmIOUtil
161201
.createChatCompletionInput(
162202
systemPrompt,
@@ -167,53 +207,82 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
167207
searchResults,
168208
timeout,
169209
params.getLlmResponseField()
170-
)
210+
),
211+
conversationId,
212+
llmQuestion,
213+
searchResults,
214+
response,
215+
responseListener
171216
);
172-
log.info("doChatCompletion complete. ({})", getDuration(start));
217+
}, responseListener::onFailure));
218+
}
219+
}
173220

174-
String answer = null;
175-
String errorMessage = null;
176-
String interactionId = null;
177-
if (output.isErrorOccurred()) {
178-
errorMessage = output.getErrors().get(0);
179-
} else {
180-
answer = (String) output.getAnswers().get(0);
221+
private void doChatCompletion(
222+
ChatCompletionInput input,
223+
String conversationId,
224+
String llmQuestion,
225+
List<String> searchResults,
226+
SearchResponse response,
227+
ActionListener<SearchResponse> responseListener
228+
) {
229+
230+
final Instant chatStart = Instant.now();
231+
llm.doChatCompletion(input, new ActionListener<>() {
232+
@Override
233+
public void onResponse(ChatCompletionOutput output) {
234+
log.info("doChatCompletion complete. ({})", getDuration(chatStart));
235+
236+
final String answer = getAnswer(output);
237+
final String errorMessage = getError(output);
181238

182239
if (conversationId != null) {
183-
start = Instant.now();
184-
interactionId = memoryClient
240+
final Instant memoryStart = Instant.now();
241+
memoryClient
185242
.createInteraction(
186243
conversationId,
187244
llmQuestion,
188245
PromptUtil.getPromptTemplate(systemPrompt, userInstructions),
189246
answer,
190247
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
191-
Collections.singletonMap("metadata", jsonArrayToString(searchResults))
248+
Collections.singletonMap("metadata", jsonArrayToString(searchResults)),
249+
ActionListener.wrap(r -> {
250+
responseListener.onResponse(insertAnswer(response, answer, errorMessage, r));
251+
log.info("Created a new interaction: {} ({})", r, getDuration(memoryStart));
252+
}, responseListener::onFailure)
192253
);
193-
log.info("Created a new interaction: {} ({})", interactionId, getDuration(start));
254+
255+
} else {
256+
responseListener.onResponse(insertAnswer(response, answer, errorMessage, null));
194257
}
258+
195259
}
196260

197-
return insertAnswer(response, answer, errorMessage, interactionId);
198-
} catch (NullPointerException nullPointerException) {
199-
throw new IllegalArgumentException(IllegalArgumentMessage);
200-
} catch (Exception e) {
201-
throw new OpenSearchException("GenerativeQAResponseProcessor failed in precessing response");
202-
}
203-
}
261+
@Override
262+
public void onFailure(Exception e) {
263+
responseListener.onFailure(e);
264+
}
204265

205-
long getDuration(Instant start) {
206-
return Duration.between(start, Instant.now()).toMillis();
266+
private String getError(ChatCompletionOutput output) {
267+
return output.isErrorOccurred() ? output.getErrors().get(0) : null;
268+
}
269+
270+
private String getAnswer(ChatCompletionOutput output) {
271+
return output.isErrorOccurred() ? null : (String) output.getAnswers().get(0);
272+
}
273+
});
207274
}
208275

209276
@Override
210277
public String getType() {
211278
return GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE;
212279
}
213280

214-
private SearchResponse insertAnswer(SearchResponse response, String answer, String errorMessage, String interactionId) {
281+
private long getDuration(Instant start) {
282+
return Duration.between(start, Instant.now()).toMillis();
283+
}
215284

216-
// TODO return the interaction id in the response.
285+
private SearchResponse insertAnswer(SearchResponse response, String answer, String errorMessage, String interactionId) {
217286

218287
return new GenerativeSearchResponse(
219288
answer,
@@ -240,9 +309,7 @@ private List<String> getSearchResults(SearchResponse response, Integer topN) {
240309
for (String contextField : contextFields) {
241310
Object context = docSourceMap.get(contextField);
242311
if (context == null) {
243-
log.error("Context " + contextField + " not found in search hit " + hits[i]);
244-
// TODO throw a more meaningful error here?
245-
throw new RuntimeException();
312+
throw new RuntimeException("Context " + contextField + " not found in search hit " + hits[i]);
246313
}
247314
searchResults.add(context.toString());
248315
}

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java

+47
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.apache.logging.log4j.LogManager;
2525
import org.apache.logging.log4j.Logger;
2626
import org.opensearch.client.Client;
27+
import org.opensearch.core.action.ActionListener;
2728
import org.opensearch.core.common.util.CollectionUtils;
2829
import org.opensearch.ml.common.conversation.Interaction;
2930
import org.opensearch.ml.memory.action.conversation.CreateConversationAction;
@@ -83,6 +84,33 @@ public String createInteraction(
8384
return res.getId();
8485
}
8586

87+
public void createInteraction(
88+
String conversationId,
89+
String input,
90+
String promptTemplate,
91+
String response,
92+
String origin,
93+
Map<String, String> additionalInfo,
94+
ActionListener<String> listener
95+
) {
96+
client
97+
.execute(
98+
CreateInteractionAction.INSTANCE,
99+
new CreateInteractionRequest(conversationId, input, promptTemplate, response, origin, additionalInfo),
100+
new ActionListener<CreateInteractionResponse>() {
101+
@Override
102+
public void onResponse(CreateInteractionResponse createInteractionResponse) {
103+
listener.onResponse(createInteractionResponse.getId());
104+
}
105+
106+
@Override
107+
public void onFailure(Exception e) {
108+
listener.onFailure(e);
109+
}
110+
}
111+
);
112+
}
113+
86114
public List<Interaction> getInteractions(String conversationId, int lastN) {
87115

88116
Preconditions.checkArgument(lastN > 0, "lastN must be at least 1.");
@@ -113,4 +141,23 @@ public List<Interaction> getInteractions(String conversationId, int lastN) {
113141

114142
return interactions;
115143
}
144+
145+
public void getInteractions(String conversationId, int lastN, ActionListener<List<Interaction>> listener) {
146+
client
147+
.execute(
148+
GetInteractionsAction.INSTANCE,
149+
new GetInteractionsRequest(conversationId, lastN, 0),
150+
new ActionListener<GetInteractionsResponse>() {
151+
@Override
152+
public void onResponse(GetInteractionsResponse getInteractionsResponse) {
153+
listener.onResponse(getInteractionsResponse.getInteractions());
154+
}
155+
156+
@Override
157+
public void onFailure(Exception e) {
158+
listener.onFailure(e);
159+
}
160+
}
161+
);
162+
}
116163
}

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClient.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public ActionFuture<MLOutput> predict(String modelId, MLInput mlInput) {
4242
}
4343

4444
@VisibleForTesting
45-
void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
45+
public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
4646
validateMLInput(mlInput, true);
4747

4848
MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java

+27-11
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import java.util.Map;
2626

2727
import org.opensearch.client.Client;
28-
import org.opensearch.common.action.ActionFuture;
28+
import org.opensearch.core.action.ActionListener;
2929
import org.opensearch.ml.common.FunctionName;
3030
import org.opensearch.ml.common.dataset.MLInputDataset;
3131
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
@@ -75,20 +75,36 @@ protected void setMlClient(MachineLearningInternalClient mlClient) {
7575
* @return
7676
*/
7777
@Override
78-
public ChatCompletionOutput doChatCompletion(ChatCompletionInput chatCompletionInput) {
7978

79+
public void doChatCompletion(ChatCompletionInput chatCompletionInput, ActionListener<ChatCompletionOutput> listener) {
8080
MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(getInputParameters(chatCompletionInput)).build();
8181
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataset).build();
82-
ActionFuture<MLOutput> future = mlClient.predict(this.openSearchModelId, mlInput);
83-
ModelTensorOutput modelOutput = (ModelTensorOutput) future.actionGet(chatCompletionInput.getTimeoutInSeconds() * 1000);
84-
85-
// Response from a remote model
86-
Map<String, ?> dataAsMap = modelOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap();
87-
// log.info("dataAsMap: {}", dataAsMap.toString());
88-
89-
// TODO dataAsMap can be null or can contain information such as throttling. Handle non-happy cases.
82+
mlClient.predict(this.openSearchModelId, mlInput, new ActionListener<>() {
83+
@Override
84+
public void onResponse(MLOutput mlOutput) {
85+
// Response from a remote model
86+
Map<String, ?> dataAsMap = ((ModelTensorOutput) mlOutput)
87+
.getMlModelOutputs()
88+
.get(0)
89+
.getMlModelTensors()
90+
.get(0)
91+
.getDataAsMap();
92+
// log.info("dataAsMap: {}", dataAsMap.toString());
93+
listener
94+
.onResponse(
95+
buildChatCompletionOutput(
96+
chatCompletionInput.getModelProvider(),
97+
dataAsMap,
98+
chatCompletionInput.getLlmResponseField()
99+
)
100+
);
101+
}
90102

91-
return buildChatCompletionOutput(chatCompletionInput.getModelProvider(), dataAsMap, chatCompletionInput.getLlmResponseField());
103+
@Override
104+
public void onFailure(Exception e) {
105+
listener.onFailure(e);
106+
}
107+
});
92108
}
93109

94110
protected Map<String, String> getInputParameters(ChatCompletionInput chatCompletionInput) {

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
*/
1818
package org.opensearch.searchpipelines.questionanswering.generative.llm;
1919

20+
import org.opensearch.core.action.ActionListener;
21+
2022
/**
2123
* Capabilities of large language models, e.g. completion, embeddings, etc.
2224
*/
@@ -29,5 +31,5 @@ enum ModelProvider {
2931
COHERE
3032
}
3133

32-
ChatCompletionOutput doChatCompletion(ChatCompletionInput input);
34+
void doChatCompletion(ChatCompletionInput input, ActionListener<ChatCompletionOutput> listener);
3335
}

0 commit comments

Comments
 (0)