Skip to content

Commit 0135cb9

Browse files
authored
Revert "Add support for Bedrock Converse API (Anthropic Messages API, Claude 3.5 Sonnet) (opensearch-project#2851) (opensearch-project#2913)" (opensearch-project#2929)
This reverts commit ed37690.
1 parent 24fc9c3 commit 0135cb9

File tree

18 files changed

+52
-1869
lines changed

18 files changed

+52
-1869
lines changed

plugin/build.gradle

-5
Original file line numberDiff line numberDiff line change
@@ -572,8 +572,3 @@ task bwcTestSuite(type: StandaloneRestIntegTestTask) {
572572
dependsOn tasks.named("${baseName}#rollingUpgradeClusterTask")
573573
dependsOn tasks.named("${baseName}#fullRestartClusterTask")
574574
}
575-
576-
forbiddenPatterns {
577-
exclude '**/*.pdf'
578-
exclude '**/*.jpg'
579-
}

plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java

+17-684
Large diffs are not rendered by default.
Binary file not shown.
Binary file not shown.

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java

+2-4
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,7 @@ public void processResponseAsync(
179179
chatHistory,
180180
searchResults,
181181
timeout,
182-
params.getLlmResponseField(),
183-
params.getLlmMessages()
182+
params.getLlmResponseField()
184183
),
185184
null,
186185
llmQuestion,
@@ -203,8 +202,7 @@ public void processResponseAsync(
203202
chatHistory,
204203
searchResults,
205204
timeout,
206-
params.getLlmResponseField(),
207-
params.getLlmMessages()
205+
params.getLlmResponseField()
208206
),
209207
conversationId,
210208
llmQuestion,

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java

+1-47
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
package org.opensearch.searchpipelines.questionanswering.generative.ext;
1919

2020
import java.io.IOException;
21-
import java.util.ArrayList;
22-
import java.util.List;
2321
import java.util.Objects;
2422

2523
import org.opensearch.core.ParseField;
@@ -32,7 +30,6 @@
3230
import org.opensearch.core.xcontent.XContentBuilder;
3331
import org.opensearch.core.xcontent.XContentParser;
3432
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants;
35-
import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock;
3633

3734
import com.google.common.base.Preconditions;
3835

@@ -84,8 +81,6 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
8481
// that contains the chat completion text, i.e. "answer".
8582
private static final ParseField LLM_RESPONSE_FIELD = new ParseField("llm_response_field");
8683

87-
private static final ParseField LLM_MESSAGES_FIELD = new ParseField("llm_messages");
88-
8984
public static final int SIZE_NULL_VALUE = -1;
9085

9186
static {
@@ -99,7 +94,6 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
9994
PARSER.declareIntOrNull(GenerativeQAParameters::setInteractionSize, SIZE_NULL_VALUE, INTERACTION_SIZE);
10095
PARSER.declareIntOrNull(GenerativeQAParameters::setTimeout, SIZE_NULL_VALUE, TIMEOUT);
10196
PARSER.declareStringOrNull(GenerativeQAParameters::setLlmResponseField, LLM_RESPONSE_FIELD);
102-
PARSER.declareObjectArray(GenerativeQAParameters::setMessageBlock, (p, c) -> MessageBlock.fromXContent(p), LLM_MESSAGES_FIELD);
10397
}
10498

10599
@Setter
@@ -138,10 +132,6 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
138132
@Getter
139133
private String llmResponseField;
140134

141-
@Setter
142-
@Getter
143-
private List<MessageBlock> llmMessages = new ArrayList<>();
144-
145135
public GenerativeQAParameters(
146136
String conversationId,
147137
String llmModel,
@@ -152,32 +142,6 @@ public GenerativeQAParameters(
152142
Integer interactionSize,
153143
Integer timeout,
154144
String llmResponseField
155-
) {
156-
this(
157-
conversationId,
158-
llmModel,
159-
llmQuestion,
160-
systemPrompt,
161-
userInstructions,
162-
contextSize,
163-
interactionSize,
164-
timeout,
165-
llmResponseField,
166-
null
167-
);
168-
}
169-
170-
public GenerativeQAParameters(
171-
String conversationId,
172-
String llmModel,
173-
String llmQuestion,
174-
String systemPrompt,
175-
String userInstructions,
176-
Integer contextSize,
177-
Integer interactionSize,
178-
Integer timeout,
179-
String llmResponseField,
180-
List<MessageBlock> llmMessages
181145
) {
182146
this.conversationId = conversationId;
183147
this.llmModel = llmModel;
@@ -192,9 +156,6 @@ public GenerativeQAParameters(
192156
this.interactionSize = (interactionSize == null) ? SIZE_NULL_VALUE : interactionSize;
193157
this.timeout = (timeout == null) ? SIZE_NULL_VALUE : timeout;
194158
this.llmResponseField = llmResponseField;
195-
if (llmMessages != null) {
196-
this.llmMessages.addAll(llmMessages);
197-
}
198159
}
199160

200161
public GenerativeQAParameters(StreamInput input) throws IOException {
@@ -207,7 +168,6 @@ public GenerativeQAParameters(StreamInput input) throws IOException {
207168
this.interactionSize = input.readInt();
208169
this.timeout = input.readInt();
209170
this.llmResponseField = input.readOptionalString();
210-
this.llmMessages.addAll(input.readList(MessageBlock::new));
211171
}
212172

213173
@Override
@@ -221,8 +181,7 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
221181
.field(CONTEXT_SIZE.getPreferredName(), this.contextSize)
222182
.field(INTERACTION_SIZE.getPreferredName(), this.interactionSize)
223183
.field(TIMEOUT.getPreferredName(), this.timeout)
224-
.field(LLM_RESPONSE_FIELD.getPreferredName(), this.llmResponseField)
225-
.field(LLM_MESSAGES_FIELD.getPreferredName(), this.llmMessages);
184+
.field(LLM_RESPONSE_FIELD.getPreferredName(), this.llmResponseField);
226185
}
227186

228187
@Override
@@ -238,7 +197,6 @@ public void writeTo(StreamOutput out) throws IOException {
238197
out.writeInt(interactionSize);
239198
out.writeInt(timeout);
240199
out.writeOptionalString(llmResponseField);
241-
out.writeList(llmMessages);
242200
}
243201

244202
public static GenerativeQAParameters parse(XContentParser parser) throws IOException {
@@ -265,8 +223,4 @@ public boolean equals(Object o) {
265223
&& (this.timeout == other.getTimeout())
266224
&& Objects.equals(this.llmResponseField, other.getLlmResponseField());
267225
}
268-
269-
public void setMessageBlock(List<MessageBlock> blockList) {
270-
this.llmMessages = blockList;
271-
}
272226
}

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java

-1
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,4 @@ public class ChatCompletionInput {
4444
private String userInstructions;
4545
private Llm.ModelProvider modelProvider;
4646
private String llmResponseField;
47-
private List<MessageBlock> llmMessages;
4847
}

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java

+4-30
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ protected void setMlClient(MachineLearningInternalClient mlClient) {
7575
* @return
7676
*/
7777
@Override
78+
7879
public void doChatCompletion(ChatCompletionInput chatCompletionInput, ActionListener<ChatCompletionOutput> listener) {
7980
MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(getInputParameters(chatCompletionInput)).build();
8081
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataset).build();
@@ -112,15 +113,14 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet
112113
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel());
113114
String messages = PromptUtil
114115
.getChatCompletionPrompt(
115-
chatCompletionInput.getModelProvider(),
116116
chatCompletionInput.getSystemPrompt(),
117117
chatCompletionInput.getUserInstructions(),
118118
chatCompletionInput.getQuestion(),
119119
chatCompletionInput.getChatHistory(),
120-
chatCompletionInput.getContexts(),
121-
chatCompletionInput.getLlmMessages()
120+
chatCompletionInput.getContexts()
122121
);
123122
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages);
123+
// log.info("Messages to LLM: {}", messages);
124124
} else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK
125125
|| chatCompletionInput.getModelProvider() == ModelProvider.COHERE
126126
|| chatCompletionInput.getLlmResponseField() != null) {
@@ -136,19 +136,6 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet
136136
chatCompletionInput.getContexts()
137137
)
138138
);
139-
} else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK_CONVERSE) {
140-
// Bedrock Converse API does not include the system prompt as part of the Messages block.
141-
String messages = PromptUtil
142-
.getChatCompletionPrompt(
143-
chatCompletionInput.getModelProvider(),
144-
null,
145-
chatCompletionInput.getUserInstructions(),
146-
chatCompletionInput.getQuestion(),
147-
chatCompletionInput.getChatHistory(),
148-
chatCompletionInput.getContexts(),
149-
chatCompletionInput.getLlmMessages()
150-
);
151-
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages);
152139
} else {
153140
throw new IllegalArgumentException(
154141
"Unknown/unsupported model provider: "
@@ -157,6 +144,7 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet
157144
);
158145
}
159146

