Skip to content

Commit 17e81ae

Browse files
Add support for Bedrock Converse API (Anthropic Messages API, Claude 3.5 Sonnet) (opensearch-project#2851)
* Add support for Anthropic Message API (Issue 2826) Signed-off-by: Austin Lee <austin@aryn.ai> * Fix a bug. Signed-off-by: Austin Lee <austin@aryn.ai> * Add unit tests, improve coverage, clean up code. Signed-off-by: Austin Lee <austin@aryn.ai> * Allow pdf and jpg files for IT tests for multimodel conversation API testing. Signed-off-by: Austin Lee <austin@aryn.ai> * Fix spotless check issues. Signed-off-by: Austin Lee <austin@aryn.ai> * Update IT to work with session tokens. Signed-off-by: Austin Lee <austin@aryn.ai> * Fix MLRAGSearchProcessorIT not to extend RestMLRemoteInferenceIT. Signed-off-by: Austin Lee <austin@aryn.ai> * Use suite specific model group name. Signed-off-by: Austin Lee <austin@aryn.ai> * Disable tests that require futher investigation. Signed-off-by: Austin Lee <austin@aryn.ai> * Skip two additional tests with time-outs. Signed-off-by: Austin Lee <austin@aryn.ai> * Restore a change from RestMLRemoteInferenceIT. Signed-off-by: Austin Lee <austin@aryn.ai> --------- Signed-off-by: Austin Lee <austin@aryn.ai>
1 parent eca963f commit 17e81ae

File tree

18 files changed

+1869
-52
lines changed

18 files changed

+1869
-52
lines changed

plugin/build.gradle

+5
Original file line numberDiff line numberDiff line change
@@ -578,3 +578,8 @@ task bwcTestSuite(type: StandaloneRestIntegTestTask) {
578578
dependsOn tasks.named("${baseName}#rollingUpgradeClusterTask")
579579
dependsOn tasks.named("${baseName}#fullRestartClusterTask")
580580
}
581+
582+
forbiddenPatterns {
583+
exclude '**/*.pdf'
584+
exclude '**/*.jpg'
585+
}

plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java

+684-17
Large diffs are not rendered by default.
Binary file not shown.
Loading

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java

+4-2
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,8 @@ public void processResponseAsync(
179179
chatHistory,
180180
searchResults,
181181
timeout,
182-
params.getLlmResponseField()
182+
params.getLlmResponseField(),
183+
params.getLlmMessages()
183184
),
184185
null,
185186
llmQuestion,
@@ -202,7 +203,8 @@ public void processResponseAsync(
202203
chatHistory,
203204
searchResults,
204205
timeout,
205-
params.getLlmResponseField()
206+
params.getLlmResponseField(),
207+
params.getLlmMessages()
206208
),
207209
conversationId,
208210
llmQuestion,

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java

+47-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
package org.opensearch.searchpipelines.questionanswering.generative.ext;
1919

2020
import java.io.IOException;
21+
import java.util.ArrayList;
22+
import java.util.List;
2123
import java.util.Objects;
2224

2325
import org.opensearch.core.ParseField;
@@ -30,6 +32,7 @@
3032
import org.opensearch.core.xcontent.XContentBuilder;
3133
import org.opensearch.core.xcontent.XContentParser;
3234
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants;
35+
import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock;
3336

3437
import com.google.common.base.Preconditions;
3538

@@ -81,6 +84,8 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
8184
// that contains the chat completion text, i.e. "answer".
8285
private static final ParseField LLM_RESPONSE_FIELD = new ParseField("llm_response_field");
8386

87+
private static final ParseField LLM_MESSAGES_FIELD = new ParseField("llm_messages");
88+
8489
public static final int SIZE_NULL_VALUE = -1;
8590

8691
static {
@@ -94,6 +99,7 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
9499
PARSER.declareIntOrNull(GenerativeQAParameters::setInteractionSize, SIZE_NULL_VALUE, INTERACTION_SIZE);
95100
PARSER.declareIntOrNull(GenerativeQAParameters::setTimeout, SIZE_NULL_VALUE, TIMEOUT);
96101
PARSER.declareStringOrNull(GenerativeQAParameters::setLlmResponseField, LLM_RESPONSE_FIELD);
102+
PARSER.declareObjectArray(GenerativeQAParameters::setMessageBlock, (p, c) -> MessageBlock.fromXContent(p), LLM_MESSAGES_FIELD);
97103
}
98104

99105
@Setter
@@ -132,6 +138,10 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
132138
@Getter
133139
private String llmResponseField;
134140

141+
@Setter
142+
@Getter
143+
private List<MessageBlock> llmMessages = new ArrayList<>();
144+
135145
public GenerativeQAParameters(
136146
String conversationId,
137147
String llmModel,
@@ -142,6 +152,32 @@ public GenerativeQAParameters(
142152
Integer interactionSize,
143153
Integer timeout,
144154
String llmResponseField
155+
) {
156+
this(
157+
conversationId,
158+
llmModel,
159+
llmQuestion,
160+
systemPrompt,
161+
userInstructions,
162+
contextSize,
163+
interactionSize,
164+
timeout,
165+
llmResponseField,
166+
null
167+
);
168+
}
169+
170+
public GenerativeQAParameters(
171+
String conversationId,
172+
String llmModel,
173+
String llmQuestion,
174+
String systemPrompt,
175+
String userInstructions,
176+
Integer contextSize,
177+
Integer interactionSize,
178+
Integer timeout,
179+
String llmResponseField,
180+
List<MessageBlock> llmMessages
145181
) {
146182
this.conversationId = conversationId;
147183
this.llmModel = llmModel;
@@ -156,6 +192,9 @@ public GenerativeQAParameters(
156192
this.interactionSize = (interactionSize == null) ? SIZE_NULL_VALUE : interactionSize;
157193
this.timeout = (timeout == null) ? SIZE_NULL_VALUE : timeout;
158194
this.llmResponseField = llmResponseField;
195+
if (llmMessages != null) {
196+
this.llmMessages.addAll(llmMessages);
197+
}
159198
}
160199

161200
public GenerativeQAParameters(StreamInput input) throws IOException {
@@ -168,6 +207,7 @@ public GenerativeQAParameters(StreamInput input) throws IOException {
168207
this.interactionSize = input.readInt();
169208
this.timeout = input.readInt();
170209
this.llmResponseField = input.readOptionalString();
210+
this.llmMessages.addAll(input.readList(MessageBlock::new));
171211
}
172212

173213
@Override
@@ -181,7 +221,8 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
181221
.field(CONTEXT_SIZE.getPreferredName(), this.contextSize)
182222
.field(INTERACTION_SIZE.getPreferredName(), this.interactionSize)
183223
.field(TIMEOUT.getPreferredName(), this.timeout)
184-
.field(LLM_RESPONSE_FIELD.getPreferredName(), this.llmResponseField);
224+
.field(LLM_RESPONSE_FIELD.getPreferredName(), this.llmResponseField)
225+
.field(LLM_MESSAGES_FIELD.getPreferredName(), this.llmMessages);
185226
}
186227

187228
@Override
@@ -197,6 +238,7 @@ public void writeTo(StreamOutput out) throws IOException {
197238
out.writeInt(interactionSize);
198239
out.writeInt(timeout);
199240
out.writeOptionalString(llmResponseField);
241+
out.writeList(llmMessages);
200242
}
201243

202244
public static GenerativeQAParameters parse(XContentParser parser) throws IOException {
@@ -223,4 +265,8 @@ public boolean equals(Object o) {
223265
&& (this.timeout == other.getTimeout())
224266
&& Objects.equals(this.llmResponseField, other.getLlmResponseField());
225267
}
268+
269+
public void setMessageBlock(List<MessageBlock> blockList) {
270+
this.llmMessages = blockList;
271+
}
226272
}

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java

+1
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,5 @@ public class ChatCompletionInput {
4444
private String userInstructions;
4545
private Llm.ModelProvider modelProvider;
4646
private String llmResponseField;
47+
private List<MessageBlock> llmMessages;
4748
}

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java

+30-4
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ protected void setMlClient(MachineLearningInternalClient mlClient) {
7575
* @return
7676
*/
7777
@Override
78-
7978
public void doChatCompletion(ChatCompletionInput chatCompletionInput, ActionListener<ChatCompletionOutput> listener) {
8079
MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(getInputParameters(chatCompletionInput)).build();
8180
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataset).build();
@@ -113,14 +112,15 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet
113112
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel());
114113
String messages = PromptUtil
115114
.getChatCompletionPrompt(
115+
chatCompletionInput.getModelProvider(),
116116
chatCompletionInput.getSystemPrompt(),
117117
chatCompletionInput.getUserInstructions(),
118118
chatCompletionInput.getQuestion(),
119119
chatCompletionInput.getChatHistory(),
120-
chatCompletionInput.getContexts()
120+
chatCompletionInput.getContexts(),
121+
chatCompletionInput.getLlmMessages()
121122
);
122123
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages);
123-
// log.info("Messages to LLM: {}", messages);
124124
} else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK
125125
|| chatCompletionInput.getModelProvider() == ModelProvider.COHERE
126126
|| chatCompletionInput.getLlmResponseField() != null) {
@@ -136,6 +136,19 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet
136136
chatCompletionInput.getContexts()
137137
)
138138
);
139+
} else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK_CONVERSE) {
140+
// Bedrock Converse API does not include the system prompt as part of the Messages block.
141+
String messages = PromptUtil
142+
.getChatCompletionPrompt(
143+
chatCompletionInput.getModelProvider(),
144+
null,
145+
chatCompletionInput.getUserInstructions(),
146+
chatCompletionInput.getQuestion(),
147+
chatCompletionInput.getChatHistory(),
148+
chatCompletionInput.getContexts(),
149+
chatCompletionInput.getLlmMessages()
150+
);
151+
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages);
139152
} else {
140153
throw new IllegalArgumentException(
141154
"Unknown/unsupported model provider: "
@@ -144,7 +157,6 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet
144157
);
145158
}
146159

