|
14 | 14 | import static org.mockito.ArgumentMatchers.isA;
|
15 | 15 | import static org.mockito.Mockito.doAnswer;
|
16 | 16 | import static org.mockito.Mockito.verify;
|
| 17 | +import static org.opensearch.ml.common.CommonValue.MASTER_KEY; |
17 | 18 | import static org.opensearch.ml.common.input.Constants.ACTION;
|
18 | 19 | import static org.opensearch.ml.common.input.Constants.ALGORITHM;
|
19 | 20 | import static org.opensearch.ml.common.input.Constants.KMEANS;
|
|
40 | 41 | import org.mockito.InjectMocks;
|
41 | 42 | import org.mockito.Mock;
|
42 | 43 | import org.mockito.MockitoAnnotations;
|
| 44 | +import org.opensearch.OpenSearchStatusException; |
43 | 45 | import org.opensearch.action.delete.DeleteResponse;
|
44 | 46 | import org.opensearch.action.search.SearchRequest;
|
45 | 47 | import org.opensearch.action.search.SearchResponse;
|
|
51 | 53 | import org.opensearch.core.common.bytes.BytesReference;
|
52 | 54 | import org.opensearch.core.index.Index;
|
53 | 55 | import org.opensearch.core.index.shard.ShardId;
|
| 56 | +import org.opensearch.core.rest.RestStatus; |
54 | 57 | import org.opensearch.core.xcontent.ToXContent;
|
55 | 58 | import org.opensearch.core.xcontent.ToXContentObject;
|
56 | 59 | import org.opensearch.core.xcontent.XContentBuilder;
|
57 | 60 | import org.opensearch.ml.common.AccessMode;
|
| 61 | +import org.opensearch.ml.common.Configuration; |
58 | 62 | import org.opensearch.ml.common.FunctionName;
|
59 | 63 | import org.opensearch.ml.common.MLAgentType;
|
| 64 | +import org.opensearch.ml.common.MLConfig; |
60 | 65 | import org.opensearch.ml.common.MLModel;
|
61 | 66 | import org.opensearch.ml.common.MLTask;
|
62 | 67 | import org.opensearch.ml.common.MLTaskState;
|
|
84 | 89 | import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
|
85 | 90 | import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
|
86 | 91 | import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
|
| 92 | +import org.opensearch.ml.common.transport.config.MLConfigGetAction; |
| 93 | +import org.opensearch.ml.common.transport.config.MLConfigGetRequest; |
| 94 | +import org.opensearch.ml.common.transport.config.MLConfigGetResponse; |
87 | 95 | import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction;
|
88 | 96 | import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest;
|
89 | 97 | import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
|
@@ -206,6 +214,9 @@ public class MachineLearningNodeClientTest {
|
206 | 214 | @Mock
|
207 | 215 | ActionListener<ToolMetadata> getToolActionListener;
|
208 | 216 |
|
| 217 | + @Mock |
| 218 | + ActionListener<MLConfig> getMlConfigListener; |
| 219 | + |
209 | 220 | @InjectMocks
|
210 | 221 | MachineLearningNodeClient machineLearningNodeClient;
|
211 | 222 |
|
@@ -951,6 +962,43 @@ public void listTools() {
|
951 | 962 | assertEquals("Use this tool to search general knowledge on wikipedia.", argumentCaptor.getValue().get(0).getDescription());
|
952 | 963 | }
|
953 | 964 |
|
| 965 | + @Test |
| 966 | + public void getConfig() { |
| 967 | + MLConfig mlConfig = MLConfig.builder().type("type").configuration(Configuration.builder().agentId("agentId").build()).build(); |
| 968 | + |
| 969 | + doAnswer(invocation -> { |
| 970 | + ActionListener<MLConfigGetResponse> actionListener = invocation.getArgument(2); |
| 971 | + MLConfigGetResponse output = MLConfigGetResponse.builder().mlConfig(mlConfig).build(); |
| 972 | + actionListener.onResponse(output); |
| 973 | + return null; |
| 974 | + }).when(client).execute(eq(MLConfigGetAction.INSTANCE), any(), any()); |
| 975 | + |
| 976 | + ArgumentCaptor<MLConfig> argumentCaptor = ArgumentCaptor.forClass(MLConfig.class); |
| 977 | + machineLearningNodeClient.getConfig("agentId", getMlConfigListener); |
| 978 | + |
| 979 | + verify(client).execute(eq(MLConfigGetAction.INSTANCE), isA(MLConfigGetRequest.class), any()); |
| 980 | + verify(getMlConfigListener).onResponse(argumentCaptor.capture()); |
| 981 | + assertEquals("agentId", argumentCaptor.getValue().getConfiguration().getAgentId()); |
| 982 | + assertEquals("type", argumentCaptor.getValue().getType()); |
| 983 | + } |
| 984 | + |
| 985 | + @Test |
| 986 | + public void getConfigRejectedMasterKey() { |
| 987 | + doAnswer(invocation -> { |
| 988 | + ActionListener<MLConfigGetResponse> actionListener = invocation.getArgument(2); |
| 989 | + actionListener.onFailure(new OpenSearchStatusException("You are not allowed to access this config doc", RestStatus.FORBIDDEN)); |
| 990 | + return null; |
| 991 | + }).when(client).execute(eq(MLConfigGetAction.INSTANCE), any(), any()); |
| 992 | + |
| 993 | + ArgumentCaptor<OpenSearchStatusException> argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); |
| 994 | + machineLearningNodeClient.getConfig(MASTER_KEY, getMlConfigListener); |
| 995 | + |
| 996 | + verify(client).execute(eq(MLConfigGetAction.INSTANCE), isA(MLConfigGetRequest.class), any()); |
| 997 | + verify(getMlConfigListener).onFailure(argumentCaptor.capture()); |
| 998 | + assertEquals(RestStatus.FORBIDDEN, argumentCaptor.getValue().status()); |
| 999 | + assertEquals("You are not allowed to access this config doc", argumentCaptor.getValue().getLocalizedMessage()); |
| 1000 | + } |
| 1001 | + |
954 | 1002 | private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
|
955 | 1003 | XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
|
956 | 1004 |
|
|
0 commit comments