Skip to content

Commit 83acad3

Browse files
Migrate RAG pipeline to async processing. (opensearch-project#2345) (opensearch-project#2350)
1 parent 54a01a1 commit 83acad3

File tree

7 files changed

+542
-191
lines changed

7 files changed

+542
-191
lines changed

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java

+105-42
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,23 @@
2727
import java.util.Map;
2828
import java.util.function.BooleanSupplier;
2929

30-
import org.opensearch.OpenSearchException;
3130
import org.opensearch.action.search.SearchRequest;
3231
import org.opensearch.action.search.SearchResponse;
3332
import org.opensearch.client.Client;
33+
import org.opensearch.core.action.ActionListener;
34+
import org.opensearch.core.common.Strings;
3435
import org.opensearch.ingest.ConfigurationUtils;
3536
import org.opensearch.ml.common.conversation.Interaction;
3637
import org.opensearch.ml.common.exception.MLException;
3738
import org.opensearch.search.SearchHit;
3839
import org.opensearch.search.pipeline.AbstractProcessor;
40+
import org.opensearch.search.pipeline.PipelineProcessingContext;
3941
import org.opensearch.search.pipeline.Processor;
4042
import org.opensearch.search.pipeline.SearchResponseProcessor;
4143
import org.opensearch.searchpipelines.questionanswering.generative.client.ConversationalMemoryClient;
4244
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamUtil;
4345
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters;
46+
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionInput;
4447
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput;
4548
import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm;
4649
import org.opensearch.searchpipelines.questionanswering.generative.llm.LlmIOUtil;
@@ -65,8 +68,6 @@ public class GenerativeQAResponseProcessor extends AbstractProcessor implements
6568

6669
private static final int DEFAULT_PROCESSOR_TIME_IN_SECONDS = 30;
6770

68-
// TODO Add "interaction_count". This is how far back in chat history we want to go back when calling LLM.
69-
7071
private final String llmModel;
7172
private final List<String> contextFields;
7273

@@ -106,20 +107,32 @@ protected GenerativeQAResponseProcessor(
106107
}
107108

108109
@Override
109-
public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception {
110+
public SearchResponse processResponse(SearchRequest searchRequest, SearchResponse searchResponse) {
111+
// Synchronous call is no longer supported because this execution can occur on a transport thread.
112+
throw new UnsupportedOperationException();
113+
}
110114

111-
log.info("Entering processResponse.");
115+
@Override
116+
public void processResponseAsync(
117+
SearchRequest request,
118+
SearchResponse response,
119+
PipelineProcessingContext requestContext,
120+
ActionListener<SearchResponse> responseListener
121+
) {
122+
log.debug("Entering processResponse.");
112123

113124
if (!this.featureFlagSupplier.getAsBoolean()) {
114125
throw new MLException(GenerativeQAProcessorConstants.FEATURE_NOT_ENABLED_ERROR_MSG);
115126
}
116127

117128
GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters(request);
118129

119-
Integer timeout = params.getTimeout();
120-
if (timeout == null || timeout == GenerativeQAParameters.SIZE_NULL_VALUE) {
121-
timeout = DEFAULT_PROCESSOR_TIME_IN_SECONDS;
130+
Integer t = params.getTimeout();
131+
if (t == null || t == GenerativeQAParameters.SIZE_NULL_VALUE) {
132+
t = DEFAULT_PROCESSOR_TIME_IN_SECONDS;
122133
}
134+
final int timeout = t;
135+
log.debug("Timeout for this request: {} seconds.", timeout);
123136

124137
String llmQuestion = params.getLlmQuestion();
125138
String llmModel = params.getLlmModel() == null ? this.llmModel : params.getLlmModel();
@@ -128,14 +141,15 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
128141
}
129142
String conversationId = params.getConversationId();
130143

144+
if (conversationId != null && !Strings.hasText(conversationId)) {
145+
throw new IllegalArgumentException("Empty conversation_id is not allowed.");
146+
}
131147
Instant start = Instant.now();
132148
Integer interactionSize = params.getInteractionSize();
133149
if (interactionSize == null || interactionSize == GenerativeQAParameters.SIZE_NULL_VALUE) {
134150
interactionSize = DEFAULT_CHAT_HISTORY_WINDOW;
135151
}
136-
List<Interaction> chatHistory = (conversationId == null)
137-
? Collections.emptyList()
138-
: memoryClient.getInteractions(conversationId, interactionSize);
152+
log.debug("Using interaction size of {}", interactionSize);
139153

140154
Integer topN = params.getContextSize();
141155
if (topN == null) {
@@ -153,10 +167,32 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
153167
effectiveUserInstructions = params.getUserInstructions();
154168
}
155169

156-
start = Instant.now();
157-
try {
158-
ChatCompletionOutput output = llm
159-
.doChatCompletion(
170+
final List<Interaction> chatHistory = new ArrayList<>();
171+
if (conversationId == null) {
172+
doChatCompletion(
173+
LlmIOUtil
174+
.createChatCompletionInput(
175+
systemPrompt,
176+
userInstructions,
177+
llmModel,
178+
llmQuestion,
179+
chatHistory,
180+
searchResults,
181+
timeout,
182+
params.getLlmResponseField()
183+
),
184+
null,
185+
llmQuestion,
186+
searchResults,
187+
response,
188+
responseListener
189+
);
190+
} else {
191+
final Instant memoryStart = Instant.now();
192+
memoryClient.getInteractions(conversationId, interactionSize, ActionListener.wrap(r -> {
193+
log.debug("getInteractions complete. ({})", getDuration(memoryStart));
194+
chatHistory.addAll(r);
195+
doChatCompletion(
160196
LlmIOUtil
161197
.createChatCompletionInput(
162198
systemPrompt,
@@ -167,53 +203,82 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
167203
searchResults,
168204
timeout,
169205
params.getLlmResponseField()
170-
)
206+
),
207+
conversationId,
208+
llmQuestion,
209+
searchResults,
210+
response,
211+
responseListener
171212
);
172-
log.info("doChatCompletion complete. ({})", getDuration(start));
213+
}, responseListener::onFailure));
214+
}
215+
}
173216

174-
String answer = null;
175-
String errorMessage = null;
176-
String interactionId = null;
177-
if (output.isErrorOccurred()) {
178-
errorMessage = output.getErrors().get(0);
179-
} else {
180-
answer = (String) output.getAnswers().get(0);
217+
private void doChatCompletion(
218+
ChatCompletionInput input,
219+
String conversationId,
220+
String llmQuestion,
221+
List<String> searchResults,
222+
SearchResponse response,
223+
ActionListener<SearchResponse> responseListener
224+
) {
225+
226+
final Instant chatStart = Instant.now();
227+
llm.doChatCompletion(input, new ActionListener<>() {
228+
@Override
229+
public void onResponse(ChatCompletionOutput output) {
230+
log.debug("doChatCompletion complete. ({})", getDuration(chatStart));
231+
232+
final String answer = getAnswer(output);
233+
final String errorMessage = getError(output);
181234

182235
if (conversationId != null) {
183-
start = Instant.now();
184-
interactionId = memoryClient
236+
final Instant memoryStart = Instant.now();
237+
memoryClient
185238
.createInteraction(
186239
conversationId,
187240
llmQuestion,
188241
PromptUtil.getPromptTemplate(systemPrompt, userInstructions),
189242
answer,
190243
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
191-
Collections.singletonMap("metadata", jsonArrayToString(searchResults))
244+
Collections.singletonMap("metadata", jsonArrayToString(searchResults)),
245+
ActionListener.wrap(r -> {
246+
responseListener.onResponse(insertAnswer(response, answer, errorMessage, r));
247+
log.info("Created a new interaction: {} ({})", r, getDuration(memoryStart));
248+
}, responseListener::onFailure)
192249
);
193-
log.info("Created a new interaction: {} ({})", interactionId, getDuration(start));
250+
251+
} else {
252+
responseListener.onResponse(insertAnswer(response, answer, errorMessage, null));
194253
}
254+
195255
}
196256

197-
return insertAnswer(response, answer, errorMessage, interactionId);
198-
} catch (NullPointerException nullPointerException) {
199-
throw new IllegalArgumentException(IllegalArgumentMessage);
200-
} catch (Exception e) {
201-
throw new OpenSearchException("GenerativeQAResponseProcessor failed in precessing response");
202-
}
203-
}
257+
@Override
258+
public void onFailure(Exception e) {
259+
responseListener.onFailure(e);
260+
}
204261

205-
long getDuration(Instant start) {
206-
return Duration.between(start, Instant.now()).toMillis();
262+
private String getError(ChatCompletionOutput output) {
263+
return output.isErrorOccurred() ? output.getErrors().get(0) : null;
264+
}
265+
266+
private String getAnswer(ChatCompletionOutput output) {
267+
return output.isErrorOccurred() ? null : (String) output.getAnswers().get(0);
268+
}
269+
});
207270
}
208271

209272
@Override
210273
public String getType() {
211274
return GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE;
212275
}
213276

214-
private SearchResponse insertAnswer(SearchResponse response, String answer, String errorMessage, String interactionId) {
277+
private long getDuration(Instant start) {
278+
return Duration.between(start, Instant.now()).toMillis();
279+
}
215280

216-
// TODO return the interaction id in the response.
281+
private SearchResponse insertAnswer(SearchResponse response, String answer, String errorMessage, String interactionId) {
217282

218283
return new GenerativeSearchResponse(
219284
answer,
@@ -240,9 +305,7 @@ private List<String> getSearchResults(SearchResponse response, Integer topN) {
240305
for (String contextField : contextFields) {
241306
Object context = docSourceMap.get(contextField);
242307
if (context == null) {
243-
log.error("Context " + contextField + " not found in search hit " + hits[i]);
244-
// TODO throw a more meaningful error here?
245-
throw new RuntimeException();
308+
throw new RuntimeException("Context " + contextField + " not found in search hit " + hits[i]);
246309
}
247310
searchResults.add(context.toString());
248311
}

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java

+47
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.apache.logging.log4j.LogManager;
2525
import org.apache.logging.log4j.Logger;
2626
import org.opensearch.client.Client;
27+
import org.opensearch.core.action.ActionListener;
2728
import org.opensearch.core.common.util.CollectionUtils;
2829
import org.opensearch.ml.common.conversation.Interaction;
2930
import org.opensearch.ml.memory.action.conversation.CreateConversationAction;
@@ -83,6 +84,33 @@ public String createInteraction(
8384
return res.getId();
8485
}
8586

87+
public void createInteraction(
88+
String conversationId,
89+
String input,
90+
String promptTemplate,
91+
String response,
92+
String origin,
93+
Map<String, String> additionalInfo,
94+
ActionListener<String> listener
95+
) {
96+
client
97+
.execute(
98+
CreateInteractionAction.INSTANCE,
99+
new CreateInteractionRequest(conversationId, input, promptTemplate, response, origin, additionalInfo),
100+
new ActionListener<CreateInteractionResponse>() {
101+
@Override
102+
public void onResponse(CreateInteractionResponse createInteractionResponse) {
103+
listener.onResponse(createInteractionResponse.getId());
104+
}
105+
106+
@Override
107+
public void onFailure(Exception e) {
108+
listener.onFailure(e);
109+
}
110+
}
111+
);
112+
}
113+
86114
public List<Interaction> getInteractions(String conversationId, int lastN) {
87115

88116
Preconditions.checkArgument(lastN > 0, "lastN must be at least 1.");
@@ -113,4 +141,23 @@ public List<Interaction> getInteractions(String conversationId, int lastN) {
113141

114142
return interactions;
115143
}
144+
145+
public void getInteractions(String conversationId, int lastN, ActionListener<List<Interaction>> listener) {
146+
client
147+
.execute(
148+
GetInteractionsAction.INSTANCE,
149+
new GetInteractionsRequest(conversationId, lastN, 0),
150+
new ActionListener<GetInteractionsResponse>() {
151+
@Override
152+
public void onResponse(GetInteractionsResponse getInteractionsResponse) {
153+
listener.onResponse(getInteractionsResponse.getInteractions());
154+
}
155+
156+
@Override
157+
public void onFailure(Exception e) {
158+
listener.onFailure(e);
159+
}
160+
}
161+
);
162+
}
116163
}

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClient.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public ActionFuture<MLOutput> predict(String modelId, MLInput mlInput) {
4242
}
4343

4444
@VisibleForTesting
45-
void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
45+
public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
4646
validateMLInput(mlInput, true);
4747

4848
MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java

+26-11
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import java.util.Map;
2626

2727
import org.opensearch.client.Client;
28-
import org.opensearch.common.action.ActionFuture;
28+
import org.opensearch.core.action.ActionListener;
2929
import org.opensearch.ml.common.FunctionName;
3030
import org.opensearch.ml.common.dataset.MLInputDataset;
3131
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
@@ -75,20 +75,35 @@ protected void setMlClient(MachineLearningInternalClient mlClient) {
7575
* @return
7676
*/
7777
@Override
78-
public ChatCompletionOutput doChatCompletion(ChatCompletionInput chatCompletionInput) {
7978

79+
public void doChatCompletion(ChatCompletionInput chatCompletionInput, ActionListener<ChatCompletionOutput> listener) {
8080
MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(getInputParameters(chatCompletionInput)).build();
8181
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataset).build();
82-
ActionFuture<MLOutput> future = mlClient.predict(this.openSearchModelId, mlInput);
83-
ModelTensorOutput modelOutput = (ModelTensorOutput) future.actionGet(chatCompletionInput.getTimeoutInSeconds() * 1000);
84-
85-
// Response from a remote model
86-
Map<String, ?> dataAsMap = modelOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap();
87-
// log.info("dataAsMap: {}", dataAsMap.toString());
88-
89-
// TODO dataAsMap can be null or can contain information such as throttling. Handle non-happy cases.
82+
mlClient.predict(this.openSearchModelId, mlInput, new ActionListener<>() {
83+
@Override
84+
public void onResponse(MLOutput mlOutput) {
85+
// Response from a remote model
86+
Map<String, ?> dataAsMap = ((ModelTensorOutput) mlOutput)
87+
.getMlModelOutputs()
88+
.get(0)
89+
.getMlModelTensors()
90+
.get(0)
91+
.getDataAsMap();
92+
listener
93+
.onResponse(
94+
buildChatCompletionOutput(
95+
chatCompletionInput.getModelProvider(),
96+
dataAsMap,
97+
chatCompletionInput.getLlmResponseField()
98+
)
99+
);
100+
}
90101

91-
return buildChatCompletionOutput(chatCompletionInput.getModelProvider(), dataAsMap, chatCompletionInput.getLlmResponseField());
102+
@Override
103+
public void onFailure(Exception e) {
104+
listener.onFailure(e);
105+
}
106+
});
92107
}
93108

94109
protected Map<String, String> getInputParameters(ChatCompletionInput chatCompletionInput) {

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
*/
1818
package org.opensearch.searchpipelines.questionanswering.generative.llm;
1919

20+
import org.opensearch.core.action.ActionListener;
21+
2022
/**
2123
* Capabilities of large language models, e.g. completion, embeddings, etc.
2224
*/
@@ -29,5 +31,5 @@ enum ModelProvider {
2931
COHERE
3032
}
3133

32-
ChatCompletionOutput doChatCompletion(ChatCompletionInput input);
34+
void doChatCompletion(ChatCompletionInput input, ActionListener<ChatCompletionOutput> listener);
3335
}

0 commit comments

Comments
 (0)