Skip to content

Commit 38bba02

Browse files
committed
add UT for acess denied cases
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 1f6c183 commit 38bba02

File tree

2 files changed

+196
-35
lines changed

2 files changed

+196
-35
lines changed

memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java

+15
Original file line numberDiff line numberDiff line change
@@ -701,4 +701,19 @@ public void testUpdateConversation_ClientFails() {
701701
verify(getListener, times(1)).onFailure(argCaptor.capture());
702702
assert (argCaptor.getValue().getMessage().equals("Client Failure"));
703703
}
704+
705+
public void testUpdateConversation_NoAccess_ThenFail() {
706+
doReturn(true).when(metadata).hasIndex(anyString());
707+
doAnswer(invocation -> {
708+
ActionListener<Boolean> al = invocation.getArgument(1);
709+
al.onResponse(false);
710+
return null;
711+
}).when(conversationMetaIndex).checkAccess(anyString(), any());
712+
713+
ActionListener<UpdateResponse> updateListener = mock(ActionListener.class);
714+
conversationMetaIndex.updateConversation("conversationId", new UpdateRequest(), updateListener);
715+
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
716+
verify(updateListener, times(1)).onFailure(argCaptor.capture());
717+
assert (argCaptor.getValue().getMessage().equals("User [BAD_USER] does not have access to conversation conversationId"));
718+
}
704719
}

memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java

+181-35
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,14 @@
3232
import java.time.Instant;
3333
import java.util.Collections;
3434
import java.util.List;
35+
import java.util.Map;
3536

3637
import org.junit.Before;
3738
import org.mockito.ArgumentCaptor;
3839
import org.mockito.Mock;
3940
import org.opensearch.OpenSearchWrapperException;
4041
import org.opensearch.ResourceAlreadyExistsException;
42+
import org.opensearch.action.DocWriteResponse;
4143
import org.opensearch.action.admin.indices.create.CreateIndexResponse;
4244
import org.opensearch.action.admin.indices.refresh.RefreshResponse;
4345
import org.opensearch.action.bulk.BulkResponse;
@@ -47,6 +49,8 @@
4749
import org.opensearch.action.search.SearchResponse;
4850
import org.opensearch.action.search.SearchResponseSections;
4951
import org.opensearch.action.search.ShardSearchFailure;
52+
import org.opensearch.action.update.UpdateRequest;
53+
import org.opensearch.action.update.UpdateResponse;
5054
import org.opensearch.client.AdminClient;
5155
import org.opensearch.client.Client;
5256
import org.opensearch.client.IndicesAdminClient;
@@ -59,6 +63,8 @@
5963
import org.opensearch.commons.ConfigConstants;
6064
import org.opensearch.core.action.ActionListener;
6165
import org.opensearch.core.common.bytes.BytesReference;
66+
import org.opensearch.core.index.Index;
67+
import org.opensearch.core.index.shard.ShardId;
6268
import org.opensearch.core.rest.RestStatus;
6369
import org.opensearch.core.xcontent.XContentBuilder;
6470
import org.opensearch.index.query.MatchAllQueryBuilder;
@@ -441,46 +447,34 @@ public void testGetTraces_NoIndex_ThenEmpty() {
441447
assert (argCaptor.getValue().size() == 0);
442448
}
443449

