Skip to content

Commit 5427338

Browse files
committed
handle null value exceptions when arguments are missing or Null in caling RAG pipeline
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 456e92d commit 5427338

File tree

2 files changed

+104
-29
lines changed

2 files changed

+104
-29
lines changed

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java

+36-29
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import java.util.Map;
2828
import java.util.function.BooleanSupplier;
2929

30+
import org.opensearch.OpenSearchException;
3031
import org.opensearch.action.search.SearchRequest;
3132
import org.opensearch.action.search.SearchResponse;
3233
import org.opensearch.client.Client;
@@ -58,7 +59,7 @@
5859
*/
5960
@Log4j2
6061
public class GenerativeQAResponseProcessor extends AbstractProcessor implements SearchResponseProcessor {
61-
62+
public static String IllegalArgumentMessage = "Please check the provided generative_qa_parameters are complete and non-null(https://opensearch.org/docs/latest/search-plugins/conversational-search/#rag-pipeline). Messages in the memory can not have Null value for input and response";
6263
private static final int DEFAULT_CHAT_HISTORY_WINDOW = 10;
6364

6465
private static final int DEFAULT_PROCESSOR_TIME_IN_SECONDS = 30;
@@ -148,37 +149,43 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
148149
log.info("system_prompt: {}", systemPrompt);
149150
log.info("user_instructions: {}", userInstructions);
150151
start = Instant.now();
151-
ChatCompletionOutput output = llm
152-
.doChatCompletion(
153-
LlmIOUtil
154-
.createChatCompletionInput(systemPrompt, userInstructions, llmModel, llmQuestion, chatHistory, searchResults, timeout)
155-
);
156-
log.info("doChatCompletion complete. ({})", getDuration(start));
157-
158-
String answer = null;
159-
String errorMessage = null;
160-
String interactionId = null;
161-
if (output.isErrorOccurred()) {
162-
errorMessage = output.getErrors().get(0);
163-
} else {
164-
answer = (String) output.getAnswers().get(0);
165-
166-
if (conversationId != null) {
167-
start = Instant.now();
168-
interactionId = memoryClient
169-
.createInteraction(
170-
conversationId,
171-
llmQuestion,
172-
PromptUtil.getPromptTemplate(systemPrompt, userInstructions),
173-
answer,
174-
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
175-
Collections.singletonMap("metadata", jsonArrayToString(searchResults))
152+
try {
153+
ChatCompletionOutput output = llm
154+
.doChatCompletion(
155+
LlmIOUtil
156+
.createChatCompletionInput(systemPrompt, userInstructions, llmModel, llmQuestion, chatHistory, searchResults, timeout)
176157
);
177-
log.info("Created a new interaction: {} ({})", interactionId, getDuration(start));
158+
log.info("doChatCompletion complete. ({})", getDuration(start));
159+
160+
String answer = null;
161+
String errorMessage = null;
162+
String interactionId = null;
163+
if (output.isErrorOccurred()) {
164+
errorMessage = output.getErrors().get(0);
165+
} else {
166+
answer = (String) output.getAnswers().get(0);
167+
168+
if (conversationId != null) {
169+
start = Instant.now();
170+
interactionId = memoryClient
171+
.createInteraction(
172+
conversationId,
173+
llmQuestion,
174+
PromptUtil.getPromptTemplate(systemPrompt, userInstructions),
175+
answer,
176+
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
177+
Collections.singletonMap("metadata", jsonArrayToString(searchResults))
178+
);
179+
log.info("Created a new interaction: {} ({})", interactionId, getDuration(start));
180+
}
178181
}
179-
}
180182

181-
return insertAnswer(response, answer, errorMessage, interactionId);
183+
return insertAnswer(response, answer, errorMessage, interactionId);
184+
} catch (NullPointerException nullPointerException) {
185+
throw new IllegalArgumentException(IllegalArgumentMessage);
186+
} catch (Exception e) {
187+
throw new OpenSearchException("GenerativeQAResponseProcessor failed in precessing response");
188+
}
182189
}
183190

184191
long getDuration(Instant start) {

search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java

+68
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import static org.mockito.Mockito.mock;
2323
import static org.mockito.Mockito.verify;
2424
import static org.mockito.Mockito.when;
25+
import static org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAResponseProcessor.IllegalArgumentMessage;
2526

2627
import java.time.Instant;
2728
import java.util.Collections;
@@ -407,4 +408,71 @@ public void testProcessorFeatureOffOnOff() throws Exception {
407408
}
408409
assertTrue(secondExceptionThrown);
409410
}
411+
412+
public void testProcessResponseNullValueInteractions() throws Exception {
413+
exceptionRule.expect(IllegalArgumentException.class);
414+
exceptionRule.expectMessage(IllegalArgumentMessage);
415+
416+
Client client = mock(Client.class);
417+
Map<String, Object> config = new HashMap<>();
418+
config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model");
419+
config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text"));
420+
421+
GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory(
422+
client,
423+
alwaysOn
424+
).create(null, "tag", "desc", true, config, null);
425+
426+
ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class);
427+
when(memoryClient.getInteractions(any(), anyInt()))
428+
.thenReturn(
429+
List
430+
.of(
431+
new Interaction(
432+
"0",
433+
Instant.now(),
434+
"1",
435+
null,
436+
null,
437+
null,
438+
null,
439+
null
440+
)
441+
)
442+
);
443+
processor.setMemoryClient(memoryClient);
444+
445+
SearchRequest request = new SearchRequest();
446+
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
447+
int contextSize = 5;
448+
GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind.", contextSize, null, null);
449+
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
450+
extBuilder.setParams(params);
451+
request.source(sourceBuilder);
452+
sourceBuilder.ext(List.of(extBuilder));
453+
454+
int numHits = 10;
455+
SearchHit[] hitsArray = new SearchHit[numHits];
456+
for (int i = 0; i < numHits; i++) {
457+
XContentBuilder sourceContent = JsonXContent
458+
.contentBuilder()
459+
.startObject()
460+
.field("_id", String.valueOf(i))
461+
.field("text", "passage" + i)
462+
.field("title", "This is the title for document " + i)
463+
.endObject();
464+
hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of());
465+
hitsArray[i].sourceRef(BytesReference.bytes(sourceContent));
466+
}
467+
468+
SearchHits searchHits = new SearchHits(hitsArray, null, 1.0f);
469+
SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0);
470+
SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null);
471+
472+
Llm llm = mock(Llm.class);
473+
when(llm.doChatCompletion(any())).thenThrow(new NullPointerException("Null Pointer in Interactions"));
474+
processor.setLlm(llm);
475+
476+
SearchResponse res = processor.processResponse(request, response);
477+
}
410478
}

0 commit comments

Comments
 (0)