Skip to content

Commit 51cba97

Browse files
[Backport 2.x] Fix MLModelTool returns null if the response of LLM is a pure json object (opensearch-project#2675) (opensearch-project#2685)
* Fix MLModelTool returns null if the response of LLM is a pure json object (opensearch-project#2655) * Fix MLModelTool returns null if the response of LLM is a pure json object Signed-off-by: Heng Qian <qianheng@amazon.com> * Fix UT failure Signed-off-by: Heng Qian <qianheng@amazon.com> * Avoid NPE Signed-off-by: Heng Qian <qianheng@amazon.com> * spotlessApply Signed-off-by: Heng Qian <qianheng@amazon.com> --------- Signed-off-by: Heng Qian <qianheng@amazon.com> (cherry picked from commit 007b914) * remove java21 API for backporting to 2.x Signed-off-by: Heng Qian <qianheng@amazon.com> --------- Signed-off-by: Heng Qian <qianheng@amazon.com> (cherry picked from commit 0a6a2b0) Co-authored-by: qianheng <qianheng@amazon.com>
1 parent dcfe439 commit 51cba97

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java

+14-2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
2222
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
2323
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
24+
import org.opensearch.ml.common.utils.StringUtils;
2425
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
2526

2627
import lombok.Getter;
@@ -54,6 +55,7 @@ public class MLModelTool implements Tool {
5455
private Parser inputParser;
5556
@Setter
5657
@Getter
58+
@VisibleForTesting
5759
private Parser outputParser;
5860
@Setter
5961
@Getter
@@ -65,8 +67,18 @@ public MLModelTool(Client client, String modelId, String responseField) {
6567
this.responseField = responseField;
6668

6769
outputParser = o -> {
68-
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o;
69-
return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get(responseField);
70+
try {
71+
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o;
72+
Map<String, ?> dataAsMap = mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap();
73+
// Return the response field if it exists, otherwise return the whole response as json string.
74+
if (dataAsMap.containsKey(responseField)) {
75+
return dataAsMap.get(responseField);
76+
} else {
77+
return StringUtils.toJson(dataAsMap);
78+
}
79+
} catch (Exception e) {
80+
throw new IllegalStateException("LLM returns wrong or empty tensors", e);
81+
}
7082
};
7183
}
7284

ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java

+21-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ public void testMLModelsWithDefaultOutputParserAndMalformedResponseField() throw
124124
tool.run(null, listener);
125125

126126
future.join();
127-
assertEquals(null, future.get());
127+
assertEquals("{\"response\":\"response 1\",\"action\":\"action1\"}", future.get());
128128
}
129129

130130
@Test
@@ -170,6 +170,26 @@ public void testOutputParserLambda() {
170170
assertEquals("testResponse", result);
171171
}
172172

173+
@Test
174+
public void testOutputParserWithJsonResponse() {
175+
Parser outputParser = new MLModelTool(client, "modelId", "response").getOutputParser();
176+
String expectedJson = "{\"key1\":\"value1\",\"key2\":\"value2\"}";
177+
178+
// Create a mock ModelTensors with json object
179+
ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("key1", "value1", "key2", "value2")).build();
180+
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
181+
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
182+
Object result = outputParser.parse(mlModelTensorOutput.getMlModelOutputs());
183+
assertEquals(expectedJson, result);
184+
185+
// Create a mock ModelTensors with response string
186+
modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "{\"key1\":\"value1\",\"key2\":\"value2\"}")).build();
187+
modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
188+
mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
189+
result = outputParser.parse(mlModelTensorOutput.getMlModelOutputs());
190+
assertEquals(expectedJson, result);
191+
}
192+
173193
@Test
174194
public void testRunWithError() {
175195
// Mocking the client.execute to simulate an error

0 commit comments

Comments
 (0)