Skip to content

Commit 138def6

Browse files
Allow llmQuestion to be optional when llmMessages is used. (Issue opensearch-project#3067)
Signed-off-by: Austin Lee <austin@aryn.ai>
1 parent 74c211e commit 138def6

File tree

3 files changed

+26
-15
lines changed

3 files changed

+26
-15
lines changed

plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java

+18-6
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase {
359359
+ " \"ext\": {\n"
360360
+ " \"generative_qa_parameters\": {\n"
361361
+ " \"llm_model\": \"%s\",\n"
362-
+ " \"llm_question\": \"%s\",\n"
362+
// + " \"llm_question\": \"%s\",\n"
363363
+ " \"system_prompt\": \"%s\",\n"
364364
+ " \"user_instructions\": \"%s\",\n"
365365
+ " \"context_size\": %d,\n"
@@ -378,7 +378,7 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase {
378378
+ " \"ext\": {\n"
379379
+ " \"generative_qa_parameters\": {\n"
380380
+ " \"llm_model\": \"%s\",\n"
381-
+ " \"llm_question\": \"%s\",\n"
381+
// + " \"llm_question\": \"%s\",\n"
382382
// + " \"system_prompt\": \"%s\",\n"
383383
+ " \"user_instructions\": \"%s\",\n"
384384
+ " \"context_size\": %d,\n"
@@ -723,8 +723,12 @@ public void testBM25WithBedrock() throws Exception {
723723
public void testBM25WithBedrockConverse() throws Exception {
724724
// Skip test if key is null
725725
if (AWS_ACCESS_KEY_ID == null) {
726+
System.out.println("Skipping testBM25WithBedrockConverse because AWS_ACCESS_KEY_ID is null");
726727
return;
727728
}
729+
730+
System.out.println("Running testBM25WithBedrockConverse");
731+
728732
Response response = createConnector(BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT);
729733
Map responseMap = parseResponseToMap(response);
730734
String connectorId = (String) responseMap.get("connector_id");
@@ -775,8 +779,11 @@ public void testBM25WithBedrockConverse() throws Exception {
775779
public void testBM25WithBedrockConverseUsingLlmMessages() throws Exception {
776780
// Skip test if key is null
777781
if (AWS_ACCESS_KEY_ID == null) {
782+
System.out.println("Skipping testBM25WithBedrockConverseUsingLlmMessages because AWS_ACCESS_KEY_ID is null");
778783
return;
779784
}
785+
System.out.println("Running testBM25WithBedrockConverseUsingLlmMessages");
786+
780787
Response response = createConnector(BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2);
781788
Map responseMap = parseResponseToMap(response);
782789
String connectorId = (String) responseMap.get("connector_id");
@@ -835,8 +842,11 @@ public void testBM25WithBedrockConverseUsingLlmMessages() throws Exception {
835842
public void testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat() throws Exception {
836843
// Skip test if key is null
837844
if (AWS_ACCESS_KEY_ID == null) {
845+
System.out.println("Skipping testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat because AWS_ACCESS_KEY_ID is null");
838846
return;
839847
}
848+
849+
System.out.println("Running testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat");
840850
Response response = createConnector(BEDROCK_DOCUMENT_CONVERSE_CONNECTOR_BLUEPRINT2);
841851
Map responseMap = parseResponseToMap(response);
842852
String connectorId = (String) responseMap.get("connector_id");
@@ -894,8 +904,11 @@ public void testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat() throws
894904
public void testBM25WithOpenAIWithConversation() throws Exception {
895905
// Skip test if key is null
896906
if (OPENAI_KEY == null) {
907+
System.out.println("Skipping testBM25WithOpenAIWithConversation because OPENAI_KEY is null");
897908
return;
898909
}
910+
System.out.println("Running testBM25WithOpenAIWithConversation");
911+
899912
Response response = createConnector(OPENAI_CONNECTOR_BLUEPRINT);
900913
Map responseMap = parseResponseToMap(response);
901914
String connectorId = (String) responseMap.get("connector_id");
@@ -951,8 +964,11 @@ public void testBM25WithOpenAIWithConversation() throws Exception {
951964
public void testBM25WithOpenAIWithConversationAndImage() throws Exception {
952965
// Skip test if key is null
953966
if (OPENAI_KEY == null) {
967+
System.out.println("Skipping testBM25WithOpenAIWithConversationAndImage because OPENAI_KEY is null");
954968
return;
955969
}
970+
System.out.println("Running testBM25WithOpenAIWithConversationAndImage");
971+
956972
Response response = createConnector(OPENAI_4o_CONNECTOR_BLUEPRINT);
957973
Map responseMap = parseResponseToMap(response);
958974
String connectorId = (String) responseMap.get("connector_id");
@@ -1245,7 +1261,6 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
12451261
requestParameters.source,
12461262
requestParameters.match,
12471263
requestParameters.llmModel,
1248-
requestParameters.llmQuestion,
12491264
requestParameters.systemPrompt,
12501265
requestParameters.userInstructions,
12511266
requestParameters.contextSize,
@@ -1268,8 +1283,6 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
12681283
requestParameters.source,
12691284
requestParameters.match,
12701285
requestParameters.llmModel,
1271-
requestParameters.llmQuestion,
1272-
// requestParameters.systemPrompt,
12731286
requestParameters.userInstructions,
12741287
requestParameters.contextSize,
12751288
requestParameters.interactionSize,
@@ -1309,7 +1322,6 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
13091322
requestParameters.source,
13101323
requestParameters.match,
13111324
requestParameters.llmModel,
1312-
requestParameters.llmQuestion,
13131325
requestParameters.systemPrompt,
13141326
requestParameters.userInstructions,
13151327
requestParameters.contextSize,

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java

+7-7
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,11 @@ public GenerativeQAParameters(
167167
this.conversationId = conversationId;
168168
this.llmModel = llmModel;
169169

170-
// TODO: keep this requirement until we can extract the question from the query or from the request processor parameters
171-
// for question rewriting.
172-
Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION + " must be provided.");
170+
Preconditions
171+
.checkArgument(
172+
!(Strings.isNullOrEmpty(llmQuestion) && (llmMessages == null || llmMessages.isEmpty())),
173+
"At least one of " + LLM_QUESTION + " or " + LLM_MESSAGES_FIELD + " must be provided."
174+
);
173175
this.llmQuestion = llmQuestion;
174176
this.systemPrompt = systemPrompt;
175177
this.userInstructions = userInstructions;
@@ -185,7 +187,7 @@ public GenerativeQAParameters(
185187
public GenerativeQAParameters(StreamInput input) throws IOException {
186188
this.conversationId = input.readOptionalString();
187189
this.llmModel = input.readOptionalString();
188-
this.llmQuestion = input.readString();
190+
this.llmQuestion = input.readOptionalString();
189191
this.systemPrompt = input.readOptionalString();
190192
this.userInstructions = input.readOptionalString();
191193
this.contextSize = input.readInt();
@@ -246,9 +248,7 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
246248
public void writeTo(StreamOutput out) throws IOException {
247249
out.writeOptionalString(conversationId);
248250
out.writeOptionalString(llmModel);
249-
250-
Preconditions.checkNotNull(llmQuestion, "llm_question must not be null.");
251-
out.writeString(llmQuestion);
251+
out.writeOptionalString(llmQuestion);
252252
out.writeOptionalString(systemPrompt);
253253
out.writeOptionalString(userInstructions);
254254
out.writeInt(contextSize);

search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,7 @@ public void testMiscMethods() throws IOException {
121121

122122
StreamOutput so = mock(StreamOutput.class);
123123
builder1.writeTo(so);
124-
verify(so, times(5)).writeOptionalString(any());
125-
verify(so, times(1)).writeString(any());
124+
verify(so, times(6)).writeOptionalString(any());
126125
}
127126

128127
public void testParse() throws IOException {

0 commit comments

Comments
 (0)