|
24 | 24 | import static org.mockito.Mockito.mock;
|
25 | 25 | import static org.mockito.Mockito.verify;
|
26 | 26 | import static org.mockito.Mockito.when;
|
| 27 | +import static org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants.RAG_NULL_GEN_QA_PARAMS_ERROR_MSG; |
27 | 28 |
|
28 | 29 | import java.time.Instant;
|
29 | 30 | import java.util.Collections;
|
@@ -646,6 +647,77 @@ public void testProcessResponseNullValueInteractions() throws Exception {
|
646 | 647 | }));
|
647 | 648 | }
|
648 | 649 |
|
| 650 | + public void testProcessResponseIllegalArgumentForNullParams() throws Exception { |
| 651 | + exceptionRule.expect(IllegalArgumentException.class); |
| 652 | + exceptionRule.expectMessage(RAG_NULL_GEN_QA_PARAMS_ERROR_MSG); |
| 653 | + |
| 654 | + Client client = mock(Client.class); |
| 655 | + Map<String, Object> config = new HashMap<>(); |
| 656 | + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model"); |
| 657 | + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text")); |
| 658 | + |
| 659 | + GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory( |
| 660 | + client, |
| 661 | + alwaysOn |
| 662 | + ).create(null, "tag", "desc", true, config, null); |
| 663 | + |
| 664 | + ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class); |
| 665 | + List<Interaction> chatHistory = List |
| 666 | + .of( |
| 667 | + new Interaction( |
| 668 | + "0", |
| 669 | + Instant.now(), |
| 670 | + "1", |
| 671 | + "question", |
| 672 | + "", |
| 673 | + "answer", |
| 674 | + "foo", |
| 675 | + Collections.singletonMap("meta data", "some meta") |
| 676 | + ) |
| 677 | + ); |
| 678 | + doAnswer(invocation -> { |
| 679 | + ((ActionListener<List<Interaction>>) invocation.getArguments()[2]).onResponse(chatHistory); |
| 680 | + return null; |
| 681 | + }).when(memoryClient).getInteractions(any(), anyInt(), any()); |
| 682 | + processor.setMemoryClient(memoryClient); |
| 683 | + |
| 684 | + SearchRequest request = new SearchRequest(); |
| 685 | + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); |
| 686 | + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); |
| 687 | + extBuilder.setParams(null); |
| 688 | + request.source(sourceBuilder); |
| 689 | + sourceBuilder.ext(List.of(extBuilder)); |
| 690 | + |
| 691 | + int numHits = 10; |
| 692 | + SearchHit[] hitsArray = new SearchHit[numHits]; |
| 693 | + for (int i = 0; i < numHits; i++) { |
| 694 | + XContentBuilder sourceContent = JsonXContent |
| 695 | + .contentBuilder() |
| 696 | + .startObject() |
| 697 | + .field("_id", String.valueOf(i)) |
| 698 | + .field("text", "passage" + i) |
| 699 | + .field("title", "This is the title for document " + i) |
| 700 | + .endObject(); |
| 701 | + hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of()); |
| 702 | + hitsArray[i].sourceRef(BytesReference.bytes(sourceContent)); |
| 703 | + } |
| 704 | + |
| 705 | + SearchHits searchHits = new SearchHits(hitsArray, null, 1.0f); |
| 706 | + SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0); |
| 707 | + SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null); |
| 708 | + |
| 709 | + Llm llm = mock(Llm.class); |
| 710 | + processor.setLlm(llm); |
| 711 | + |
| 712 | + processor |
| 713 | + .processResponseAsync( |
| 714 | + request, |
| 715 | + response, |
| 716 | + null, |
| 717 | + ActionListener.wrap(r -> { assertTrue(r instanceof GenerativeSearchResponse); }, e -> {}) |
| 718 | + ); |
| 719 | + } |
| 720 | + |
649 | 721 | public void testProcessResponseIllegalArgument() throws Exception {
|
650 | 722 | exceptionRule.expect(IllegalArgumentException.class);
|
651 | 723 | exceptionRule.expectMessage("llm_model cannot be null.");
|
|
0 commit comments