Skip to content

Commit 96e6c62

Browse files
Add unit tests, improve coverage, clean up code.
Signed-off-by: Austin Lee <austin@aryn.ai>
1 parent 961d54d commit 96e6c62

File tree

14 files changed

+745
-122
lines changed

14 files changed

+745
-122
lines changed

plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java

+369-7
Large diffs are not rendered by default.
Binary file not shown.
Loading

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput;
4848
import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm;
4949
import org.opensearch.searchpipelines.questionanswering.generative.llm.LlmIOUtil;
50-
import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock;
5150
import org.opensearch.searchpipelines.questionanswering.generative.llm.ModelLocator;
5251
import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil;
5352

@@ -241,7 +240,7 @@ public void onResponse(ChatCompletionOutput output) {
241240
.createInteraction(
242241
conversationId,
243242
llmQuestion,
244-
PromptUtil.getPromptTemplate(input.getModelProvider(), systemPrompt, userInstructions),
243+
PromptUtil.getPromptTemplate(systemPrompt, userInstructions),
245244
answer,
246245
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
247246
Collections.singletonMap("metadata", jsonArrayToString(searchResults)),

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java

+13-12
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,8 @@
2020
import java.io.IOException;
2121
import java.util.ArrayList;
2222
import java.util.List;
23-
import java.util.Map;
2423
import java.util.Objects;
2524

26-
import org.opensearch.common.settings.Settings;
2725
import org.opensearch.core.ParseField;
2826
import org.opensearch.core.common.Strings;
2927
import org.opensearch.core.common.io.stream.StreamInput;
@@ -33,15 +31,14 @@
3331
import org.opensearch.core.xcontent.ToXContentObject;
3432
import org.opensearch.core.xcontent.XContentBuilder;
3533
import org.opensearch.core.xcontent.XContentParser;
36-
import org.opensearch.index.analysis.NameOrDefinition;
3734
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants;
35+
import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock;
3836

3937
import com.google.common.base.Preconditions;
4038

4139
import lombok.Getter;
4240
import lombok.NoArgsConstructor;
4341
import lombok.Setter;
44-
import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock;
4542

4643
/**
4744
* Defines parameters for generative QA search pipelines.
@@ -93,7 +90,6 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
9390

9491
static {
9592
PARSER = new ObjectParser<>("generative_qa_parameters", GenerativeQAParameters::new);
96-
// ObjectParser<MessageBlock, Void> objectParser = new ObjectParser<>("llm_message_parser", MessageBlock::new);
9793
PARSER.declareString(GenerativeQAParameters::setConversationId, CONVERSATION_ID);
9894
PARSER.declareString(GenerativeQAParameters::setLlmModel, LLM_MODEL);
9995
PARSER.declareString(GenerativeQAParameters::setLlmQuestion, LLM_QUESTION);
@@ -146,8 +142,6 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
146142
@Getter
147143
private List<MessageBlock> llmMessages = new ArrayList<>();
148144

149-
// private List<MessageBlock> blockList = null;
150-
151145
public GenerativeQAParameters(
152146
String conversationId,
153147
String llmModel,
@@ -159,7 +153,18 @@ public GenerativeQAParameters(
159153
Integer timeout,
160154
String llmResponseField
161155
) {
162-
this(conversationId, llmModel, llmQuestion, systemPrompt, userInstructions, contextSize, interactionSize, timeout, llmResponseField, null);
156+
this(
157+
conversationId,
158+
llmModel,
159+
llmQuestion,
160+
systemPrompt,
161+
userInstructions,
162+
contextSize,
163+
interactionSize,
164+
timeout,
165+
llmResponseField,
166+
null
167+
);
163168
}
164169

165170
public GenerativeQAParameters(
@@ -264,8 +269,4 @@ public boolean equals(Object o) {
264269
public void setMessageBlock(List<MessageBlock> blockList) {
265270
this.llmMessages = blockList;
266271
}
267-
268-
public MessageBlock getMessageBlock() {
269-
return null;
270-
}
271272
}

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet
121121
chatCompletionInput.getLlmMessages()
122122
);
123123
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages);
124-
// log.info("Messages to LLM: {}", messages);
125124
} else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK
126125
|| chatCompletionInput.getModelProvider() == ModelProvider.COHERE
127126
|| chatCompletionInput.getLlmResponseField() != null) {
@@ -138,7 +137,6 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet
138137
)
139138
);
140139
} else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK_CONVERSE) {
141-
// inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel());
142140
// Bedrock Converse API does not include the system prompt as part of the Messages block.
143141
String messages = PromptUtil
144142
.getChatCompletionPrompt(
@@ -159,7 +157,6 @@ protected Map<String, String> getInputParameters(ChatCompletionInput chatComplet
159157
);
160158
}
161159

162-
// log.info("LLM input parameters: {}", inputParameters.toString());
163160
return inputParameters;
164161
}
165162

@@ -208,6 +205,9 @@ protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider,
208205
answers.add(answer);
209206
} else {
210207
Map error = (Map) output.get("error");
208+
if (error == null) {
209+
throw new RuntimeException("Unexpected output: " + output);
210+
}
211211
errors.add((String) error.get("message"));
212212
}
213213
} else {

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/MessageBlock.java

+20-27
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,25 @@
1717
*/
1818
package org.opensearch.searchpipelines.questionanswering.generative.llm;
1919

20-
import com.google.common.base.Preconditions;
21-
import lombok.Getter;
22-
import lombok.Setter;
20+
import java.io.IOException;
21+
import java.util.ArrayList;
22+
import java.util.HashMap;
23+
import java.util.List;
24+
import java.util.Map;
25+
import java.util.Objects;
26+
2327
import org.opensearch.core.common.io.stream.StreamInput;
2428
import org.opensearch.core.common.io.stream.StreamOutput;
2529
import org.opensearch.core.common.io.stream.Writeable;
2630
import org.opensearch.core.xcontent.ToXContent;
2731
import org.opensearch.core.xcontent.XContentBuilder;
2832
import org.opensearch.core.xcontent.XContentParseException;
2933
import org.opensearch.core.xcontent.XContentParser;
30-
import org.opensearch.index.analysis.NameOrDefinition;
3134

32-
import javax.print.Doc;
33-
import java.io.IOException;
34-
import java.util.ArrayList;
35-
import java.util.HashMap;
36-
import java.util.List;
37-
import java.util.Map;
38-
import java.util.Objects;
35+
import com.google.common.base.Preconditions;
36+
37+
import lombok.Getter;
38+
import lombok.Setter;
3939

4040
public class MessageBlock implements Writeable, ToXContent {
4141

@@ -70,10 +70,7 @@ public static MessageBlock fromXContent(XContentParser parser) throws IOExceptio
7070
if (parser.currentToken() == XContentParser.Token.START_OBJECT) {
7171
return new MessageBlock(parser.map());
7272
}
73-
throw new XContentParseException(
74-
parser.getTokenLocation(),
75-
"Expected [VALUE_STRING] or [START_OBJECT], got " + parser.currentToken()
76-
);
73+
throw new XContentParseException(parser.getTokenLocation(), "Expected [START_OBJECT], got " + parser.currentToken());
7774
}
7875

7976
@Override
@@ -129,7 +126,7 @@ public TextBlock(StreamInput in) throws IOException {
129126

130127
@Override
131128
public void writeTo(StreamOutput out) throws IOException {
132-
out.writeString("text");
129+
out.writeString(this.type);
133130
out.writeString(this.text);
134131
}
135132

@@ -175,6 +172,7 @@ public ImageBlock(Map<String, ?> imageBlock) {
175172
}
176173

177174
}
175+
178176
public ImageBlock(String format, String data, String url) {
179177
Preconditions.checkNotNull(format, "format cannot be null.");
180178
if (data == null && url == null) {
@@ -193,7 +191,7 @@ public ImageBlock(StreamInput in) throws IOException {
193191

194192
@Override
195193
public void writeTo(StreamOutput out) throws IOException {
196-
out.writeString("image");
194+
out.writeString(this.type);
197195
out.writeString(this.format);
198196
out.writeOptionalString(this.data);
199197
out.writeOptionalString(this.url);
@@ -215,7 +213,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
215213
}
216214
}
217215

218-
static class DocumentBlock extends AbstractBlock {
216+
public static class DocumentBlock extends AbstractBlock {
219217

220218
@Getter
221219
String type = "document";
@@ -260,6 +258,7 @@ public DocumentBlock(StreamInput in) throws IOException {
260258

261259
@Override
262260
public void writeTo(StreamOutput out) throws IOException {
261+
out.writeString(this.type);
263262
out.writeString(this.format);
264263
out.writeString(this.name);
265264
out.writeString(this.data);
@@ -279,9 +278,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
279278
}
280279

281280
@Getter
281+
@Setter
282282
private String role;
283283

284284
@Getter
285+
@Setter
285286
private List<AbstractBlock> blockList = new ArrayList<>();
286287

287288
public MessageBlock() {}
@@ -290,13 +291,9 @@ public MessageBlock(Map<String, ?> map) {
290291
setMessageBlock(map);
291292
}
292293

293-
// public <T extends AbstractBlock> T get(int index) {
294-
// return (T) this.blockList.get(index);
295-
// }
296-
297294
public void setMessageBlock(Map<String, ?> message) {
298295
Preconditions.checkNotNull(message, "message cannot be null.");
299-
Preconditions.checkState(message.containsKey("role"),"message must have role." );
296+
Preconditions.checkState(message.containsKey("role"), "message must have role.");
300297
Preconditions.checkState(message.containsKey("content"), "message must have content.");
301298

302299
this.role = (String) message.get("role");
@@ -326,7 +323,3 @@ public int hashCode() {
326323
return Objects.hashCode(this.role) + Objects.hashCode(this.blockList);
327324
}
328325
}
329-
330-
331-
332-

0 commit comments

Comments
 (0)