Skip to content

Commit f1a80ac

Browse files
committed
Fix parser poc
Signed-off-by: Sicheng Song <sicheng.song@outlook.com>
1 parent dfc4e0e commit f1a80ac

File tree

3 files changed

+97
-18
lines changed

3 files changed

+97
-18
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java

+57-6
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,11 @@
1515
import static org.opensearch.ml.engine.utils.ScriptUtils.executePostProcessFunction;
1616

1717
import java.io.IOException;
18-
import java.util.ArrayList;
19-
import java.util.HashMap;
20-
import java.util.List;
21-
import java.util.Map;
22-
import java.util.Set;
23-
import java.util.Optional;
18+
import java.util.*;
2419
import java.util.function.Function;
2520

21+
import org.json.JSONObject;
22+
import org.json.JSONArray;
2623
import org.apache.commons.lang3.StringUtils;
2724
import org.apache.commons.text.StringSubstitutor;
2825
import org.opensearch.OpenSearchParseException;
@@ -289,4 +286,58 @@ public static void validateSchema(String schemaString, String instanceString) {
289286
throw new RuntimeException(e.getMessage());
290287
}
291288
}
289+
290+
public static JSONObject processJsonObject(JSONObject jsonObject) {
291+
292+
JSONObject schema = wrapProperties(jsonObject);
293+
JSONObject modifiedSchema = new JSONObject();
294+
295+
if (!schema.has("type") && !schema.has("properties")) {
296+
modifiedSchema.put("properties", schema);
297+
return modifiedSchema;
298+
} else {
299+
return schema;
300+
}
301+
}
302+
303+
public static JSONObject wrapProperties(JSONObject jsonObject) {
304+
JSONObject newObject = new JSONObject();
305+
306+
for (String key : jsonObject.keySet()) {
307+
Object value = jsonObject.get(key);
308+
309+
if (value instanceof JSONObject) {
310+
JSONObject nestedObject = (JSONObject) value;
311+
// Check if the nested object has any schema keyword field
312+
if (nestedObject.has("type") || nestedObject.has("properties") || nestedObject.has("description")) {
313+
newObject.put(key, nestedObject); // Leave as is, because it's likely a schema definition
314+
} else {
315+
// Recurse to handle nested objects and wrap them with "properties"
316+
newObject.put(key, new JSONObject().put("properties", wrapProperties(nestedObject)));
317+
}
318+
} else if (value instanceof JSONArray) {
319+
newObject.put(key, wrapPropertiesInArray((JSONArray) value));
320+
} else {
321+
newObject.put(key, value); // Directly copy the value if it's not an object or array
322+
}
323+
}
324+
325+
return newObject;
326+
}
327+
328+
private static JSONArray wrapPropertiesInArray(JSONArray jsonArray) {
329+
JSONArray newArray = new JSONArray();
330+
for (int i = 0; i < jsonArray.length(); i++) {
331+
Object item = jsonArray.get(i);
332+
if (item instanceof JSONObject) {
333+
newArray.put(processJsonObject((JSONObject) item));
334+
} else if (item instanceof JSONArray) {
335+
newArray.put(wrapPropertiesInArray((JSONArray) item));
336+
} else {
337+
newArray.put(item);
338+
}
339+
}
340+
return newArray;
341+
}
342+
292343
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java

+18-6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import java.util.List;
1515
import java.util.Map;
1616

17+
import org.json.JSONException;
18+
import org.json.JSONObject;
1719
import org.opensearch.OpenSearchStatusException;
1820
import org.opensearch.client.Client;
1921
import org.opensearch.cluster.service.ClusterService;
@@ -161,17 +163,27 @@ && getUserRateLimiterMap().get(user.getName()) != null
161163

162164
void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, String payload, List<ModelTensors> tensorOutputs);
163165

