Skip to content

Commit e25134c

Browse files
committed
add beckrock url in the allowed list and more UTs
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 57051bd commit e25134c

File tree

4 files changed

+78
-5
lines changed

4 files changed

+78
-5
lines changed

ml-algorithms/src/test/java/org/opensearch/ml/engine/ingest/AbstractIngestionTests.java

+29
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,35 @@ public void testFilterFieldMapping_MatchingPrefix() {
204204
assertEquals(Arrays.asList("$.custom_id"), result.get("_id"));
205205
}
206206

207+
@Test
208+
public void testFilterFieldMappingSoleSource_MatchingPrefix() {
209+
// Arrange
210+
Map<String, Object> fieldMap = new HashMap<>();
211+
fieldMap.put("question", "source[0].$.body.input[0]");
212+
fieldMap.put("question_embedding", "source[0].$.response.body.data[0].embedding");
213+
fieldMap.put("answer", "source[0].$.body.input[1]");
214+
fieldMap.put("answer_embedding", "$.response.body.data[1].embedding");
215+
fieldMap.put("_id", Arrays.asList("$.custom_id", "source[1].$.custom_id"));
216+
217+
MLBatchIngestionInput mlBatchIngestionInput = new MLBatchIngestionInput(
218+
"indexName",
219+
fieldMap,
220+
ingestFields,
221+
new HashMap<>(),
222+
new HashMap<>()
223+
);
224+
225+
// Act
226+
Map<String, Object> result = s3DataIngestion.filterFieldMappingSoleSource(mlBatchIngestionInput);
227+
228+
// Assert
229+
assertEquals(6, result.size());
230+
231+
assertEquals("$.body.input[0]", result.get("question"));
232+
assertEquals("$.response.body.data[0].embedding", result.get("question_embedding"));
233+
assertEquals(Arrays.asList("$.custom_id"), result.get("_id"));
234+
}
235+
207236
@Test
208237
public void testProcessFieldMapping_FromSM() {
209238
String jsonStr =

plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ private MLCommonsSettings() {}
146146
"^https://api\\.openai\\.com/.*$",
147147
"^https://api\\.cohere\\.ai/.*$",
148148
"^https://bedrock-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
149-
"^https://bedrock-agent-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$"
149+
"^https://bedrock-agent-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
150+
"^https://bedrock\\..*[a-z0-9-]\\.amazonaws\\.com/.*$"
150151
),
151152
Function.identity(),
152153
Setting.Property.NodeScope,

plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java

+44
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,29 @@
66
package org.opensearch.ml.action.batch;
77

88
import static org.mockito.ArgumentMatchers.any;
9+
import static org.mockito.ArgumentMatchers.anyBoolean;
10+
import static org.mockito.ArgumentMatchers.anyLong;
11+
import static org.mockito.ArgumentMatchers.anyString;
912
import static org.mockito.ArgumentMatchers.isA;
1013
import static org.mockito.Mockito.doAnswer;
14+
import static org.mockito.Mockito.doReturn;
1115
import static org.mockito.Mockito.doThrow;
16+
import static org.mockito.Mockito.never;
1217
import static org.mockito.Mockito.verify;
1318
import static org.mockito.Mockito.when;
1419
import static org.opensearch.ml.common.MLTask.ERROR_FIELD;
1520
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
1621
import static org.opensearch.ml.common.MLTaskState.COMPLETED;
1722
import static org.opensearch.ml.common.MLTaskState.FAILED;
1823
import static org.opensearch.ml.engine.ingest.S3DataIngestion.SOURCE;
24+
import static org.opensearch.ml.plugin.MachineLearningPlugin.INGEST_THREAD_POOL;
1925
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
2026

2127
import java.util.ArrayList;
2228
import java.util.Arrays;
2329
import java.util.HashMap;
2430
import java.util.Map;
31+
import java.util.concurrent.ExecutorService;
2532

