Skip to content

Commit 5c68b12

Browse files
* Expose ML Config API Signed-off-by: Ashish Agrawal <ashisagr@amazon.com> * Add tests for rejected master key Signed-off-by: Ashish Agrawal <ashisagr@amazon.com> --------- Signed-off-by: Ashish Agrawal <ashisagr@amazon.com> (cherry picked from commit 05eb53f) Co-authored-by: Ashish Agrawal <ashish81394@gmail.com>
1 parent ab2fae3 commit 5c68b12

File tree

7 files changed

+148
-0
lines changed

7 files changed

+148
-0
lines changed

client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java

+17
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.opensearch.common.action.ActionFuture;
1717
import org.opensearch.core.action.ActionListener;
1818
import org.opensearch.ml.common.FunctionName;
19+
import org.opensearch.ml.common.MLConfig;
1920
import org.opensearch.ml.common.MLModel;
2021
import org.opensearch.ml.common.MLTask;
2122
import org.opensearch.ml.common.ToolMetadata;
@@ -428,4 +429,20 @@ default ActionFuture<ToolMetadata> getTool(String toolName) {
428429
*/
429430
void getTool(String toolName, ActionListener<ToolMetadata> listener);
430431

432+
/**
433+
* Get config
434+
* @param configId ML config id
435+
*/
436+
default ActionFuture<MLConfig> getConfig(String configId) {
437+
PlainActionFuture<MLConfig> actionFuture = PlainActionFuture.newFuture();
438+
getConfig(configId, actionFuture);
439+
return actionFuture;
440+
}
441+
442+
/**
443+
* Get config
444+
* @param configId ML config id
445+
* @param listener a listener to be notified of the result
446+
*/
447+
void getConfig(String configId, ActionListener<MLConfig> listener);
431448
}

client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java

+22
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.opensearch.core.action.ActionListener;
2626
import org.opensearch.core.action.ActionResponse;
2727
import org.opensearch.ml.common.FunctionName;
28+
import org.opensearch.ml.common.MLConfig;
2829
import org.opensearch.ml.common.MLModel;
2930
import org.opensearch.ml.common.MLTask;
3031
import org.opensearch.ml.common.ToolMetadata;
@@ -39,6 +40,9 @@
3940
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
4041
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
4142
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
43+
import org.opensearch.ml.common.transport.config.MLConfigGetAction;
44+
import org.opensearch.ml.common.transport.config.MLConfigGetRequest;
45+
import org.opensearch.ml.common.transport.config.MLConfigGetResponse;
4246
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction;
4347
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest;
4448
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
@@ -309,6 +313,13 @@ public void getTool(String toolName, ActionListener<ToolMetadata> listener) {
309313
client.execute(MLGetToolAction.INSTANCE, mlToolGetRequest, getMlGetToolResponseActionListener(listener));
310314
}
311315

