Skip to content

Commit 54d7daf

Browse files
Fix custom prompt substitute with List issue in ml inference search response processor (opensearch-project#2871) (opensearch-project#2874)
(cherry picked from commit 49d4a01) Co-authored-by: Mingshi Liu <mingshl@amazon.com>
1 parent a1a7dbb commit 54d7daf

File tree

7 files changed

+473
-158
lines changed

7 files changed

+473
-158
lines changed

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

+4-27
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol;
1111
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
1212
import static org.opensearch.ml.common.utils.StringUtils.isJson;
13+
import static org.opensearch.ml.common.utils.StringUtils.parseParameters;
1314

1415
import java.io.IOException;
1516
import java.time.Instant;
@@ -322,40 +323,16 @@ public <T> T createPayload(String action, Map<String, String> parameters) {
322323
if (connectorAction.isPresent() && connectorAction.get().getRequestBody() != null) {
323324
String payload = connectorAction.get().getRequestBody();
324325
payload = fillNullParameters(parameters, payload);
326+
parseParameters(parameters);
325327
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
326328
payload = substitutor.replace(payload);
329+
327330
if (!isJson(payload)) {
328-
String payloadAfterEscape = connectorAction.get().getRequestBody();
329-
Map<String, String> escapedParameters = escapeMapValues(parameters);
330-
StringSubstitutor escapedSubstitutor = new StringSubstitutor(escapedParameters, "${parameters.", "}");
331-
payloadAfterEscape = escapedSubstitutor.replace(payloadAfterEscape);
332-
if (!isJson(payloadAfterEscape)) {
333-
throw new IllegalArgumentException("Invalid payload: " + payload);
334-
} else {
335-
payload = payloadAfterEscape;
336-
}
331+
throw new IllegalArgumentException("Invalid payload: " + payload);
337332
}
338333
return (T) payload;
339334
}
340335
return (T) parameters.get("http_body");
341-
342-
}
343-
344-
public static Map<String, String> escapeMapValues(Map<String, String> parameters) {
345-
Map<String, String> escapedMap = new HashMap<>();
346-
if (parameters != null) {
347-
for (Map.Entry<String, String> entry : parameters.entrySet()) {
348-
String key = entry.getKey();
349-
String value = entry.getValue();
350-
String escapedValue = escapeValue(value);
351-
escapedMap.put(key, escapedValue);
352-
}
353-
}
354-
return escapedMap;
355-
}
356-
357-
private static String escapeValue(String value) {
358-
return value.replace("\\", "\\\\").replace("\"", "\\\"").replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t");
359336
}
360337

361338
protected String fillNullParameters(Map<String, String> parameters, String payload) {

common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java

+46
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ public class StringUtils {
5050
static {
5151
gson = new Gson();
5252
}
53+
public static final String TO_STRING_FUNCTION_NAME = ".toString()";
5354

5455
public static boolean isValidJsonString(String Json) {
5556
try {
@@ -233,4 +234,49 @@ public static String getErrorMessage(String errorMessage, String modelId, Boolea
233234
return errorMessage + " Model ID: " + modelId;
234235
}
235236
}
237+
238+
/**
239+
* Collects the prefixes of the toString() method calls present in the values of the given map.
240+
*
241+
* @param map A map containing key-value pairs where the values may contain toString() method calls.
242+
* @return A list of prefixes for the toString() method calls found in the map values.
243+
*/
244+
public static List<String> collectToStringPrefixes(Map<String, String> map) {
245+
List<String> prefixes = new ArrayList<>();
246+
for (String key : map.keySet()) {
247+
String value = map.get(key);
248+
if (value != null) {
249+
Pattern pattern = Pattern.compile("\\$\\{parameters\\.(.+?)\\.toString\\(\\)\\}");
250+
Matcher matcher = pattern.matcher(value);
251+
while (matcher.find()) {
252+
String prefix = matcher.group(1);
253+
prefixes.add(prefix);
254+
}
255+
}
256+
}
257+
return prefixes;
258+
}
259+
260+
/**
261+
* Parses the given parameters map and processes the values containing toString() method calls.
262+
*
263+
* @param parameters A map containing key-value pairs where the values may contain toString() method calls.
264+
* @return A new map with the processed values for the toString() method calls.
265+
*/
266+
public static Map<String, String> parseParameters(Map<String, String> parameters) {
267+
if (parameters != null) {
268+
List<String> toStringParametersPrefixes = collectToStringPrefixes(parameters);
269+
270+
if (!toStringParametersPrefixes.isEmpty()) {
271+
for (String prefix : toStringParametersPrefixes) {
272+
String value = parameters.get(prefix);
273+
if (value != null) {
274+
parameters.put(prefix + TO_STRING_FUNCTION_NAME, processTextDoc(value));
275+
}
276+
}
277+
}
278+
}
279+
return parameters;
280+
}
281+
236282
}

common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java

-109
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
package org.opensearch.ml.common.connector;
77

88
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
9-
import static org.opensearch.ml.common.utils.StringUtils.toJson;
109

1110
import java.io.IOException;
1211
import java.util.ArrayList;
@@ -184,114 +183,6 @@ public void createPayload_InvalidJson() {
184183
connector.validatePayload(predictPayload);
185184
}
186185

187-
@Test
188-
public void createPayloadWithString() {
189-
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
190-
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
191-
Map<String, String> parameters = new HashMap<>();
192-
193-
parameters.put("prompt", "answer question based on context: ${parameters.context}");
194-
parameters.put("context", "document1");
195-
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
196-
connector.validatePayload(predictPayload);
197-
Assert.assertEquals("{\"prompt\": \"answer question based on context: document1\"}", predictPayload);
198-
}
199-
200-
@Test
201-
public void createPayloadWithList() {
202-
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
203-
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
204-
Map<String, String> parameters = new HashMap<>();
205-
parameters.put("prompt", "answer question based on context: ${parameters.context}");
206-
ArrayList<String> listOfDocuments = new ArrayList<>();
207-
listOfDocuments.add("document1");
208-
listOfDocuments.add("document2");
209-
parameters.put("context", toJson(listOfDocuments));
210-
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
211-
connector.validatePayload(predictPayload);
212-
}
213-
214-
@Test
215-
public void createPayloadWithNestedList() {
216-
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
217-
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
218-
Map<String, String> parameters = new HashMap<>();
219-
parameters.put("prompt", "answer question based on context: ${parameters.context}");
220-
ArrayList<String> listOfDocuments = new ArrayList<>();
221-
listOfDocuments.add("document1");
222-
ArrayList<String> NestedListOfDocuments = new ArrayList<>();
223-
NestedListOfDocuments.add("document2");
224-
listOfDocuments.add(toJson(NestedListOfDocuments));
225-
parameters.put("context", toJson(listOfDocuments));
226-
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
227-
connector.validatePayload(predictPayload);
228-
}
229-
230-
@Test
231-
public void createPayloadWithMap() {
232-
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
233-
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
234-
Map<String, String> parameters = new HashMap<>();
235-
parameters.put("prompt", "answer question based on context: ${parameters.context}");
236-
Map<String, String> mapOfDocuments = new HashMap<>();
237-
mapOfDocuments.put("name", "John");
238-
parameters.put("context", toJson(mapOfDocuments));
239-
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
240-
connector.validatePayload(predictPayload);
241-
}
242-
243-
@Test
244-
public void createPayloadWithNestedMapOfString() {
245-
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
246-
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
247-
Map<String, String> parameters = new HashMap<>();
248-
parameters.put("prompt", "answer question based on context: ${parameters.context}");
249-
Map<String, String> mapOfDocuments = new HashMap<>();
250-
mapOfDocuments.put("name", "John");
251-
Map<String, String> nestedMapOfDocuments = new HashMap<>();
252-
nestedMapOfDocuments.put("city", "New York");
253-
mapOfDocuments.put("hometown", toJson(nestedMapOfDocuments));
254-
parameters.put("context", toJson(mapOfDocuments));
255-
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
256-
connector.validatePayload(predictPayload);
257-
}
258-
259-
@Test
260-
public void createPayloadWithNestedMapOfObject() {
261-
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
262-
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
263-
Map<String, String> parameters = new HashMap<>();
264-
parameters.put("prompt", "answer question based on context: ${parameters.context}");
265-
Map<String, Object> mapOfDocuments = new HashMap<>();
266-
mapOfDocuments.put("name", "John");
267-
Map<String, String> nestedMapOfDocuments = new HashMap<>();
268-
nestedMapOfDocuments.put("city", "New York");
269-
mapOfDocuments.put("hometown", nestedMapOfDocuments);
270-
parameters.put("context", toJson(mapOfDocuments));
271-
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
272-
connector.validatePayload(predictPayload);
273-
}
274-
275-
@Test
276-
public void createPayloadWithNestedListOfMapOfObject() {
277-
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
278-
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
279-
Map<String, String> parameters = new HashMap<>();
280-
parameters.put("prompt", "answer question based on context: ${parameters.context}");
281-
ArrayList<String> listOfDocuments = new ArrayList<>();
282-
listOfDocuments.add("document1");
283-
ArrayList<Object> NestedListOfDocuments = new ArrayList<>();
284-
Map<String, Object> mapOfDocuments = new HashMap<>();
285-
mapOfDocuments.put("name", "John");
286-
Map<String, String> nestedMapOfDocuments = new HashMap<>();
287-
nestedMapOfDocuments.put("city", "New York");
288-
mapOfDocuments.put("hometown", nestedMapOfDocuments);
289-
listOfDocuments.add(toJson(NestedListOfDocuments));
290-
parameters.put("context", toJson(listOfDocuments));
291-
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
292-
connector.validatePayload(predictPayload);
293-
}
294-
295186
@Test
296187
public void createPayload() {
297188
HttpConnector connector = createHttpConnector();

0 commit comments

Comments
 (0)