Skip to content

Commit 95466fd

Browse files
authored
Merge branch 'opensearch-project:main' into main
2 parents 86ae88b + 51bd8cb commit 95466fd

File tree

4 files changed

+167
-145
lines changed

4 files changed

+167
-145
lines changed

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

+35-31
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ public void registerMLRemoteModel(
464464
mlRegisterModelInput.getTenantId(),
465465
new MLResourceNotFoundException("Failed to get model group due to index missing")
466466
);
467-
listener.onFailure(e);
467+
listener.onFailure(new OpenSearchStatusException("Model group not found", RestStatus.NOT_FOUND));
468468
} else {
469469
log.error("Failed to get model group", e);
470470
handleException(mlRegisterModelInput.getFunctionName(), mlTask.getTaskId(), mlRegisterModelInput.getTenantId(), e);
@@ -2077,40 +2077,44 @@ public void getConnector(String connectorId, String tenantId, ActionListener<Con
20772077
.tenantId(tenantId)
20782078
.build();
20792079

2080-
sdkClient.getDataObjectAsync(getDataObjectRequest).whenComplete((r, throwable) -> {
2081-
log.debug("Completed Get Connector Request, id:{}", connectorId);
2082-
if (throwable != null) {
2083-
Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable);
2084-
if (ExceptionsHelper.unwrap(cause, IndexNotFoundException.class) != null) {
2085-
log.error("Failed to get connector index", cause);
2086-
listener.onFailure(new OpenSearchStatusException("Failed to find connector", RestStatus.NOT_FOUND));
2080+
try (ThreadContext.StoredContext ctx = client.threadPool().getThreadContext().stashContext()) {
2081+
sdkClient.getDataObjectAsync(getDataObjectRequest).whenComplete((r, throwable) -> {
2082+
log.debug("Completed Get Connector Request, id:{}", connectorId);
2083+
ctx.restore();
2084+
if (throwable != null) {
2085+
Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable);
2086+
if (ExceptionsHelper.unwrap(cause, IndexNotFoundException.class) != null) {
2087+
log.error("Failed to get connector index", cause);
2088+
listener.onFailure(new OpenSearchStatusException("Failed to find connector", RestStatus.NOT_FOUND));
2089+
} else {
2090+
log.error("Failed to get ML connector {}", connectorId, cause);
2091+
listener.onFailure(cause);
2092+
}
20872093
} else {
2088-
log.error("Failed to get ML connector {}", connectorId, cause);
2089-
listener.onFailure(cause);
2090-
}
2091-
} else {
2092-
try {
2093-
GetResponse gr = r.parser() == null ? null : GetResponse.fromXContent(r.parser());
2094-
if (gr != null && gr.isExists()) {
2095-
try (
2096-
XContentParser parser = MLNodeUtils
2097-
.createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, gr.getSourceAsBytesRef())
2098-
) {
2099-
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
2100-
Connector connector = Connector.createConnector(parser);
2101-
listener.onResponse(connector);
2102-
} catch (Exception e) {
2103-
log.error("Failed to parse connector:{}", connectorId);
2104-
listener.onFailure(e);
2094+
try {
2095+
GetResponse gr = r.parser() == null ? null : GetResponse.fromXContent(r.parser());
2096+
if (gr != null && gr.isExists()) {
2097+
try (
2098+
XContentParser parser = MLNodeUtils
2099+
.createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, gr.getSourceAsBytesRef())
2100+
) {
2101+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
2102+
Connector connector = Connector.createConnector(parser);
2103+
listener.onResponse(connector);
2104+
} catch (Exception e) {
2105+
log.error("Failed to parse connector:{}", connectorId);
2106+
listener.onFailure(e);
2107+
}
2108+
} else {
2109+
listener
2110+
.onFailure(new OpenSearchStatusException("Failed to find connector:" + connectorId, RestStatus.NOT_FOUND));
21052111
}
2106-
} else {
2107-
listener.onFailure(new OpenSearchStatusException("Failed to find connector:" + connectorId, RestStatus.NOT_FOUND));
2112+
} catch (Exception e) {
2113+
listener.onFailure(e);
21082114
}
2109-
} catch (Exception e) {
2110-
listener.onFailure(e);
21112115
}
2112-
}
2113-
});
2116+
});
2117+
}
21142118
}
21152119

