Skip to content

Commit 6c9415a

Browse files
[bug-fix] Handle BWC for bedrock converse API (opensearch-project#3173) (opensearch-project#3183)
* fix: bwc check for llm messages Signed-off-by: Pavan Yekbote <mail2pavanyekbote@gmail.com> * fix: add bwc for llmquestion Signed-off-by: Pavan Yekbote <mail2pavanyekbote@gmail.com> * test: adding ut for bwc Signed-off-by: Pavan Yekbote <mail2pavanyekbote@gmail.com> * fix: test case which counts optional occurence Signed-off-by: Pavan Yekbote <mail2pavanyekbote@gmail.com> --------- Signed-off-by: Pavan Yekbote <mail2pavanyekbote@gmail.com> (cherry picked from commit 24d2640) Co-authored-by: Pavan Yekbote <mail2pavanyekbote@gmail.com>
1 parent 158ffba commit 6c9415a

File tree

3 files changed

+86
-7
lines changed

3 files changed

+86
-7
lines changed

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java

+30-4
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,15 @@
2424
import java.util.List;
2525
import java.util.Objects;
2626

27+
import org.opensearch.Version;
2728
import org.opensearch.core.common.Strings;
2829
import org.opensearch.core.common.io.stream.StreamInput;
2930
import org.opensearch.core.common.io.stream.StreamOutput;
3031
import org.opensearch.core.common.io.stream.Writeable;
3132
import org.opensearch.core.xcontent.ToXContentObject;
3233
import org.opensearch.core.xcontent.XContentBuilder;
3334
import org.opensearch.core.xcontent.XContentParser;
35+
import org.opensearch.ml.common.CommonValue;
3436
import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock;
3537

3638
import com.google.common.base.Preconditions;
@@ -86,6 +88,8 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
8688

8789
public static final int SIZE_NULL_VALUE = -1;
8890

91+
static final Version MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES = CommonValue.VERSION_2_18_0;
92+
8993
@Setter
9094
@Getter
9195
private String conversationId;
@@ -185,16 +189,27 @@ public GenerativeQAParameters(
185189
}
186190

187191
public GenerativeQAParameters(StreamInput input) throws IOException {
192+
Version version = input.getVersion();
188193
this.conversationId = input.readOptionalString();
189194
this.llmModel = input.readOptionalString();
190-
this.llmQuestion = input.readOptionalString();
195+
196+
// this string was made optional in 2.18
197+
if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES)) {
198+
this.llmQuestion = input.readOptionalString();
199+
} else {
200+
this.llmQuestion = input.readString();
201+
}
202+
191203
this.systemPrompt = input.readOptionalString();
192204
this.userInstructions = input.readOptionalString();
193205
this.contextSize = input.readInt();
194206
this.interactionSize = input.readInt();
195207
this.timeout = input.readInt();
196208
this.llmResponseField = input.readOptionalString();
197-
this.llmMessages.addAll(input.readList(MessageBlock::new));
209+
210+
if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES)) {
211+
this.llmMessages.addAll(input.readList(MessageBlock::new));
212+
}
198213
}
199214

200215
@Override
@@ -246,16 +261,27 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
246261

247262
@Override
248263
public void writeTo(StreamOutput out) throws IOException {
264+
Version version = out.getVersion();
249265
out.writeOptionalString(conversationId);
250266
out.writeOptionalString(llmModel);
251-
out.writeOptionalString(llmQuestion);
267+
268+
// this string was made optional in 2.18
269+
if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES)) {
270+
out.writeOptionalString(llmQuestion);
271+
} else {
272+
out.writeString(llmQuestion);
273+
}
274+
252275
out.writeOptionalString(systemPrompt);
253276
out.writeOptionalString(userInstructions);
254277
out.writeInt(contextSize);
255278
out.writeInt(interactionSize);
256279
out.writeInt(timeout);
257280
out.writeOptionalString(llmResponseField);
258-
out.writeList(llmMessages);
281+
282+
if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES)) {
283+
out.writeList(llmMessages);
284+
}
259285
}
260286