444-
public void testGetTraces() {
445-
doAnswer(invocation -> {
446-
XContentBuilder content = XContentBuilder.builder(XContentType.JSON.xContent());
447-
content.startObject();
448-
content.field(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now());
449-
content.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "sample inputs");
450-
content.field(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, "conversation-id");
451-
content.endObject();
450+
public void testInnerGetTraces_success() {
451+
setUpSearchTraceResponse();
452+
doReturn(true).when(metadata).hasIndex(anyString());
453+
@SuppressWarnings("unchecked")
454+
ActionListener<List<Interaction>> getTracesListener = mock(ActionListener.class);
455+
interactionsIndex.innerGetTraces("cid", 0, 10, getTracesListener);
456+
@SuppressWarnings("unchecked")
457+
ArgumentCaptor<List<Interaction>> argCaptor = ArgumentCaptor.forClass(List.class);
458+
verify(getTracesListener, times(1)).onResponse(argCaptor.capture());
459+
assert (argCaptor.getValue().size() == 1);
460+
}
452461

453-
SearchHit[] hits = new SearchHit[1];
454-
hits[0] = new SearchHit(0, "iId", null, null).sourceRef(BytesReference.bytes(content));
455-
SearchHits searchHits = new SearchHits(hits, null, Float.NaN);
456-
SearchResponseSections searchSections = new SearchResponseSections(
457-
searchHits,
458-
InternalAggregations.EMPTY,
459-
null,
460-
false,
461-
false,
462-
null,
463-
1
464-
);
465-
SearchResponse searchResponse = new SearchResponse(
466-
searchSections,
467-
null,
468-
1,
469-
1,
470-
0,
471-
11,
472-
ShardSearchFailure.EMPTY_ARRAY,
473-
SearchResponse.Clusters.EMPTY
474-
);
475-
ActionListener<SearchResponse> al = invocation.getArgument(1);
476-
al.onResponse(searchResponse);
477-
return null;
478-
}).when(client).search(any(), any());
462+
public void testGetTraces_success() {
463+
setupGrantAccess();
464+
doReturn(true).when(metadata).hasIndex(anyString());
465+
setupRefreshSuccess();
479466

467+
GetResponse response = setUpInteractionResponse("iid");
468+
doAnswer(invocation -> {
469+
ActionListener<GetResponse> listener = invocation.getArgument(1);
470+
listener.onResponse(response);
471+
return null;
472+
}).when(client).get(any(), any());
473+
setUpSearchTraceResponse();
480474
doReturn(true).when(metadata).hasIndex(anyString());
481475
@SuppressWarnings("unchecked")
482476
ActionListener<List<Interaction>> getTracesListener = mock(ActionListener.class);
483-
interactionsIndex.innerGetTraces("cid", 0, 10, getTracesListener);
477+
interactionsIndex.getTraces("iid", 0, 10, getTracesListener);
484478
@SuppressWarnings("unchecked")
485479
ArgumentCaptor<List<Interaction>> argCaptor = ArgumentCaptor.forClass(List.class);
486480
verify(getTracesListener, times(1)).onResponse(argCaptor.capture());
@@ -800,4 +794,156 @@ public void testGetSg_ClientFails_ThenFail() {
800794
verify(getListener, times(1)).onFailure(argCaptor.capture());
801795
assert (argCaptor.getValue().getMessage().equals("Client Failure in Sg Get"));
802796
}
797+
798+
public void testGetSg_NoAccess_ThenFail() {
799+
doReturn(true).when(metadata).hasIndex(anyString());
800+
setupDenyAccess("Henry");
801+
setupRefreshSuccess();
802+
GetResponse response = setUpInteractionResponse("iid");
803+
doAnswer(invocation -> {
804+
ActionListener<GetResponse> listener = invocation.getArgument(1);
805+
listener.onResponse(response);
806+
return null;
807+
}).when(client).get(any(), any());
808+
ActionListener<Interaction> getListener = mock(ActionListener.class);
809+
interactionsIndex.getInteraction("iid", getListener);
810+
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
811+
verify(getListener, times(1)).onFailure(argCaptor.capture());
812+
assert (argCaptor.getValue().getMessage().equals("User [Henry] does not have access to interaction iid"));
813+
}
814+
815+
public void testGetSg_GrantAccess_Success() {
816+
setupGrantAccess();
817+
doReturn(true).when(metadata).hasIndex(anyString());
818+
setupRefreshSuccess();
819+
GetResponse response = setUpInteractionResponse("iid");
820+
doAnswer(invocation -> {
821+
ActionListener<GetResponse> listener = invocation.getArgument(1);
822+
listener.onResponse(response);
823+
return null;
824+
}).when(client).get(any(), any());
825+
ActionListener<Interaction> getListener = mock(ActionListener.class);
826+
interactionsIndex.getInteraction("iid", getListener);
827+
ArgumentCaptor<Interaction> argCaptor = ArgumentCaptor.forClass(Interaction.class);
828+
verify(getListener, times(1)).onResponse(argCaptor.capture());
829+
assert (argCaptor.getValue().getId().equals("iid"));
830+
assert (argCaptor.getValue().getConversationId().equals("conversation test 1"));
831+
}
832+
833+
public void testGetTraces_NoAccess_ThenFail() {
834+
doReturn(true).when(metadata).hasIndex(anyString());
835+
setupRefreshSuccess();
836+
setupDenyAccess("Xun");
837+
GetResponse response = setUpInteractionResponse("iid");
838+
doAnswer(invocation -> {
839+
ActionListener<GetResponse> listener = invocation.getArgument(1);
840+
listener.onResponse(response);
841+
return null;
842+
}).when(client).get(any(), any());
843+
844+
ActionListener<List<Interaction>> getListener = mock(ActionListener.class);
845+
interactionsIndex.getTraces("iid", 0, 10, getListener);
846+
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
847+
verify(getListener, times(1)).onFailure(argCaptor.capture());
848+
assert (argCaptor.getValue().getMessage().equals("User [Xun] does not have access to interaction iid"));
849+
}
850+
851+
public void testUpdateInteraction_NoAccess_ThenFail() {
852+
doReturn(true).when(metadata).hasIndex(anyString());
853+
setupRefreshSuccess();
854+
setupDenyAccess("Xun");
855+
GetResponse response = setUpInteractionResponse("iid");
856+
doAnswer(invocation -> {
857+
ActionListener<GetResponse> listener = invocation.getArgument(1);
858+
listener.onResponse(response);
859+
return null;
860+
}).when(client).get(any(), any());
861+
862+
ActionListener<UpdateResponse> updateListener = mock(ActionListener.class);
863+
interactionsIndex.updateInteraction("iid", new UpdateRequest(), updateListener);
864+
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
865+
verify(updateListener, times(1)).onFailure(argCaptor.capture());
866+
assert (argCaptor.getValue().getMessage().equals("User [Xun] does not have access to interaction iid"));
867+
}
868+
869+
public void testUpdateInteraction_Success() {
870+
doReturn(true).when(metadata).hasIndex(anyString());
871+
setupRefreshSuccess();
872+
setupGrantAccess();
873+
GetResponse response = setUpInteractionResponse("iid");
874+
doAnswer(invocation -> {
875+
ActionListener<GetResponse> listener = invocation.getArgument(1);
876+
listener.onResponse(response);
877+
return null;
878+
}).when(client).get(any(), any());
879+
880+
doAnswer(invocation -> {
881+
ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1);
882+
UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED);
883+
ActionListener<UpdateResponse> listener = invocation.getArgument(1);
884+
listener.onResponse(updateResponse);
885+
return null;
886+
}).when(client).update(any(), any());
887+
888+
ActionListener<UpdateResponse> updateListener = mock(ActionListener.class);
889+
interactionsIndex.updateInteraction("iid", new UpdateRequest(), updateListener);
890+
ArgumentCaptor<UpdateResponse> argCaptor = ArgumentCaptor.forClass(UpdateResponse.class);
891+
verify(updateListener, times(1)).onResponse(argCaptor.capture());
892+
}
893+
894+
private GetResponse setUpInteractionResponse(String interactionId) {
895+
@SuppressWarnings("unchecked")
896+
GetResponse response = mock(GetResponse.class);
897+
doReturn(true).when(response).isExists();
898+
doReturn(interactionId).when(response).getId();
899+
doReturn(
900+
Map
901+
.of(
902+
ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD,
903+
Instant.now().toString(),
904+
ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD,
905+
"conversation test 1",
906+
ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD,
907+
"answer1"
908+
)
909+
).when(response).getSourceAsMap();
910+
return response;
911+
}
912+
913+
private void setUpSearchTraceResponse() {
914+
doAnswer(invocation -> {
915+
XContentBuilder content = XContentBuilder.builder(XContentType.JSON.xContent());
916+
content.startObject();
917+
content.field(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now());
918+
content.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "sample inputs");
919+
content.field(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, "conversation-id");
920+
content.endObject();
921+
922+
SearchHit[] hits = new SearchHit[1];
923+
hits[0] = new SearchHit(0, "iId", null, null).sourceRef(BytesReference.bytes(content));
924+
SearchHits searchHits = new SearchHits(hits, null, Float.NaN);
925+
SearchResponseSections searchSections = new SearchResponseSections(
926+
searchHits,
927+
InternalAggregations.EMPTY,
928+
null,
929+
false,
930+
false,
931+
null,
932+
1
933+
);
934+
SearchResponse searchResponse = new SearchResponse(
935+
searchSections,
936+
null,
937+
1,
938+
1,
939+
0,
940+
11,
941+
ShardSearchFailure.EMPTY_ARRAY,
942+
SearchResponse.Clusters.EMPTY
943+
);
944+
ActionListener<SearchResponse> al = invocation.getArgument(1);
945+
al.onResponse(searchResponse);
946+
return null;
947+
}).when(client).search(any(), any());
948+
}
803949
}

0 commit comments

Comments
 (0)