|
6 | 6 | package org.opensearch.ml.rest;
|
7 | 7 |
|
8 | 8 | import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD;
|
| 9 | +import static org.opensearch.ml.utils.TestData.SENTENCE_TRANSFORMER_MODEL_HASH_VALUE; |
9 | 10 | import static org.opensearch.ml.utils.TestData.SENTENCE_TRANSFORMER_MODEL_URL;
|
10 | 11 | import static org.opensearch.ml.utils.TestHelper.makeRequest;
|
11 | 12 |
|
|
25 | 26 | import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
|
26 | 27 | import org.opensearch.ml.utils.TestHelper;
|
27 | 28 |
|
| 29 | +import com.jayway.jsonpath.DocumentContext; |
28 | 30 | import com.jayway.jsonpath.JsonPath;
|
29 | 31 |
|
30 | 32 | public class RestMLInferenceIngestProcessorIT extends MLCommonsRestTestCase {
|
@@ -431,6 +433,110 @@ public void testMLInferenceProcessorLocalModelObjectField() throws Exception {
|
431 | 433 | Assert.assertEquals(0.49191704, (Double) embedding2.get(0), 0.005);
|
432 | 434 | }
|
433 | 435 |
|
| 436 | + public void testMLInferenceIngestProcessor_simulatesIngestPipelineSuccessfully_withAsymmetricEmbeddingModelUsingPassageContentType() |
| 437 | + throws Exception { |
| 438 | + String taskId = registerModel(TestHelper.toJsonString(registerAsymmetricEmbeddingModelInput())); |
| 439 | + waitForTask(taskId, MLTaskState.COMPLETED); |
| 440 | + getTask(client(), taskId, response -> { |
| 441 | + assertNotNull(response.get(MODEL_ID_FIELD)); |
| 442 | + this.localModelId = (String) response.get(MODEL_ID_FIELD); |
| 443 | + try { |
| 444 | + String deployTaskID = deployModel(this.localModelId); |
| 445 | + waitForTask(deployTaskID, MLTaskState.COMPLETED); |
| 446 | + |
| 447 | + getModel(client(), this.localModelId, model -> { assertEquals("DEPLOYED", model.get("model_state")); }); |
| 448 | + } catch (IOException | InterruptedException e) { |
| 449 | + throw new RuntimeException(e); |
| 450 | + } |
| 451 | + }); |
| 452 | + |
| 453 | + String asymmetricPipelineName = "asymmetric_embedding_pipeline"; |
| 454 | + String createPipelineRequestBody = "{\n" |
| 455 | + + " \"description\": \"ingest PASSAGE text and generate a embedding using an asymmetric model\",\n" |
| 456 | + + " \"processors\": [\n" |
| 457 | + + " {\n" |
| 458 | + + " \"ml_inference\": {\n" |
| 459 | + + "\n" |
| 460 | + + " \"model_input\": \"{\\\"text_docs\\\":[\\\"${input_map.text_docs}\\\"],\\\"target_response\\\":[\\\"sentence_embedding\\\"],\\\"parameters\\\":{\\\"content_type\\\":\\\"passage\\\"}}\",\n" |
| 461 | + + " \"function_name\": \"text_embedding\",\n" |
| 462 | + + " \"model_id\": \"" |
| 463 | + + this.localModelId |
| 464 | + + "\",\n" |
| 465 | + + " \"input_map\": [\n" |
| 466 | + + " {\n" |
| 467 | + + " \"text_docs\": \"description\"\n" |
| 468 | + + " }\n" |
| 469 | + + " ],\n" |
| 470 | + + " \"output_map\": [\n" |
| 471 | + + " {\n" |
| 472 | + + "\n" |
| 473 | + + " " |
| 474 | + + " \"fact_embedding\": \"$.inference_results[0].output[0].data\"\n" |
| 475 | + + " }\n" |
| 476 | + + " ]\n" |
| 477 | + + " }\n" |
| 478 | + + " }\n" |
| 479 | + + " ]\n" |
| 480 | + + "}"; |
| 481 | + |
| 482 | + createPipelineProcessor(createPipelineRequestBody, asymmetricPipelineName); |
| 483 | + String sampleDocuments = "{\n" |
| 484 | + + " \"docs\": [\n" |
| 485 | + + " {\n" |
| 486 | + + " \"_index\": \"my-index\",\n" |
| 487 | + + " \"_id\": \"1\",\n" |
| 488 | + + " \"_source\": {\n" |
| 489 | + + " \"title\": \"Central Park\",\n" |
| 490 | + + " \"description\": \"A large public park in the heart of New York City, offering a wide range of recreational activities.\"\n" |
| 491 | + + " }\n" |
| 492 | + + " },\n" |
| 493 | + + " {\n" |
| 494 | + + " \"_index\": \"my-index\",\n" |
| 495 | + + " \"_id\": \"2\",\n" |
| 496 | + + " \"_source\": {\n" |
| 497 | + + " \"title\": \"Empire State Building\",\n" |
| 498 | + + " \"description\": \"An iconic skyscraper in New York City offering breathtaking views from its observation deck.\"\n" |
| 499 | + + " }\n" |
| 500 | + + " }\n" |
| 501 | + + " ]\n" |
| 502 | + + "}\n"; |
| 503 | + |
| 504 | + Map simulateResponseDocuments = simulateIngestPipeline(asymmetricPipelineName, sampleDocuments); |
| 505 | + |
| 506 | + DocumentContext documents = JsonPath.parse(simulateResponseDocuments); |
| 507 | + |
| 508 | + List centralParkFactEmbedding = documents.read("docs.[0].*._source.fact_embedding.*"); |
| 509 | + assertEquals(768, centralParkFactEmbedding.size()); |
| 510 | + Assert.assertEquals(0.5137818, (Double) centralParkFactEmbedding.get(0), 0.005); |
| 511 | + |
| 512 | + List empireStateBuildingFactEmbedding = documents.read("docs.[1].*._source.fact_embedding.*"); |
| 513 | + assertEquals(768, empireStateBuildingFactEmbedding.size()); |
| 514 | + Assert.assertEquals(0.4493039, (Double) empireStateBuildingFactEmbedding.get(0), 0.005); |
| 515 | + } |
| 516 | + |
| 517 | + private MLRegisterModelInput registerAsymmetricEmbeddingModelInput() { |
| 518 | + MLModelConfig modelConfig = TextEmbeddingModelConfig |
| 519 | + .builder() |
| 520 | + .modelType("bert") |
| 521 | + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) |
| 522 | + .embeddingDimension(768) |
| 523 | + .queryPrefix("query >>") |
| 524 | + .passagePrefix("passage >> ") |
| 525 | + .build(); |
| 526 | + |
| 527 | + return MLRegisterModelInput |
| 528 | + .builder() |
| 529 | + .modelName("test_model_name") |
| 530 | + .version("1.0.0") |
| 531 | + .functionName(FunctionName.TEXT_EMBEDDING) |
| 532 | + .modelFormat(MLModelFormat.TORCH_SCRIPT) |
| 533 | + .modelConfig(modelConfig) |
| 534 | + .url(SENTENCE_TRANSFORMER_MODEL_URL) |
| 535 | + .deployModel(false) |
| 536 | + .hashValue(SENTENCE_TRANSFORMER_MODEL_HASH_VALUE) |
| 537 | + .build(); |
| 538 | + } |
| 539 | + |
434 | 540 | // TODO: add tests for other local model types such as sparse/cross encoders
|
435 | 541 | public void testMLInferenceProcessorLocalModelNestedField() throws Exception {
|
436 | 542 |
|
@@ -550,6 +656,14 @@ protected void createPipelineProcessor(String requestBody, final String pipeline
|
550 | 656 |
|
551 | 657 | }
|
552 | 658 |
|
| 659 | + protected Map simulateIngestPipeline(String pipelineName, String sampleDocuments) throws IOException { |
| 660 | + Response ingestionResponse = TestHelper |
| 661 | + .makeRequest(client(), "POST", "/_ingest/pipeline/" + pipelineName + "/_simulate", null, sampleDocuments, null); |
| 662 | + assertEquals(200, ingestionResponse.getStatusLine().getStatusCode()); |
| 663 | + |
| 664 | + return parseResponseToMap(ingestionResponse); |
| 665 | + } |
| 666 | + |
553 | 667 | protected void createIndex(String indexName, String requestBody) throws Exception {
|
554 | 668 | Response response = makeRequest(client(), "PUT", indexName, null, requestBody, null);
|
555 | 669 | assertEquals(200, response.getStatusLine().getStatusCode());
|
@@ -585,7 +699,7 @@ protected MLRegisterModelInput registerModelInput() throws IOException, Interrup
|
585 | 699 | .modelConfig(modelConfig)
|
586 | 700 | .url(SENTENCE_TRANSFORMER_MODEL_URL)
|
587 | 701 | .deployModel(false)
|
588 |
| - .hashValue("e13b74006290a9d0f58c1376f9629d4ebc05a0f9385f40db837452b167ae9021") |
| 702 | + .hashValue(SENTENCE_TRANSFORMER_MODEL_HASH_VALUE) |
589 | 703 | .build();
|
590 | 704 | }
|
591 | 705 |
|
|
0 commit comments