Skip to content

Commit f64e3f3

Browse files
authored
Support list in response body (opensearch-project#2811)
1 parent 9663053 commit f64e3f3

File tree

2 files changed

+136
-2
lines changed

2 files changed

+136
-2
lines changed

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

+27-2
Original file line numberDiff line numberDiff line change
@@ -324,13 +324,38 @@ public <T> T createPayload(String action, Map<String, String> parameters) {
324324
payload = fillNullParameters(parameters, payload);
325325
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
326326
payload = substitutor.replace(payload);
327-
328327
if (!isJson(payload)) {
329-
throw new IllegalArgumentException("Invalid payload: " + payload);
328+
String payloadAfterEscape = connectorAction.get().getRequestBody();
329+
Map<String, String> escapedParameters = escapeMapValues(parameters);
330+
StringSubstitutor escapedSubstitutor = new StringSubstitutor(escapedParameters, "${parameters.", "}");
331+
payloadAfterEscape = escapedSubstitutor.replace(payloadAfterEscape);
332+
if (!isJson(payloadAfterEscape)) {
333+
throw new IllegalArgumentException("Invalid payload: " + payload);
334+
} else {
335+
payload = payloadAfterEscape;
336+
}
330337
}
331338
return (T) payload;
332339
}
333340
return (T) parameters.get("http_body");
341+
342+
}
343+
344+
public static Map<String, String> escapeMapValues(Map<String, String> parameters) {
345+
Map<String, String> escapedMap = new HashMap<>();
346+
if (parameters != null) {
347+
for (Map.Entry<String, String> entry : parameters.entrySet()) {
348+
String key = entry.getKey();
349+
String value = entry.getValue();
350+
String escapedValue = escapeValue(value);
351+
escapedMap.put(key, escapedValue);
352+
}
353+
}
354+
return escapedMap;
355+
}
356+
357+
private static String escapeValue(String value) {
358+
return value.replace("\\", "\\\\").replace("\"", "\\\"").replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t");
334359
}
335360

336361
protected String fillNullParameters(Map<String, String> parameters, String payload) {

common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java

+109
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.opensearch.ml.common.connector;
77

88
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
9+
import static org.opensearch.ml.common.utils.StringUtils.toJson;
910

1011
import java.io.IOException;
1112
import java.util.ArrayList;
@@ -183,6 +184,114 @@ public void createPayload_InvalidJson() {
183184
connector.validatePayload(predictPayload);
184185
}
185186

187+
@Test
188+
public void createPayloadWithString() {
189+
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
190+
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
191+
Map<String, String> parameters = new HashMap<>();
192+
193+
parameters.put("prompt", "answer question based on context: ${parameters.context}");
194+
parameters.put("context", "document1");
195+
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
196+
connector.validatePayload(predictPayload);
197+
Assert.assertEquals("{\"prompt\": \"answer question based on context: document1\"}", predictPayload);
198+
}
199+
200+
@Test
201+
public void createPayloadWithList() {
202+
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
203+
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
204+
Map<String, String> parameters = new HashMap<>();
205+
parameters.put("prompt", "answer question based on context: ${parameters.context}");
206+
ArrayList<String> listOfDocuments = new ArrayList<>();
207+
listOfDocuments.add("document1");
208+
listOfDocuments.add("document2");
209+
parameters.put("context", toJson(listOfDocuments));
210+
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
211+
connector.validatePayload(predictPayload);
212+
}
213+
214+
@Test
215+
public void createPayloadWithNestedList() {
216+
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
217+
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
218+
Map<String, String> parameters = new HashMap<>();
219+
parameters.put("prompt", "answer question based on context: ${parameters.context}");
220+
ArrayList<String> listOfDocuments = new ArrayList<>();
221+
listOfDocuments.add("document1");
222+
ArrayList<String> NestedListOfDocuments = new ArrayList<>();
223+
NestedListOfDocuments.add("document2");
224+
listOfDocuments.add(toJson(NestedListOfDocuments));
225+
parameters.put("context", toJson(listOfDocuments));
226+
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
227+
connector.validatePayload(predictPayload);
228+
}
229+
230+
@Test
231+
public void createPayloadWithMap() {
232+
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
233+
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
234+
Map<String, String> parameters = new HashMap<>();
235+
parameters.put("prompt", "answer question based on context: ${parameters.context}");
236+
Map<String, String> mapOfDocuments = new HashMap<>();
237+
mapOfDocuments.put("name", "John");
238+
parameters.put("context", toJson(mapOfDocuments));
239+
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
240+
connector.validatePayload(predictPayload);
241+
}
242+
243+
@Test
244+
public void createPayloadWithNestedMapOfString() {
245+
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
246+
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
247+
Map<String, String> parameters = new HashMap<>();
248+
parameters.put("prompt", "answer question based on context: ${parameters.context}");
249+
Map<String, String> mapOfDocuments = new HashMap<>();
250+
mapOfDocuments.put("name", "John");
251+
Map<String, String> nestedMapOfDocuments = new HashMap<>();
252+
nestedMapOfDocuments.put("city", "New York");
253+
mapOfDocuments.put("hometown", toJson(nestedMapOfDocuments));
254+
parameters.put("context", toJson(mapOfDocuments));
255+
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
256+
connector.validatePayload(predictPayload);
257+
}
258+
259+
@Test
260+
public void createPayloadWithNestedMapOfObject() {
261+
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
262+
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
263+
Map<String, String> parameters = new HashMap<>();
264+
parameters.put("prompt", "answer question based on context: ${parameters.context}");
265+
Map<String, Object> mapOfDocuments = new HashMap<>();
266+
mapOfDocuments.put("name", "John");
267+
Map<String, String> nestedMapOfDocuments = new HashMap<>();
268+
nestedMapOfDocuments.put("city", "New York");
269+
mapOfDocuments.put("hometown", nestedMapOfDocuments);
270+
parameters.put("context", toJson(mapOfDocuments));
271+
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
272+
connector.validatePayload(predictPayload);
273+
}
274+
275+
@Test
276+
public void createPayloadWithNestedListOfMapOfObject() {
277+
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
278+
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
279+
Map<String, String> parameters = new HashMap<>();
280+
parameters.put("prompt", "answer question based on context: ${parameters.context}");
281+
ArrayList<String> listOfDocuments = new ArrayList<>();
282+
listOfDocuments.add("document1");
283+
ArrayList<Object> NestedListOfDocuments = new ArrayList<>();
284+
Map<String, Object> mapOfDocuments = new HashMap<>();
285+
mapOfDocuments.put("name", "John");
286+
Map<String, String> nestedMapOfDocuments = new HashMap<>();
287+
nestedMapOfDocuments.put("city", "New York");
288+
mapOfDocuments.put("hometown", nestedMapOfDocuments);
289+
listOfDocuments.add(toJson(NestedListOfDocuments));
290+
parameters.put("context", toJson(listOfDocuments));
291+
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
292+
connector.validatePayload(predictPayload);
293+
}
294+
186295
@Test
187296
public void createPayload() {
188297
HttpConnector connector = createHttpConnector();

0 commit comments

Comments
 (0)