316+
@Override
317+
public void getConfig(String configId, ActionListener<MLConfig> listener) {
318+
MLConfigGetRequest mlConfigGetRequest = MLConfigGetRequest.builder().configId(configId).build();
319+
320+
client.execute(MLConfigGetAction.INSTANCE, mlConfigGetRequest, getMlGetConfigResponseActionListener(listener));
321+
}
322+
312323
private ActionListener<MLToolsListResponse> getMlListToolsResponseActionListener(ActionListener<List<ToolMetadata>> listener) {
313324
ActionListener<MLToolsListResponse> internalListener = ActionListener.wrap(mlModelListResponse -> {
314325
listener.onResponse(mlModelListResponse.getToolMetadataList());
@@ -331,6 +342,17 @@ private ActionListener<MLToolGetResponse> getMlGetToolResponseActionListener(Act
331342
return actionListener;
332343
}
333344

345+
private ActionListener<MLConfigGetResponse> getMlGetConfigResponseActionListener(ActionListener<MLConfig> listener) {
346+
ActionListener<MLConfigGetResponse> internalListener = ActionListener.wrap(mlConfigGetResponse -> {
347+
listener.onResponse(mlConfigGetResponse.getMlConfig());
348+
}, listener::onFailure);
349+
ActionListener<MLConfigGetResponse> actionListener = wrapActionListener(internalListener, res -> {
350+
MLConfigGetResponse getResponse = MLConfigGetResponse.fromActionResponse(res);
351+
return getResponse;
352+
});
353+
return actionListener;
354+
}
355+
334356
private ActionListener<MLRegisterAgentResponse> getMLRegisterAgentResponseActionListener(
335357
ActionListener<MLRegisterAgentResponse> listener
336358
) {

client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java

+26
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import static org.opensearch.ml.common.input.Constants.KMEANS;
1313
import static org.opensearch.ml.common.input.Constants.TRAIN;
1414

15+
import java.time.Instant;
1516
import java.util.ArrayList;
1617
import java.util.Arrays;
1718
import java.util.HashMap;
@@ -28,8 +29,10 @@
2829
import org.opensearch.action.search.SearchResponse;
2930
import org.opensearch.core.action.ActionListener;
3031
import org.opensearch.ml.common.AccessMode;
32+
import org.opensearch.ml.common.Configuration;
3133
import org.opensearch.ml.common.FunctionName;
3234
import org.opensearch.ml.common.MLAgentType;
35+
import org.opensearch.ml.common.MLConfig;
3336
import org.opensearch.ml.common.MLModel;
3437
import org.opensearch.ml.common.MLTask;
3538
import org.opensearch.ml.common.ToolMetadata;
@@ -46,6 +49,7 @@
4649
import org.opensearch.ml.common.output.MLOutput;
4750
import org.opensearch.ml.common.output.MLTrainingOutput;
4851
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
52+
import org.opensearch.ml.common.transport.config.MLConfigGetResponse;
4953
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
5054
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
5155
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
@@ -99,9 +103,13 @@ public class MachineLearningClientTest {
99103
@Mock
100104
MLRegisterAgentResponse registerAgentResponse;
101105

106+
@Mock
107+
MLConfigGetResponse configGetResponse;
108+
102109
private String modekId = "test_model_id";
103110
private MLModel mlModel;
104111
private MLTask mlTask;
112+
private MLConfig mlConfig;
105113
private ToolMetadata toolMetadata;
106114
private List<ToolMetadata> toolsList = new ArrayList<>();
107115

@@ -124,6 +132,14 @@ public void setUp() {
124132
.build();
125133
toolsList.add(toolMetadata);
126134

135+
mlConfig = MLConfig
136+
.builder()
137+
.type("dummyType")
138+
.configuration(Configuration.builder().agentId("agentId").build())
139+
.createTime(Instant.now())
140+
.lastUpdateTime(Instant.now())
141+
.build();
142+
127143
machineLearningClient = new MachineLearningClient() {
128144
@Override
129145
public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
@@ -231,6 +247,11 @@ public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentRespons
231247
public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
232248
listener.onResponse(deleteResponse);
233249
}
250+
251+
@Override
252+
public void getConfig(String configId, ActionListener<MLConfig> listener) {
253+
listener.onResponse(mlConfig);
254+
}
234255
};
235256
}
236257

@@ -503,4 +524,9 @@ public void getTool() {
503524
public void listTools() {
504525
assertEquals(toolMetadata, machineLearningClient.listTools().actionGet().get(0));
505526
}
527+
528+
@Test
529+
public void getConfig() {
530+
assertEquals(mlConfig, machineLearningClient.getConfig("configId").actionGet());
531+
}
506532
}

client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java

+48
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import static org.mockito.ArgumentMatchers.isA;
1515
import static org.mockito.Mockito.doAnswer;
1616
import static org.mockito.Mockito.verify;
17+
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
1718
import static org.opensearch.ml.common.input.Constants.ACTION;
1819
import static org.opensearch.ml.common.input.Constants.ALGORITHM;
1920
import static org.opensearch.ml.common.input.Constants.KMEANS;
@@ -40,6 +41,7 @@
4041
import org.mockito.InjectMocks;
4142
import org.mockito.Mock;
4243
import org.mockito.MockitoAnnotations;
44+
import org.opensearch.OpenSearchStatusException;
4345
import org.opensearch.action.delete.DeleteResponse;
4446
import org.opensearch.action.search.SearchRequest;
4547
import org.opensearch.action.search.SearchResponse;
@@ -51,12 +53,15 @@
5153
import org.opensearch.core.common.bytes.BytesReference;
5254
import org.opensearch.core.index.Index;
5355
import org.opensearch.core.index.shard.ShardId;
56+
import org.opensearch.core.rest.RestStatus;
5457
import org.opensearch.core.xcontent.ToXContent;
5558
import org.opensearch.core.xcontent.ToXContentObject;
5659
import org.opensearch.core.xcontent.XContentBuilder;
5760
import org.opensearch.ml.common.AccessMode;
61+
import org.opensearch.ml.common.Configuration;
5862
import org.opensearch.ml.common.FunctionName;
5963
import org.opensearch.ml.common.MLAgentType;
64+
import org.opensearch.ml.common.MLConfig;
6065
import org.opensearch.ml.common.MLModel;
6166
import org.opensearch.ml.common.MLTask;
6267
import org.opensearch.ml.common.MLTaskState;
@@ -84,6 +89,9 @@
8489
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
8590
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
8691
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
92+
import org.opensearch.ml.common.transport.config.MLConfigGetAction;
93+
import org.opensearch.ml.common.transport.config.MLConfigGetRequest;
94+
import org.opensearch.ml.common.transport.config.MLConfigGetResponse;
8795
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction;
8896
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest;
8997
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
@@ -206,6 +214,9 @@ public class MachineLearningNodeClientTest {
206214
@Mock
207215
ActionListener<ToolMetadata> getToolActionListener;
208216

217+
@Mock
218+
ActionListener<MLConfig> getMlConfigListener;
219+
209220
@InjectMocks
210221
MachineLearningNodeClient machineLearningNodeClient;
211222

@@ -951,6 +962,43 @@ public void listTools() {
951962
assertEquals("Use this tool to search general knowledge on wikipedia.", argumentCaptor.getValue().get(0).getDescription());
952963
}
953964

965+
@Test
966+
public void getConfig() {
967+
MLConfig mlConfig = MLConfig.builder().type("type").configuration(Configuration.builder().agentId("agentId").build()).build();
968+
969+
doAnswer(invocation -> {
970+
ActionListener<MLConfigGetResponse> actionListener = invocation.getArgument(2);
971+
MLConfigGetResponse output = MLConfigGetResponse.builder().mlConfig(mlConfig).build();
972+
actionListener.onResponse(output);
973+
return null;
974+
}).when(client).execute(eq(MLConfigGetAction.INSTANCE), any(), any());
975+
976+
ArgumentCaptor<MLConfig> argumentCaptor = ArgumentCaptor.forClass(MLConfig.class);
977+
machineLearningNodeClient.getConfig("agentId", getMlConfigListener);
978+
979+
verify(client).execute(eq(MLConfigGetAction.INSTANCE), isA(MLConfigGetRequest.class), any());
980+
verify(getMlConfigListener).onResponse(argumentCaptor.capture());
981+
assertEquals("agentId", argumentCaptor.getValue().getConfiguration().getAgentId());
982+
assertEquals("type", argumentCaptor.getValue().getType());
983+
}
984+
985+
@Test
986+
public void getConfigRejectedMasterKey() {
987+
doAnswer(invocation -> {
988+
ActionListener<MLConfigGetResponse> actionListener = invocation.getArgument(2);
989+
actionListener.onFailure(new OpenSearchStatusException("You are not allowed to access this config doc", RestStatus.FORBIDDEN));
990+
return null;
991+
}).when(client).execute(eq(MLConfigGetAction.INSTANCE), any(), any());
992+
993+
ArgumentCaptor<OpenSearchStatusException> argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
994+
machineLearningNodeClient.getConfig(MASTER_KEY, getMlConfigListener);
995+
996+
verify(client).execute(eq(MLConfigGetAction.INSTANCE), isA(MLConfigGetRequest.class), any());
997+
verify(getMlConfigListener).onFailure(argumentCaptor.capture());
998+
assertEquals(RestStatus.FORBIDDEN, argumentCaptor.getValue().status());
999+
assertEquals("You are not allowed to access this config doc", argumentCaptor.getValue().getLocalizedMessage());
1000+
}
1001+
9541002
private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
9551003
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
9561004

common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetResponse.java

+2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
import org.opensearch.ml.common.MLConfig;
2121

2222
import lombok.Builder;
23+
import lombok.Getter;
2324

25+
@Getter
2426
public class MLConfigGetResponse extends ActionResponse implements ToXContentObject {
2527
MLConfig mlConfig;
2628

plugin/src/main/java/org/opensearch/ml/action/config/GetConfigTransportAction.java

+6
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.opensearch.ml.action.config;
77

88
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
9+
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
910
import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX;
1011
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;
1112

@@ -58,6 +59,11 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLConf
5859
String configId = mlConfigGetRequest.getConfigId();
5960
GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(configId);
6061

62+
if (configId.equals(MASTER_KEY)) {
63+
actionListener.onFailure(new OpenSearchStatusException("You are not allowed to access this config doc", RestStatus.FORBIDDEN));
64+
return;
65+
}
66+
6167
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
6268
client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> {
6369
log.debug("Completed Get Agent Request, id:{}", configId);

plugin/src/test/java/org/opensearch/ml/action/config/GetConfigTransportActionTests.java

+27
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55

66
package org.opensearch.ml.action.config;
77

8+
import static org.junit.Assert.assertEquals;
89
import static org.mockito.ArgumentMatchers.any;
910
import static org.mockito.Mockito.doAnswer;
1011
import static org.mockito.Mockito.mock;
1112
import static org.mockito.Mockito.spy;
1213
import static org.mockito.Mockito.verify;
1314
import static org.mockito.Mockito.when;
15+
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
1416

1517
import java.io.IOException;
1618
import java.time.Instant;
@@ -22,6 +24,7 @@
2224
import org.mockito.ArgumentCaptor;
2325
import org.mockito.Mock;
2426
import org.mockito.MockitoAnnotations;
27+
import org.opensearch.OpenSearchStatusException;
2528
import org.opensearch.action.get.GetResponse;
2629
import org.opensearch.action.support.ActionFilters;
2730
import org.opensearch.client.Client;
@@ -30,6 +33,7 @@
3033
import org.opensearch.common.xcontent.XContentFactory;
3134
import org.opensearch.core.action.ActionListener;
3235
import org.opensearch.core.common.bytes.BytesReference;
36+
import org.opensearch.core.rest.RestStatus;
3337
import org.opensearch.core.xcontent.NamedXContentRegistry;
3438
import org.opensearch.core.xcontent.ToXContent;
3539
import org.opensearch.core.xcontent.XContentBuilder;
@@ -168,4 +172,27 @@ public GetResponse prepareMLConfig(String configID) throws IOException {
168172
GetResponse getResponse = new GetResponse(getResult);
169173
return getResponse;
170174
}
175+
176+
@Test
177+
public void testDoExecute_Rejected_MASTER_KEY() throws IOException {
178+
String configID = MASTER_KEY;
179+
GetResponse getResponse = prepareMLConfig(configID);
180+
ActionListener<MLConfigGetResponse> actionListener = mock(ActionListener.class);
181+
MLConfigGetRequest request = new MLConfigGetRequest(configID);
182+
Task task = mock(Task.class);
183+
184+
doAnswer(invocation -> {
185+
ActionListener<GetResponse> listener = invocation.getArgument(1);
186+
listener.onResponse(getResponse);
187+
return null;
188+
}).when(client).get(any(), any());
189+
190+
ArgumentCaptor<OpenSearchStatusException> argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
191+
192+
getConfigTransportAction.doExecute(task, request, actionListener);
193+
verify(actionListener).onFailure(argumentCaptor.capture());
194+
assertEquals(RestStatus.FORBIDDEN, argumentCaptor.getValue().status());
195+
assertEquals("You are not allowed to access this config doc", argumentCaptor.getValue().getLocalizedMessage());
196+
197+
}
171198
}

0 commit comments

Comments
 (0)