164-
private void validateInputSchema(MLInput mlInput) throws IOException {
166+
private void validateInputSchema(MLInput mlInput) {
165167
if (getConnector().getModelInterface() != null && getConnector().getModelInterface().get("input") != null) {
166-
String schemaString = getConnector().getModelInterface().get("input");
167-
ConnectorUtils.validateSchema(schemaString, mlInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString());
168+
String inputSchemaString = getConnector().getModelInterface().get("input");
169+
try {
170+
JSONObject inputSchemaObject = ConnectorUtils.processJsonObject(new JSONObject(inputSchemaString));
171+
ConnectorUtils.validateSchema(inputSchemaObject.toString(), mlInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString());
172+
} catch (IOException | JSONException e) {
173+
throw new IllegalArgumentException("Error validating input schema: " + e.getMessage());
174+
}
168175
}
169176
}
170177

171-
private void validateOutputSchema(ModelTensor modelTensor) throws IOException {
178+
private void validateOutputSchema(ModelTensor modelTensor) {
172179
if (getConnector().getModelInterface() != null && getConnector().getModelInterface().get("output") != null) {
173-
String schemaString = getConnector().getModelInterface().get("output");
174-
ConnectorUtils.validateSchema(schemaString, modelTensor.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString());
180+
String outputSchemaString = getConnector().getModelInterface().get("output");
181+
try {
182+
JSONObject outputSchemaObject = ConnectorUtils.processJsonObject(new JSONObject(outputSchemaString));
183+
ConnectorUtils.validateSchema(outputSchemaObject.toString(), modelTensor.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString());
184+
} catch (IOException | JSONException e) {
185+
throw new IllegalArgumentException("Error validating output schema: " + e.getMessage());
186+
}
175187
}
176188
}
177189
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java

+22-6
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import java.util.List;
1818
import java.util.Map;
1919

20+
import org.json.JSONObject;
2021
import org.junit.Assert;
2122
import org.junit.Before;
2223
import org.junit.Rule;
@@ -237,20 +238,35 @@ public void processOutput_PostprocessFunction() throws IOException {
237238

238239
@Test
239240
public void testValidateSchema() {
240-
exceptionRule.expect(RuntimeException.class);
241-
exceptionRule.expectMessage("Input is null");
242241
String schema = "{"
243242
+ "\"type\": \"object\","
244243
+ "\"properties\": {"
245244
+ " \"key1\": {\"type\": \"string\"},"
246-
+ " \"key2\": {\"type\": \"string\"}"
247-
+ "},"
248-
+ "\"required\": [\"name\", \"age\"]"
245+
+ " \"key2\": {\"type\": \"integer\"}"
246+
+ "}"
249247
+ "}";
250-
String json = "{\"key1\": true, \"key2\": 123}";
248+
String json = "{\"key1\": \"foo\", \"key2\": 123}";
251249
ConnectorUtils.validateSchema(schema, json);
252250
}
253251

252+
@Test
253+
public void testConvertingInterfaceToValidSchema() {
254+
String schemaString = "{"
255+
+ "\"department\": {"
256+
+ " \"name\": \"string\","
257+
+ " \"employees\": "
258+
+ " {\"name\": \"string\","
259+
+ " \"age\": {\"type\":\"integer\","
260+
+ " \"description\": \"This field should be above zero\"}},"
261+
+ " }, \"foo\": {\"properties\":{\"name\":\"string\"}}"
262+
+ "}";
263+
String expectedString = "{\"properties\":{\"foo\":{\"properties\":{\"name\":\"string\"}},\"department\":{\"properties\":{\"name\":\"string\",\"employees\":{\"properties\":{\"name\":\"string\",\"age\":{\"description\":\"This field should be above zero\",\"type\":\"integer\"}}}}}}}";
264+
JSONObject schemaObject = new JSONObject(schemaString);
265+
JSONObject validSchemaObject = ConnectorUtils.processJsonObject(schemaObject);
266+
Assert.assertEquals(expectedString, validSchemaObject.toString());
267+
}
268+
269+
254270
private void processInput_TextDocsInputDataSet_PreprocessFunction(
255271
String requestBody,
256272
List<String> inputs,

0 commit comments

Comments
 (0)