261287
public static GenerativeQAParameters parse(XContentParser parser) throws IOException {

search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java

+11-3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import static org.mockito.Mockito.mock;
2222
import static org.mockito.Mockito.times;
2323
import static org.mockito.Mockito.verify;
24+
import static org.mockito.Mockito.when;
2425
import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS;
2526

2627
import java.io.EOFException;
@@ -30,6 +31,7 @@
3031
import java.util.Map;
3132

3233
import org.junit.Assert;
34+
import org.opensearch.Version;
3335
import org.opensearch.common.io.stream.BytesStreamOutput;
3436
import org.opensearch.common.settings.Settings;
3537
import org.opensearch.common.xcontent.XContentType;
@@ -119,9 +121,15 @@ public void testMiscMethods() throws IOException {
119121
assertNotEquals(builder1, builder2);
120122
assertNotEquals(builder1.hashCode(), builder2.hashCode());
121123

122-
StreamOutput so = mock(StreamOutput.class);
123-
builder1.writeTo(so);
124-
verify(so, times(6)).writeOptionalString(any());
124+
StreamOutput so1 = mock(StreamOutput.class);
125+
when(so1.getVersion()).thenReturn(GenerativeQAParameters.MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES);
126+
builder1.writeTo(so1);
127+
verify(so1, times(6)).writeOptionalString(any());
128+
129+
StreamOutput so2 = mock(StreamOutput.class);
130+
when(so2.getVersion()).thenReturn(Version.V_2_17_0);
131+
builder1.writeTo(so2);
132+
verify(so2, times(5)).writeOptionalString(any());
125133
}
126134

127135
public void testParse() throws IOException {

search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java

+45
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@
2626
import java.util.List;
2727
import java.util.Map;
2828

29+
import org.opensearch.Version;
2930
import org.opensearch.action.search.SearchRequest;
31+
import org.opensearch.common.io.stream.BytesStreamOutput;
32+
import org.opensearch.core.common.io.stream.StreamInput;
3033
import org.opensearch.core.common.io.stream.StreamOutput;
3134
import org.opensearch.core.xcontent.XContent;
3235
import org.opensearch.core.xcontent.XContentBuilder;
@@ -179,6 +182,48 @@ public void testWriteTo() throws IOException {
179182
assertTrue(timeout == intValues.get(2));
180183
}
181184

185+
public void testWriteToBwcBedrockConverse() throws IOException {
186+
String conversationId = "a";
187+
String llmModel = "b";
188+
String llmQuestion = "c";
189+
String systemPrompt = "s";
190+
String userInstructions = "u";
191+
int contextSize = 1;
192+
int interactionSize = 2;
193+
int timeout = 10;
194+
String llmResponseField = "text";
195+
GenerativeQAParameters expected = new GenerativeQAParameters(
196+
conversationId,
197+
llmModel,
198+
llmQuestion,
199+
systemPrompt,
200+
userInstructions,
201+
contextSize,
202+
interactionSize,
203+
timeout,
204+
llmResponseField,
205+
messageList
206+
);
207+
208+
// Version.2_18_0 (MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES)
209+
BytesStreamOutput output = new BytesStreamOutput();
210+
output.setVersion(GenerativeQAParameters.MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES);
211+
expected.writeTo(output);
212+
StreamInput input = output.bytes().streamInput();
213+
input.setVersion(GenerativeQAParameters.MINIMAL_SUPPORTED_VERSION_FOR_BEDROCK_CONVERSE_LLM_MESSAGES);
214+
GenerativeQAParameters actual = new GenerativeQAParameters(input);
215+
assertEquals(expected, actual);
216+
217+
// Version.2_17_0 (LlmMessages should be empty list)
218+
output = new BytesStreamOutput();
219+
output.setVersion(Version.V_2_17_0);
220+
expected.writeTo(output);
221+
input = output.bytes().streamInput();
222+
input.setVersion(Version.V_2_17_0);
223+
actual = new GenerativeQAParameters(input);
224+
assertTrue(actual.getLlmMessages().isEmpty());
225+
}
226+
182227
public void testMisc() {
183228
String conversationId = "a";
184229
String llmModel = "b";

0 commit comments

Comments
 (0)