147+
// log.info("LLM input parameters: {}", inputParameters.toString());
160148
return inputParameters;
161149
}
162150

@@ -196,20 +184,6 @@ protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider,
196184
} else if (provider == ModelProvider.COHERE) {
197185
answerField = "text";
198186
fillAnswersOrErrors(dataAsMap, answers, errors, answerField, errorField, defaultErrorMessageField);
199-
} else if (provider == ModelProvider.BEDROCK_CONVERSE) {
200-
Map output = (Map) dataAsMap.get("output");
201-
Map message = (Map) output.get("message");
202-
if (message != null) {
203-
List content = (List) message.get("content");
204-
String answer = (String) ((Map) content.get(0)).get("text");
205-
answers.add(answer);
206-
} else {
207-
Map error = (Map) output.get("error");
208-
if (error == null) {
209-
throw new RuntimeException("Unexpected output: " + output);
210-
}
211-
errors.add((String) error.get("message"));
212-
}
213187
} else {
214188
throw new IllegalArgumentException(
215189
"Unknown/unsupported model provider: " + provider + ". You must provide a valid model provider or llm_response_field."

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ public interface Llm {
2828
enum ModelProvider {
2929
OPENAI,
3030
BEDROCK,
31-
COHERE,
32-
BEDROCK_CONVERSE
31+
COHERE
3332
}
3433

3534
void doChatCompletion(ChatCompletionInput input, ActionListener<ChatCompletionOutput> listener);

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java

+3-9
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ public class LlmIOUtil {
2929

3030
public static final String BEDROCK_PROVIDER_PREFIX = "bedrock/";
3131
public static final String COHERE_PROVIDER_PREFIX = "cohere/";
32-
public static final String BEDROCK_CONVERSE__PROVIDER_PREFIX = "bedrock-converse/";
3332

3433
public static ChatCompletionInput createChatCompletionInput(
3534
String llmModel,
@@ -50,8 +49,7 @@ public static ChatCompletionInput createChatCompletionInput(
5049
chatHistory,
5150
contexts,
5251
timeoutInSeconds,
53-
llmResponseField,
54-
null
52+
llmResponseField
5553
);
5654
}
5755

@@ -63,8 +61,7 @@ public static ChatCompletionInput createChatCompletionInput(
6361
List<Interaction> chatHistory,
6462
List<String> contexts,
6563
int timeoutInSeconds,
66-
String llmResponseField,
67-
List<MessageBlock> llmMessages
64+
String llmResponseField
6865
) {
6966
Llm.ModelProvider provider = null;
7067
if (llmResponseField == null) {
@@ -74,8 +71,6 @@ public static ChatCompletionInput createChatCompletionInput(
7471
provider = Llm.ModelProvider.BEDROCK;
7572
} else if (llmModel.startsWith(COHERE_PROVIDER_PREFIX)) {
7673
provider = Llm.ModelProvider.COHERE;
77-
} else if (llmModel.startsWith(BEDROCK_CONVERSE__PROVIDER_PREFIX)) {
78-
provider = Llm.ModelProvider.BEDROCK_CONVERSE;
7974
}
8075
}
8176
}
@@ -88,8 +83,7 @@ public static ChatCompletionInput createChatCompletionInput(
8883
systemPrompt,
8984
userInstructions,
9085
provider,
91-
llmResponseField,
92-
llmMessages
86+
llmResponseField
9387
);
9488
}
9589
}

0 commit comments

Comments
 (0)