2633
import org.junit.Before;
2734
import org.mockito.ArgumentCaptor;
@@ -45,6 +52,8 @@
4552
import org.opensearch.threadpool.ThreadPool;
4653
import org.opensearch.transport.TransportService;
4754

55+
import com.jayway.jsonpath.PathNotFoundException;
56+
4857
public class TransportBatchIngestionActionTests extends OpenSearchTestCase {
4958
@Mock
5059
private Client client;
@@ -62,6 +71,8 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase {
6271
ActionListener<MLBatchIngestionResponse> actionListener;
6372
@Mock
6473
ThreadPool threadPool;
74+
@Mock
75+
ExecutorService executorService;
6576

6677
private TransportBatchIngestionAction batchAction;
6778
private MLBatchIngestionInput batchInput;
@@ -105,9 +116,42 @@ public void test_doExecute_success() {
105116
listener.onResponse(indexResponse);
106117
return null;
107118
}).when(mlTaskManager).createMLTask(isA(MLTask.class), isA(ActionListener.class));
119+
doReturn(executorService).when(threadPool).executor(INGEST_THREAD_POOL);
120+
doAnswer(invocation -> {
121+
Runnable runnable = invocation.getArgument(0);
122+
runnable.run();
123+
return null;
124+
}).when(executorService).execute(any(Runnable.class));
125+
108126
batchAction.doExecute(task, mlBatchIngestionRequest, actionListener);
109127

110128
verify(actionListener).onResponse(any(MLBatchIngestionResponse.class));
129+
verify(threadPool).executor(INGEST_THREAD_POOL);
130+
}
131+
132+
public void test_doExecute_ExecuteWithNoErrorHandling() {
133+
batchAction.executeWithErrorHandling(() -> {}, "taskId");
134+
135+
verify(mlTaskManager, never()).updateMLTask(anyString(), isA(Map.class), anyLong(), anyBoolean());
136+
}
137+
138+
public void test_doExecute_ExecuteWithPathNotFoundException() {
139+
batchAction.executeWithErrorHandling(() -> { throw new PathNotFoundException("jsonPath not found!"); }, "taskId");
140+
141+
verify(mlTaskManager)
142+
.updateMLTask("taskId", Map.of(STATE_FIELD, FAILED, ERROR_FIELD, "jsonPath not found!"), TASK_SEMAPHORE_TIMEOUT, true);
143+
}
144+
145+
public void test_doExecute_RuntimeException() {
146+
batchAction.executeWithErrorHandling(() -> { throw new RuntimeException("runtime exception in the ingestion!"); }, "taskId");
147+
148+
verify(mlTaskManager)
149+
.updateMLTask(
150+
"taskId",
151+
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, "runtime exception in the ingestion!"),
152+
TASK_SEMAPHORE_TIMEOUT,
153+
true
154+
);
111155
}
112156

113157
public void test_doExecute_handleSuccessRate100() {

plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java

+3-4
Original file line numberDiff line numberDiff line change
@@ -447,10 +447,9 @@ public void testValidateBatchPredictionSuccess() throws IOException {
447447
"output",
448448
"{\"properties\":{\"inference_results\":{\"description\":\"This is a test description field\"," + "\"type\":\"array\"}}}"
449449
);
450-
ModelTensorOutput modelTensorOutput = ModelTensorOutput
451-
.builder()
452-
.mlModelOutputs(List.of(ModelTensors.builder().mlModelTensors(List.of(modelTensor)).build()))
453-
.build();
450+
ModelTensors modelTensors = ModelTensors.builder().statusCode(200).mlModelTensors(List.of(modelTensor)).statusCode(200).build();
451+
modelTensors.setStatusCode(200);
452+
ModelTensorOutput modelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(List.of(modelTensors)).build();
454453
doAnswer(invocation -> {
455454
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(1);
456455
actionListener.onResponse(MLTaskResponse.builder().output(modelTensorOutput).build());

0 commit comments

Comments
 (0)