147-
// log.info("LLM input parameters: {}", inputParameters.toString());
148160
return inputParameters;
149161
}
150162

@@ -184,6 +196,20 @@ protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider,
184196
} else if (provider == ModelProvider.COHERE) {
185197
answerField = "text";
186198
fillAnswersOrErrors(dataAsMap, answers, errors, answerField, errorField, defaultErrorMessageField);
199+
} else if (provider == ModelProvider.BEDROCK_CONVERSE) {
200+
Map output = (Map) dataAsMap.get("output");
201+
Map message = (Map) output.get("message");
202+
if (message != null) {
203+
List content = (List) message.get("content");
204+
String answer = (String) ((Map) content.get(0)).get("text");
205+
answers.add(answer);
206+
} else {
207+
Map error = (Map) output.get("error");
208+
if (error == null) {
209+
throw new RuntimeException("Unexpected output: " + output);
210+
}
211+
errors.add((String) error.get("message"));
212+
}
187213
} else {
188214
throw new IllegalArgumentException(
189215
"Unknown/unsupported model provider: " + provider + ". You must provide a valid model provider or llm_response_field."

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ public interface Llm {
2828
enum ModelProvider {
2929
OPENAI,
3030
BEDROCK,
31-
COHERE
31+
COHERE,
32+
BEDROCK_CONVERSE
3233
}
3334

3435
void doChatCompletion(ChatCompletionInput input, ActionListener<ChatCompletionOutput> listener);

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java

+9-3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ public class LlmIOUtil {
2929

3030
public static final String BEDROCK_PROVIDER_PREFIX = "bedrock/";
3131
public static final String COHERE_PROVIDER_PREFIX = "cohere/";
32+
public static final String BEDROCK_CONVERSE__PROVIDER_PREFIX = "bedrock-converse/";
3233

3334
public static ChatCompletionInput createChatCompletionInput(
3435
String llmModel,
@@ -49,7 +50,8 @@ public static ChatCompletionInput createChatCompletionInput(
4950
chatHistory,
5051
contexts,
5152
timeoutInSeconds,
52-
llmResponseField
53+
llmResponseField,
54+
null
5355
);
5456
}
5557

@@ -61,7 +63,8 @@ public static ChatCompletionInput createChatCompletionInput(
6163
List<Interaction> chatHistory,
6264
List<String> contexts,
6365
int timeoutInSeconds,
64-
String llmResponseField
66+
String llmResponseField,
67+
List<MessageBlock> llmMessages
6568
) {
6669
Llm.ModelProvider provider = null;
6770
if (llmResponseField == null) {
@@ -71,6 +74,8 @@ public static ChatCompletionInput createChatCompletionInput(
7174
provider = Llm.ModelProvider.BEDROCK;
7275
} else if (llmModel.startsWith(COHERE_PROVIDER_PREFIX)) {
7376
provider = Llm.ModelProvider.COHERE;
77+
} else if (llmModel.startsWith(BEDROCK_CONVERSE__PROVIDER_PREFIX)) {
78+
provider = Llm.ModelProvider.BEDROCK_CONVERSE;
7479
}
7580
}
7681
}
@@ -83,7 +88,8 @@ public static ChatCompletionInput createChatCompletionInput(
8388
systemPrompt,
8489
userInstructions,
8590
provider,
86-
llmResponseField
91+
llmResponseField,
92+
llmMessages
8793
);
8894
}
8995
}

0 commit comments

Comments
 (0)