Skip to content

Commit c984f62

Browse files
Add a new IT test that uses both an image and a documenet in RAG.
Signed-off-by: Austin Lee <austin@aryn.ai>
1 parent 74c211e commit c984f62

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed

plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java

+69
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,72 @@ public void testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat() throws
891891
assertNotNull(answer);
892892
}
893893

894+
public void testBM25WithBedrockConverseUsingLlmMessagesForImageAndDocument() throws Exception {
895+
// Skip test if key is null
896+
if (AWS_ACCESS_KEY_ID == null) {
897+
return;
898+
}
899+
Response response = createConnector(BEDROCK_DOCUMENT_CONVERSE_CONNECTOR_BLUEPRINT2);
900+
Map responseMap = parseResponseToMap(response);
901+
String connectorId = (String) responseMap.get("connector_id");
902+
response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "Bedrock Anthropic Claude", connectorId);
903+
responseMap = parseResponseToMap(response);
904+
String taskId = (String) responseMap.get("task_id");
905+
waitForTask(taskId, MLTaskState.COMPLETED);
906+
response = RestMLRemoteInferenceIT.getTask(taskId);
907+
responseMap = parseResponseToMap(response);
908+
String modelId = (String) responseMap.get("model_id");
909+
response = deployRemoteModel(modelId);
910+
responseMap = parseResponseToMap(response);
911+
taskId = (String) responseMap.get("task_id");
912+
waitForTask(taskId, MLTaskState.COMPLETED);
913+
914+
PipelineParameters pipelineParameters = new PipelineParameters();
915+
pipelineParameters.tag = "testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat";
916+
pipelineParameters.description = "desc";
917+
pipelineParameters.modelId = modelId;
918+
// pipelineParameters.systemPrompt = "You are a helpful assistant";
919+
pipelineParameters.userInstructions = "none";
920+
pipelineParameters.context_field = "text";
921+
Response response1 = createSearchPipeline2("pipeline_test", pipelineParameters);
922+
assertEquals(200, response1.getStatusLine().getStatusCode());
923+
924+
byte[] rawImage = FileUtils
925+
.readFileToByteArray(Path.of(classLoader.getResource(TEST_DOC_PATH + "openai_boardwalk.jpg").toURI()).toFile());
926+
String imageContent = Base64.getEncoder().encodeToString(rawImage);
927+
928+
byte[] docBytes = FileUtils.readFileToByteArray(Path.of(classLoader.getResource(TEST_DOC_PATH + "lincoln.pdf").toURI()).toFile());
929+
String docContent = Base64.getEncoder().encodeToString(docBytes);
930+
931+
SearchRequestParameters requestParameters;
932+
requestParameters = new SearchRequestParameters();
933+
requestParameters.source = "text";
934+
requestParameters.match = "president";
935+
requestParameters.llmModel = BEDROCK_CONVERSE_ANTHROPIC_CLAUDE;
936+
requestParameters.llmQuestion = "use the information from the attached document to tell me something interesting about lincoln";
937+
requestParameters.contextSize = 5;
938+
requestParameters.interactionSize = 5;
939+
requestParameters.timeout = 60;
940+
requestParameters.imageFormat = "jpeg";
941+
requestParameters.imageType = "data"; // Bedrock does not support URLs
942+
requestParameters.imageData = imageContent;
943+
requestParameters.documentFormat = "pdf";
944+
requestParameters.documentName = "lincoln";
945+
requestParameters.documentData = docContent;
946+
Response response3 = performSearch(INDEX_NAME, "pipeline_test", 5, requestParameters);
947+
assertEquals(200, response3.getStatusLine().getStatusCode());
948+
949+
Map responseMap3 = parseResponseToMap(response3);
950+
Map ext = (Map) responseMap3.get("ext");
951+
assertNotNull(ext);
952+
Map rag = (Map) ext.get("retrieval_augmented_generation");
953+
assertNotNull(rag);
954+
955+
// TODO handle errors such as throttling
956+
String answer = (String) rag.get("answer");
957+
assertNotNull(answer);
958+
}
959+
894960
public void testBM25WithOpenAIWithConversation() throws Exception {
895961
// Skip test if key is null
896962
if (OPENAI_KEY == null) {
@@ -1352,6 +1418,9 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
13521418
requestParameters.interactionSize,
13531419
requestParameters.timeout
13541420
);
1421+
1422+
System.out.println(httpEntity);
1423+
13551424
return makeRequest(
13561425
client(),
13571426
"POST",

0 commit comments

Comments
 (0)