Skip to content

Commit ffc758b

Browse files
committed
handle null value exceptions when arguments are missing or Null in caling RAG pipeline
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 456e92d commit ffc758b

File tree

2 files changed

+100
-30
lines changed

2 files changed

+100
-30
lines changed

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java

+46-30
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import java.util.Map;
2828
import java.util.function.BooleanSupplier;
2929

30+
import org.opensearch.OpenSearchException;
3031
import org.opensearch.action.search.SearchRequest;
3132
import org.opensearch.action.search.SearchResponse;
3233
import org.opensearch.client.Client;
@@ -58,7 +59,8 @@
5859
*/
5960
@Log4j2
6061
public class GenerativeQAResponseProcessor extends AbstractProcessor implements SearchResponseProcessor {
61-
62+
public static String IllegalArgumentMessage =
63+
"Please check the provided generative_qa_parameters are complete and non-null(https://opensearch.org/docs/latest/search-plugins/conversational-search/#rag-pipeline). Messages in the memory can not have Null value for input and response";
6264
private static final int DEFAULT_CHAT_HISTORY_WINDOW = 10;
6365

6466
private static final int DEFAULT_PROCESSOR_TIME_IN_SECONDS = 30;
@@ -148,37 +150,51 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
148150
log.info("system_prompt: {}", systemPrompt);
149151
log.info("user_instructions: {}", userInstructions);
150152
start = Instant.now();
151-
ChatCompletionOutput output = llm
152-
.doChatCompletion(
153-
LlmIOUtil
154-
.createChatCompletionInput(systemPrompt, userInstructions, llmModel, llmQuestion, chatHistory, searchResults, timeout)
155-
);
156-
log.info("doChatCompletion complete. ({})", getDuration(start));
157-
158-
String answer = null;
159-
String errorMessage = null;
160-
String interactionId = null;
161-
if (output.isErrorOccurred()) {
162-
errorMessage = output.getErrors().get(0);
163-
} else {
164-
answer = (String) output.getAnswers().get(0);
165-
166-
if (conversationId != null) {
167-
start = Instant.now();
168-
interactionId = memoryClient
169-
.createInteraction(
170-
conversationId,
171-
llmQuestion,
172-
PromptUtil.getPromptTemplate(systemPrompt, userInstructions),
173-
answer,
174-
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
175-
Collections.singletonMap("metadata", jsonArrayToString(searchResults))
176-
);
177-
log.info("Created a new interaction: {} ({})", interactionId, getDuration(start));
153+
try {
154+
ChatCompletionOutput output = llm
155+
.doChatCompletion(
156+
LlmIOUtil
157+
.createChatCompletionInput(
158+
systemPrompt,
159+
userInstructions,
160+
llmModel,
161+
llmQuestion,
162+
chatHistory,
163+
searchResults,
164+
timeout
165+
)
166+
);
167+
log.info("doChatCompletion complete. ({})", getDuration(start));
168+
169+
String answer = null;
170+
String errorMessage = null;
171+
String interactionId = null;
172+
if (output.isErrorOccurred()) {
173+
errorMessage = output.getErrors().get(0);
174+
} else {
175+
answer = (String) output.getAnswers().get(0);
176+
177+
if (conversationId != null) {
178+
start = Instant.now();
179+
interactionId = memoryClient
180+
.createInteraction(
181+
conversationId,
182+
llmQuestion,
183+
PromptUtil.getPromptTemplate(systemPrompt, userInstructions),
184+
answer,
185+
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
186+
Collections.singletonMap("metadata", jsonArrayToString(searchResults))
187+
);
188+
log.info("Created a new interaction: {} ({})", interactionId, getDuration(start));
189+
}
178190
}
179-
}
180191

181-
return insertAnswer(response, answer, errorMessage, interactionId);
192+
return insertAnswer(response, answer, errorMessage, interactionId);
193+
} catch (NullPointerException nullPointerException) {
194+
throw new IllegalArgumentException(IllegalArgumentMessage);
195+
} catch (Exception e) {
196+
throw new OpenSearchException("GenerativeQAResponseProcessor failed in precessing response");
197+
}
182198
}
183199

184200
long getDuration(Instant start) {

search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java

+54
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import static org.mockito.Mockito.mock;
2323
import static org.mockito.Mockito.verify;
2424
import static org.mockito.Mockito.when;
25+
import static org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAResponseProcessor.IllegalArgumentMessage;
2526

2627
import java.time.Instant;
2728
import java.util.Collections;
@@ -407,4 +408,57 @@ public void testProcessorFeatureOffOnOff() throws Exception {
407408
}
408409
assertTrue(secondExceptionThrown);
409410
}
411+
412+
public void testProcessResponseNullValueInteractions() throws Exception {
413+
exceptionRule.expect(IllegalArgumentException.class);
414+
exceptionRule.expectMessage(IllegalArgumentMessage);
415+
416+
Client client = mock(Client.class);
417+
Map<String, Object> config = new HashMap<>();
418+
config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model");
419+
config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text"));
420+
421+
GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory(
422+
client,
423+
alwaysOn
424+
).create(null, "tag", "desc", true, config, null);
425+
426+
ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class);
427+
when(memoryClient.getInteractions(any(), anyInt()))
428+
.thenReturn(List.of(new Interaction("0", Instant.now(), "1", null, null, null, null, null)));
429+
processor.setMemoryClient(memoryClient);
430+
431+
SearchRequest request = new SearchRequest();
432+
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
433+
int contextSize = 5;
434+
GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind.", contextSize, null, null);
435+
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
436+
extBuilder.setParams(params);
437+
request.source(sourceBuilder);
438+
sourceBuilder.ext(List.of(extBuilder));
439+
440+
int numHits = 10;
441+
SearchHit[] hitsArray = new SearchHit[numHits];
442+
for (int i = 0; i < numHits; i++) {
443+
XContentBuilder sourceContent = JsonXContent
444+
.contentBuilder()
445+
.startObject()
446+
.field("_id", String.valueOf(i))
447+
.field("text", "passage" + i)
448+
.field("title", "This is the title for document " + i)
449+
.endObject();
450+
hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of());
451+
hitsArray[i].sourceRef(BytesReference.bytes(sourceContent));
452+
}
453+
454+
SearchHits searchHits = new SearchHits(hitsArray, null, 1.0f);
455+
SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0);
456+
SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null);
457+
458+
Llm llm = mock(Llm.class);
459+
when(llm.doChatCompletion(any())).thenThrow(new NullPointerException("Null Pointer in Interactions"));
460+
processor.setLlm(llm);
461+
462+
SearchResponse res = processor.processResponse(request, response);
463+
}
410464
}

0 commit comments

Comments
 (0)