|
32 | 32 | import java.time.Instant;
|
33 | 33 | import java.util.Collections;
|
34 | 34 | import java.util.List;
|
| 35 | +import java.util.Map; |
35 | 36 |
|
36 | 37 | import org.junit.Before;
|
37 | 38 | import org.mockito.ArgumentCaptor;
|
38 | 39 | import org.mockito.Mock;
|
39 | 40 | import org.opensearch.OpenSearchWrapperException;
|
40 | 41 | import org.opensearch.ResourceAlreadyExistsException;
|
| 42 | +import org.opensearch.action.DocWriteResponse; |
41 | 43 | import org.opensearch.action.admin.indices.create.CreateIndexResponse;
|
42 | 44 | import org.opensearch.action.admin.indices.refresh.RefreshResponse;
|
43 | 45 | import org.opensearch.action.bulk.BulkResponse;
|
|
47 | 49 | import org.opensearch.action.search.SearchResponse;
|
48 | 50 | import org.opensearch.action.search.SearchResponseSections;
|
49 | 51 | import org.opensearch.action.search.ShardSearchFailure;
|
| 52 | +import org.opensearch.action.update.UpdateRequest; |
| 53 | +import org.opensearch.action.update.UpdateResponse; |
50 | 54 | import org.opensearch.client.AdminClient;
|
51 | 55 | import org.opensearch.client.Client;
|
52 | 56 | import org.opensearch.client.IndicesAdminClient;
|
|
59 | 63 | import org.opensearch.commons.ConfigConstants;
|
60 | 64 | import org.opensearch.core.action.ActionListener;
|
61 | 65 | import org.opensearch.core.common.bytes.BytesReference;
|
| 66 | +import org.opensearch.core.index.Index; |
| 67 | +import org.opensearch.core.index.shard.ShardId; |
62 | 68 | import org.opensearch.core.rest.RestStatus;
|
63 | 69 | import org.opensearch.core.xcontent.XContentBuilder;
|
64 | 70 | import org.opensearch.index.query.MatchAllQueryBuilder;
|
@@ -441,46 +447,34 @@ public void testGetTraces_NoIndex_ThenEmpty() {
|
441 | 447 | assert (argCaptor.getValue().size() == 0);
|
442 | 448 | }
|
443 | 449 |
|
444 |
| - public void testGetTraces() { |
445 |
| - doAnswer(invocation -> { |
446 |
| - XContentBuilder content = XContentBuilder.builder(XContentType.JSON.xContent()); |
447 |
| - content.startObject(); |
448 |
| - content.field(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now()); |
449 |
| - content.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "sample inputs"); |
450 |
| - content.field(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, "conversation-id"); |
451 |
| - content.endObject(); |
| 450 | + public void testInnerGetTraces_success() { |
| 451 | + setUpSearchTraceResponse(); |
| 452 | + doReturn(true).when(metadata).hasIndex(anyString()); |
| 453 | + @SuppressWarnings("unchecked") |
| 454 | + ActionListener<List<Interaction>> getTracesListener = mock(ActionListener.class); |
| 455 | + interactionsIndex.innerGetTraces("cid", 0, 10, getTracesListener); |
| 456 | + @SuppressWarnings("unchecked") |
| 457 | + ArgumentCaptor<List<Interaction>> argCaptor = ArgumentCaptor.forClass(List.class); |
| 458 | + verify(getTracesListener, times(1)).onResponse(argCaptor.capture()); |
| 459 | + assert (argCaptor.getValue().size() == 1); |
| 460 | + } |
452 | 461 |
|
453 |
| - SearchHit[] hits = new SearchHit[1]; |
454 |
| - hits[0] = new SearchHit(0, "iId", null, null).sourceRef(BytesReference.bytes(content)); |
455 |
| - SearchHits searchHits = new SearchHits(hits, null, Float.NaN); |
456 |
| - SearchResponseSections searchSections = new SearchResponseSections( |
457 |
| - searchHits, |
458 |
| - InternalAggregations.EMPTY, |
459 |
| - null, |
460 |
| - false, |
461 |
| - false, |
462 |
| - null, |
463 |
| - 1 |
464 |
| - ); |
465 |
| - SearchResponse searchResponse = new SearchResponse( |
466 |
| - searchSections, |
467 |
| - null, |
468 |
| - 1, |
469 |
| - 1, |
470 |
| - 0, |
471 |
| - 11, |
472 |
| - ShardSearchFailure.EMPTY_ARRAY, |
473 |
| - SearchResponse.Clusters.EMPTY |
474 |
| - ); |
475 |
| - ActionListener<SearchResponse> al = invocation.getArgument(1); |
476 |
| - al.onResponse(searchResponse); |
477 |
| - return null; |
478 |
| - }).when(client).search(any(), any()); |
| 462 | + public void testGetTraces_success() { |
| 463 | + setupGrantAccess(); |
| 464 | + doReturn(true).when(metadata).hasIndex(anyString()); |
| 465 | + setupRefreshSuccess(); |
479 | 466 |
|
| 467 | + GetResponse response = setUpInteractionResponse("iid"); |
| 468 | + doAnswer(invocation -> { |
| 469 | + ActionListener<GetResponse> listener = invocation.getArgument(1); |
| 470 | + listener.onResponse(response); |
| 471 | + return null; |
| 472 | + }).when(client).get(any(), any()); |
| 473 | + setUpSearchTraceResponse(); |
480 | 474 | doReturn(true).when(metadata).hasIndex(anyString());
|
481 | 475 | @SuppressWarnings("unchecked")
|
482 | 476 | ActionListener<List<Interaction>> getTracesListener = mock(ActionListener.class);
|
483 |
| - interactionsIndex.innerGetTraces("cid", 0, 10, getTracesListener); |
| 477 | + interactionsIndex.getTraces("iid", 0, 10, getTracesListener); |
484 | 478 | @SuppressWarnings("unchecked")
|
485 | 479 | ArgumentCaptor<List<Interaction>> argCaptor = ArgumentCaptor.forClass(List.class);
|
486 | 480 | verify(getTracesListener, times(1)).onResponse(argCaptor.capture());
|
@@ -800,4 +794,156 @@ public void testGetSg_ClientFails_ThenFail() {
|
800 | 794 | verify(getListener, times(1)).onFailure(argCaptor.capture());
|
801 | 795 | assert (argCaptor.getValue().getMessage().equals("Client Failure in Sg Get"));
|
802 | 796 | }
|
| 797 | + |
| 798 | + public void testGetSg_NoAccess_ThenFail() { |
| 799 | + doReturn(true).when(metadata).hasIndex(anyString()); |
| 800 | + setupDenyAccess("Henry"); |
| 801 | + setupRefreshSuccess(); |
| 802 | + GetResponse response = setUpInteractionResponse("iid"); |
| 803 | + doAnswer(invocation -> { |
| 804 | + ActionListener<GetResponse> listener = invocation.getArgument(1); |
| 805 | + listener.onResponse(response); |
| 806 | + return null; |
| 807 | + }).when(client).get(any(), any()); |
| 808 | + ActionListener<Interaction> getListener = mock(ActionListener.class); |
| 809 | + interactionsIndex.getInteraction("iid", getListener); |
| 810 | + ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class); |
| 811 | + verify(getListener, times(1)).onFailure(argCaptor.capture()); |
| 812 | + assert (argCaptor.getValue().getMessage().equals("User [Henry] does not have access to interaction iid")); |
| 813 | + } |
| 814 | + |
| 815 | + public void testGetSg_GrantAccess_Success() { |
| 816 | + setupGrantAccess(); |
| 817 | + doReturn(true).when(metadata).hasIndex(anyString()); |
| 818 | + setupRefreshSuccess(); |
| 819 | + GetResponse response = setUpInteractionResponse("iid"); |
| 820 | + doAnswer(invocation -> { |
| 821 | + ActionListener<GetResponse> listener = invocation.getArgument(1); |
| 822 | + listener.onResponse(response); |
| 823 | + return null; |
| 824 | + }).when(client).get(any(), any()); |
| 825 | + ActionListener<Interaction> getListener = mock(ActionListener.class); |
| 826 | + interactionsIndex.getInteraction("iid", getListener); |
| 827 | + ArgumentCaptor<Interaction> argCaptor = ArgumentCaptor.forClass(Interaction.class); |
| 828 | + verify(getListener, times(1)).onResponse(argCaptor.capture()); |
| 829 | + assert (argCaptor.getValue().getId().equals("iid")); |
| 830 | + assert (argCaptor.getValue().getConversationId().equals("conversation test 1")); |
| 831 | + } |
| 832 | + |
| 833 | + public void testGetTraces_NoAccess_ThenFail() { |
| 834 | + doReturn(true).when(metadata).hasIndex(anyString()); |
| 835 | + setupRefreshSuccess(); |
| 836 | + setupDenyAccess("Xun"); |
| 837 | + GetResponse response = setUpInteractionResponse("iid"); |
| 838 | + doAnswer(invocation -> { |
| 839 | + ActionListener<GetResponse> listener = invocation.getArgument(1); |
| 840 | + listener.onResponse(response); |
| 841 | + return null; |
| 842 | + }).when(client).get(any(), any()); |
| 843 | + |
| 844 | + ActionListener<List<Interaction>> getListener = mock(ActionListener.class); |
| 845 | + interactionsIndex.getTraces("iid", 0, 10, getListener); |
| 846 | + ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class); |
| 847 | + verify(getListener, times(1)).onFailure(argCaptor.capture()); |
| 848 | + assert (argCaptor.getValue().getMessage().equals("User [Xun] does not have access to interaction iid")); |
| 849 | + } |
| 850 | + |
| 851 | + public void testUpdateInteraction_NoAccess_ThenFail() { |
| 852 | + doReturn(true).when(metadata).hasIndex(anyString()); |
| 853 | + setupRefreshSuccess(); |
| 854 | + setupDenyAccess("Xun"); |
| 855 | + GetResponse response = setUpInteractionResponse("iid"); |
| 856 | + doAnswer(invocation -> { |
| 857 | + ActionListener<GetResponse> listener = invocation.getArgument(1); |
| 858 | + listener.onResponse(response); |
| 859 | + return null; |
| 860 | + }).when(client).get(any(), any()); |
| 861 | + |
| 862 | + ActionListener<UpdateResponse> updateListener = mock(ActionListener.class); |
| 863 | + interactionsIndex.updateInteraction("iid", new UpdateRequest(), updateListener); |
| 864 | + ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class); |
| 865 | + verify(updateListener, times(1)).onFailure(argCaptor.capture()); |
| 866 | + assert (argCaptor.getValue().getMessage().equals("User [Xun] does not have access to interaction iid")); |
| 867 | + } |
| 868 | + |
| 869 | + public void testUpdateInteraction_Success() { |
| 870 | + doReturn(true).when(metadata).hasIndex(anyString()); |
| 871 | + setupRefreshSuccess(); |
| 872 | + setupGrantAccess(); |
| 873 | + GetResponse response = setUpInteractionResponse("iid"); |
| 874 | + doAnswer(invocation -> { |
| 875 | + ActionListener<GetResponse> listener = invocation.getArgument(1); |
| 876 | + listener.onResponse(response); |
| 877 | + return null; |
| 878 | + }).when(client).get(any(), any()); |
| 879 | + |
| 880 | + doAnswer(invocation -> { |
| 881 | + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); |
| 882 | + UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); |
| 883 | + ActionListener<UpdateResponse> listener = invocation.getArgument(1); |
| 884 | + listener.onResponse(updateResponse); |
| 885 | + return null; |
| 886 | + }).when(client).update(any(), any()); |
| 887 | + |
| 888 | + ActionListener<UpdateResponse> updateListener = mock(ActionListener.class); |
| 889 | + interactionsIndex.updateInteraction("iid", new UpdateRequest(), updateListener); |
| 890 | + ArgumentCaptor<UpdateResponse> argCaptor = ArgumentCaptor.forClass(UpdateResponse.class); |
| 891 | + verify(updateListener, times(1)).onResponse(argCaptor.capture()); |
| 892 | + } |
| 893 | + |
| 894 | + private GetResponse setUpInteractionResponse(String interactionId) { |
| 895 | + @SuppressWarnings("unchecked") |
| 896 | + GetResponse response = mock(GetResponse.class); |
| 897 | + doReturn(true).when(response).isExists(); |
| 898 | + doReturn(interactionId).when(response).getId(); |
| 899 | + doReturn( |
| 900 | + Map |
| 901 | + .of( |
| 902 | + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, |
| 903 | + Instant.now().toString(), |
| 904 | + ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, |
| 905 | + "conversation test 1", |
| 906 | + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, |
| 907 | + "answer1" |
| 908 | + ) |
| 909 | + ).when(response).getSourceAsMap(); |
| 910 | + return response; |
| 911 | + } |
| 912 | + |
| 913 | + private void setUpSearchTraceResponse() { |
| 914 | + doAnswer(invocation -> { |
| 915 | + XContentBuilder content = XContentBuilder.builder(XContentType.JSON.xContent()); |
| 916 | + content.startObject(); |
| 917 | + content.field(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now()); |
| 918 | + content.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "sample inputs"); |
| 919 | + content.field(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, "conversation-id"); |
| 920 | + content.endObject(); |
| 921 | + |
| 922 | + SearchHit[] hits = new SearchHit[1]; |
| 923 | + hits[0] = new SearchHit(0, "iId", null, null).sourceRef(BytesReference.bytes(content)); |
| 924 | + SearchHits searchHits = new SearchHits(hits, null, Float.NaN); |
| 925 | + SearchResponseSections searchSections = new SearchResponseSections( |
| 926 | + searchHits, |
| 927 | + InternalAggregations.EMPTY, |
| 928 | + null, |
| 929 | + false, |
| 930 | + false, |
| 931 | + null, |
| 932 | + 1 |
| 933 | + ); |
| 934 | + SearchResponse searchResponse = new SearchResponse( |
| 935 | + searchSections, |
| 936 | + null, |
| 937 | + 1, |
| 938 | + 1, |
| 939 | + 0, |
| 940 | + 11, |
| 941 | + ShardSearchFailure.EMPTY_ARRAY, |
| 942 | + SearchResponse.Clusters.EMPTY |
| 943 | + ); |
| 944 | + ActionListener<SearchResponse> al = invocation.getArgument(1); |
| 945 | + al.onResponse(searchResponse); |
| 946 | + return null; |
| 947 | + }).when(client).search(any(), any()); |
| 948 | + } |
803 | 949 | }
|
0 commit comments