21162120
/**

plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java

+43
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
import org.mockito.ArgumentCaptor;
7676
import org.mockito.Mock;
7777
import org.mockito.MockitoAnnotations;
78+
import org.opensearch.OpenSearchStatusException;
7879
import org.opensearch.action.get.GetRequest;
7980
import org.opensearch.action.get.GetResponse;
8081
import org.opensearch.action.index.IndexResponse;
@@ -92,9 +93,11 @@
9293
import org.opensearch.core.common.breaker.CircuitBreakingException;
9394
import org.opensearch.core.common.bytes.BytesReference;
9495
import org.opensearch.core.index.shard.ShardId;
96+
import org.opensearch.core.rest.RestStatus;
9597
import org.opensearch.core.xcontent.NamedXContentRegistry;
9698
import org.opensearch.core.xcontent.ToXContent;
9799
import org.opensearch.core.xcontent.XContentBuilder;
100+
import org.opensearch.index.IndexNotFoundException;
98101
import org.opensearch.index.get.GetResult;
99102
import org.opensearch.ml.breaker.MLCircuitBreakerService;
100103
import org.opensearch.ml.breaker.ThresholdCircuitBreaker;
@@ -492,6 +495,46 @@ public void testRegisterMLRemoteModel() throws PrivilegedActionException, IOExce
492495
verify(mlTaskManager).updateMLTask(anyString(), any(), anyMap(), anyLong(), anyBoolean());
493496
}
494497

498+
@Test
499+
public void testRegisterMLRemoteModelModelGroupNotFoundException() throws PrivilegedActionException, IOException {
500+
// Create listener and capture the failure
501+
ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
502+
ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class);
503+
504+
// Setup mocks
505+
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
506+
when(mlCircuitBreakerService.checkOpenCB()).thenReturn(null);
507+
when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService);
508+
when(modelHelper.downloadPrebuiltModelMetaList(any(), any())).thenReturn(Collections.singletonList("demo"));
509+
when(modelHelper.isModelAllowed(any(), any())).thenReturn(true);
510+
511+
// Create test inputs
512+
MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true);
513+
MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build();
514+
515+
// Mock index handler
516+
mock_MLIndicesHandler_initModelIndex(mlIndicesHandler, true);
517+
518+
// Mock client.get() to throw IndexNotFoundException
519+
doAnswer(invocation -> {
520+
ActionListener<GetResponse> getModelGroupListener = invocation.getArgument(1);
521+
getModelGroupListener.onFailure(new IndexNotFoundException("Test", "test"));
522+
return null;
523+
}).when(client).get(any(), any());
524+
525+
// Execute method under test
526+
modelManager.registerMLRemoteModel(sdkClient, pretrainedInput, pretrainedTask, listener);
527+
528+
// Verify the listener's onFailure was called with correct exception
529+
verify(listener).onFailure(exceptionCaptor.capture());
530+
Exception exception = exceptionCaptor.getValue();
531+
532+
// Verify exception type and message
533+
assertTrue(exception instanceof OpenSearchStatusException);
534+
assertEquals("Model group not found", exception.getMessage());
535+
assertEquals(RestStatus.NOT_FOUND, ((OpenSearchStatusException) exception).status());
536+
}
537+
495538
public void testRegisterMLRemoteModel_SkipMemoryCBOpen() throws IOException {
496539
ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class);
497540
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());

plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java

-114
Original file line numberDiff line numberDiff line change
@@ -740,120 +740,6 @@ public void testCohereGenerateTextModel() throws IOException, InterruptedExcepti
740740
assertFalse(((String) responseMap.get("text")).isEmpty());
741741
}
742742

743-
public void testCohereClassifyModel() throws IOException, InterruptedException {
744-
// Skip test if key is null
745-
if (COHERE_KEY == null) {
746-
return;
747-
}
748-
String entity = "{\n"
749-
+ " \"name\": \"Cohere classify model Connector\",\n"
750-
+ " \"description\": \"The connector to public Cohere classify model service\",\n"
751-
+ " \"version\": 1,\n"
752-
+ " \"client_config\": {\n"
753-
+ " \"max_connection\": 20,\n"
754-
+ " \"connection_timeout\": 50000,\n"
755-
+ " \"read_timeout\": 50000\n"
756-
+ " },\n"
757-
+ " \"protocol\": \"http\",\n"
758-
+ " \"parameters\": {\n"
759-
+ " \"endpoint\": \"api.cohere.ai\",\n"
760-
+ " \"auth\": \"API_Key\",\n"
761-
+ " \"content_type\": \"application/json\",\n"
762-
+ " \"max_tokens\": \"20\"\n"
763-
+ " },\n"
764-
+ " \"credential\": {\n"
765-
+ " \"cohere_key\": \""
766-
+ COHERE_KEY
767-
+ "\"\n"
768-
+ " },\n"
769-
+ " \"actions\": [\n"
770-
+ " {\n"
771-
+ " \"action_type\": \"predict\",\n"
772-
+ " \"method\": \"POST\",\n"
773-
+ " \"url\": \"https://${parameters.endpoint}/v1/classify\",\n"
774-
+ " \"headers\": { \n"
775-
+ " \"Authorization\": \"Bearer ${credential.cohere_key}\"\n"
776-
+ " },\n"
777-
+ " \"request_body\": \"{ \\\"inputs\\\": ${parameters.inputs}, \\\"examples\\\": ${parameters.examples}, \\\"truncate\\\": \\\"END\\\" }\"\n"
778-
+ " }\n"
779-
+ " ]\n"
780-
+ "}";
781-
Response response = createConnector(entity);
782-
Map responseMap = parseResponseToMap(response);
783-
String connectorId = (String) responseMap.get("connector_id");
784-
response = registerRemoteModel("cohere classify model", connectorId);
785-
responseMap = parseResponseToMap(response);
786-
String taskId = (String) responseMap.get("task_id");
787-
waitForTask(taskId, MLTaskState.COMPLETED);
788-
response = getTask(taskId);
789-
responseMap = parseResponseToMap(response);
790-
String modelId = (String) responseMap.get("model_id");
791-
response = deployRemoteModel(modelId);
792-
responseMap = parseResponseToMap(response);
793-
taskId = (String) responseMap.get("task_id");
794-
waitForTask(taskId, MLTaskState.COMPLETED);
795-
String predictInput = "{\n"
796-
+ " \"parameters\": {\n"
797-
+ " \"inputs\": [\n"
798-
+ " \"Confirm your email address\",\n"
799-
+ " \"hey i need u to send some $\"\n"
800-
+ " ],\n"
801-
+ " \"examples\": [\n"
802-
+ " {\n"
803-
+ " \"text\": \"Dermatologists don't like her!\",\n"
804-
+ " \"label\": \"Spam\"\n"
805-
+ " },\n"
806-
+ " {\n"
807-
+ " \"text\": \"Hello, open to this?\",\n"
808-
+ " \"label\": \"Spam\"\n"
809-
+ " },\n"
810-
+ " {\n"
811-
+ " \"text\": \"I need help please wire me $1000 right now\",\n"
812-
+ " \"label\": \"Spam\"\n"
813-
+ " },\n"
814-
+ " {\n"
815-
+ " \"text\": \"Nice to know you ;)\",\n"
816-
+ " \"label\": \"Spam\"\n"
817-
+ " },\n"
818-
+ " {\n"
819-
+ " \"text\": \"Please help me?\",\n"
820-
+ " \"label\": \"Spam\"\n"
821-
+ " },\n"
822-
+ " {\n"
823-
+ " \"text\": \"Your parcel will be delivered today\",\n"
824-
+ " \"label\": \"Not spam\"\n"
825-
+ " },\n"
826-
+ " {\n"
827-
+ " \"text\": \"Review changes to our Terms and Conditions\",\n"
828-
+ " \"label\": \"Not spam\"\n"
829-
+ " },\n"
830-
+ " {\n"
831-
+ " \"text\": \"Weekly sync notes\",\n"
832-
+ " \"label\": \"Not spam\"\n"
833-
+ " },\n"
834-
+ " {\n"
835-
+ " \"text\": \"Re: Follow up from todays meeting\",\n"
836-
+ " \"label\": \"Not spam\"\n"
837-
+ " },\n"
838-
+ " {\n"
839-
+ " \"text\": \"Pre-read for tomorrow\",\n"
840-
+ " \"label\": \"Not spam\"\n"
841-
+ " }\n"
842-
+ " ]\n"
843-
+ " }\n"
844-
+ "}";
845-
846-
response = predictRemoteModel(modelId, predictInput);
847-
responseMap = parseResponseToMap(response);
848-
List responseList = (List) responseMap.get("inference_results");
849-
responseMap = (Map) responseList.get(0);
850-
responseList = (List) responseMap.get("output");
851-
responseMap = (Map) responseList.get(0);
852-
responseMap = (Map) responseMap.get("dataAsMap");
853-
responseList = (List) responseMap.get("classifications");
854-
assertFalse(responseList.isEmpty());
855-
}
856-
857743
public static Response createConnector(String input) throws IOException {
858744
try {
859745
return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/connectors/_create", null, TestHelper.toHttpEntity(input), null);

0 commit comments

Comments
 (0)