Skip to content

Commit 4a5464e

Browse files
xinyualgithub-actions[bot]
authored andcommitted
Fix argument pass (#1941)
* add logs Signed-off-by: xinyual <xinyual@amazon.com> * option2FixBug Signed-off-by: xinyual <xinyual@amazon.com> * remove useless log Signed-off-by: xinyual <xinyual@amazon.com> * remove useless log Signed-off-by: xinyual <xinyual@amazon.com> * fix spot less Signed-off-by: xinyual <xinyual@amazon.com> * change argument parsing Signed-off-by: xinyual <xinyual@amazon.com> * move common function to utils Signed-off-by: xinyual <xinyual@amazon.com> * checkout for typo Signed-off-by: xinyual <xinyual@amazon.com> * remove useless code Signed-off-by: xinyual <xinyual@amazon.com> * add UTs Signed-off-by: xinyual <xinyual@amazon.com> * apply spotless Signed-off-by: xinyual <xinyual@amazon.com> * add more uts Signed-off-by: xinyual <xinyual@amazon.com> * modify import Signed-off-by: xinyual <xinyual@amazon.com> * protect original parameters Signed-off-by: xinyual <xinyual@amazon.com> --------- Signed-off-by: xinyual <xinyual@amazon.com> (cherry picked from commit 4a6ceba)
1 parent 3c9a76d commit 4a5464e

File tree

5 files changed

+215
-20
lines changed

5 files changed

+215
-20
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java

+33
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS;
1616
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES;
1717

18+
import java.security.AccessController;
19+
import java.security.PrivilegedActionException;
20+
import java.security.PrivilegedExceptionAction;
1821
import java.util.HashMap;
1922
import java.util.List;
2023
import java.util.Map;
@@ -23,6 +26,8 @@
2326
import java.util.regex.Pattern;
2427

2528
import org.apache.commons.text.StringSubstitutor;
29+
import org.opensearch.ml.common.output.model.ModelTensor;
30+
import org.opensearch.ml.common.output.model.ModelTensorOutput;
2631
import org.opensearch.ml.common.spi.tools.Tool;
2732

2833
public class AgentUtils {
@@ -152,4 +157,32 @@ public static String extractModelResponseJson(String text) {
152157
throw new IllegalArgumentException("Model output is invalid");
153158
}
154159
}
160+
161+
public static String outputToOutputString(Object output) throws PrivilegedActionException {
162+
String outputString;
163+
if (output instanceof ModelTensorOutput) {
164+
ModelTensor outputModel = ((ModelTensorOutput) output).getMlModelOutputs().get(0).getMlModelTensors().get(0);
165+
if (outputModel.getDataAsMap() != null) {
166+
outputString = AccessController
167+
.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(outputModel.getDataAsMap()));
168+
} else {
169+
outputString = outputModel.getResult();
170+
}
171+
} else if (output instanceof String) {
172+
outputString = (String) output;
173+
} else {
174+
outputString = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(output));
175+
}
176+
return outputString;
177+
}
178+
179+
public static String parseInputFromLLMReturn(Map<String, ?> retMap) {
180+
Object actionInput = retMap.get("action_input");
181+
if (actionInput instanceof Map) {
182+
return gson.toJson(actionInput);
183+
} else {
184+
return String.valueOf(actionInput);
185+
}
186+
187+
}
155188
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

+15-11
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD;
1010
import static org.opensearch.ml.common.utils.StringUtils.gson;
1111
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.extractModelResponseJson;
12+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.outputToOutputString;
13+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseInputFromLLMReturn;
1214

13-
import java.security.AccessController;
14-
import java.security.PrivilegedExceptionAction;
1515
import java.util.ArrayList;
1616
import java.util.Collections;
1717
import java.util.HashMap;
@@ -325,7 +325,7 @@ private void runReAct(
325325
}
326326
String thought = String.valueOf(dataAsMap.get("thought"));
327327
String action = String.valueOf(dataAsMap.get("action"));
328-
String actionInput = String.valueOf(dataAsMap.get("action_input"));
328+
String actionInput = parseInputFromLLMReturn(dataAsMap);
329329
String finalAnswer = (String) dataAsMap.get("final_answer");
330330
if (!dataAsMap.containsKey("thought")) {
331331
String response = (String) dataAsMap.get("response");
@@ -336,7 +336,7 @@ private void runReAct(
336336
Map map = gson.fromJson(jsonBlock, Map.class);
337337
thought = String.valueOf(map.get("thought"));
338338
action = String.valueOf(map.get("action"));
339-
actionInput = String.valueOf(map.get("action_input"));
339+
actionInput = parseInputFromLLMReturn(map);
340340
finalAnswer = (String) map.get("final_answer");
341341
} else {
342342
finalAnswer = response;
@@ -509,9 +509,7 @@ private void runReAct(
509509
} else {
510510
MLToolSpec toolSpec = toolSpecMap.get(lastAction.get());
511511
if (toolSpec != null && toolSpec.isIncludeOutputInAgentResponse()) {
512-
String outputString = output instanceof String
513-
? (String) output
514-
: AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(output));
512+
String outputString = outputToOutputString(output);
515513

516514
String toolOutputKey = String.format("%s.output", toolSpec.getType());
517515
if (additionalInfo.get(toolOutputKey) != null) {
@@ -531,7 +529,13 @@ private void runReAct(
531529
.singletonList(
532530
ModelTensor
533531
.builder()
534-
.dataAsMap(ImmutableMap.of("response", lastThought.get() + "\nObservation: " + output))
532+
.dataAsMap(
533+
ImmutableMap
534+
.of(
535+
"response",
536+
lastThought.get() + "\nObservation: " + outputToOutputString(output)
537+
)
538+
)
535539
.build()
536540
)
537541
)
@@ -540,7 +544,7 @@ private void runReAct(
540544

541545
String toolResponse = tmpParameters.get("prompt.tool_response");
542546
StringSubstitutor toolResponseSubstitutor = new StringSubstitutor(
543-
ImmutableMap.of("observation", output),
547+
ImmutableMap.of("observation", outputToOutputString(output)),
544548
"${parameters.",
545549
"}"
546550
);
@@ -552,7 +556,7 @@ private void runReAct(
552556
.conversationIndexMessageBuilder()
553557
.type("ReAct")
554558
.question(lastActionInput.get())
555-
.response((String) output)
559+
.response(outputToOutputString(output))
556560
.finalAnswer(false)
557561
.sessionId(sessionId)
558562
.build();
@@ -567,7 +571,7 @@ private void runReAct(
567571
newPrompt.set(substitutor.replace(finalPrompt));
568572
tmpParameters.put(PROMPT, newPrompt.get());
569573

570-
sessionMsgAnswerBuilder.append("\nObservation: ").append(output);
574+
sessionMsgAnswerBuilder.append("\nObservation: ").append(outputToOutputString(output));
571575
cotModelTensors
572576
.add(
573577
ModelTensors

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java

+19-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
package org.opensearch.ml.engine.tools;
77

8+
import static org.opensearch.ml.common.utils.StringUtils.gson;
9+
10+
import java.util.HashMap;
811
import java.util.Map;
912

1013
import org.opensearch.action.ActionRequest;
@@ -51,11 +54,12 @@ public AgentTool(Client client, String agentId) {
5154

5255
@Override
5356
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
57+
Map<String, String> extractedParameters = extractInputParameters(parameters);
5458
AgentMLInput agentMLInput = AgentMLInput
5559
.AgentMLInputBuilder()
5660
.agentId(agentId)
5761
.functionName(FunctionName.AGENT)
58-
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(parameters).build())
62+
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(extractedParameters).build())
5963
.build();
6064
ActionRequest request = new MLExecuteTaskRequest(FunctionName.AGENT, agentMLInput, false);
6165
client.execute(MLExecuteTaskAction.INSTANCE, request, ActionListener.wrap(r -> {
@@ -135,4 +139,18 @@ public String getDefaultVersion() {
135139
return null;
136140
}
137141
}
142+
143+
private Map<String, String> extractInputParameters(Map<String, String> parameters) {
144+
Map<String, String> extractedParameters = new HashMap<>();
145+
extractedParameters.putAll(parameters);
146+
if (parameters.containsKey("input")) {
147+
try {
148+
Map<String, String> chatParameters = gson.fromJson(parameters.get("input"), Map.class);
149+
extractedParameters.putAll(chatParameters);
150+
} catch (Exception exception) {
151+
log.info("fail extract parameters from key 'input' due to" + exception.getMessage());
152+
}
153+
}
154+
return extractedParameters;
155+
}
138156
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java

+126
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,86 @@ public void testParsingJsonBlockFromResponse() {
188188
assertEquals("parsed final answer", modelTensor2.getResult());
189189
}
190190

191+
@Test
192+
public void testParsingJsonBlockFromResponse2() {
193+
// Prepare the response with JSON block
194+
String jsonBlock = "{\"thought\":\"parsed thought\", \"action\":\"parsed action\", "
195+
+ "\"action_input\":\"parsed action input\", \"final_answer\":\"parsed final answer\"}";
196+
String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text";
197+
198+
// Mock LLM response to not contain "thought" but contain "response" with JSON block
199+
Map<String, String> llmResponse = new HashMap<>();
200+
llmResponse.put("response", responseWithJsonBlock);
201+
doAnswer(getLLMAnswer(llmResponse))
202+
.when(client)
203+
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));
204+
205+
// Create an MLAgent and run the MLChatAgentRunner
206+
MLAgent mlAgent = createMLAgentWithTools();
207+
Map<String, String> params = new HashMap<>();
208+
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
209+
params.put("verbose", "true");
210+
mlChatAgentRunner.run(mlAgent, params, agentActionListener);
211+
212+
// Capture the response passed to the listener
213+
ArgumentCaptor<Object> responseCaptor = ArgumentCaptor.forClass(Object.class);
214+
verify(agentActionListener).onResponse(responseCaptor.capture());
215+
216+
// Extract the captured response
217+
Object capturedResponse = responseCaptor.getValue();
218+
assertTrue(capturedResponse instanceof ModelTensorOutput);
219+
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;
220+
221+
ModelTensor parentInteractionModelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1);
222+
ModelTensor modelTensor1 = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0);
223+
ModelTensor modelTensor2 = modelTensorOutput.getMlModelOutputs().get(2).getMlModelTensors().get(0);
224+
225+
// Verify that the parsed values from JSON block are correctly set
226+
assertEquals("parent_interaction_id", parentInteractionModelTensor.getResult());
227+
assertEquals("Thought: parsed thought", modelTensor1.getResult());
228+
assertEquals("parsed final answer", modelTensor2.getResult());
229+
}
230+
231+
@Test
232+
public void testParsingJsonBlockFromResponse3() {
233+
// Prepare the response with JSON block
234+
String jsonBlock = "{\"thought\":\"parsed thought\", \"action\":\"parsed action\", "
235+
+ "\"action_input\":{\"a\":\"n\"}, \"final_answer\":\"parsed final answer\"}";
236+
String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text";
237+
238+
// Mock LLM response to not contain "thought" but contain "response" with JSON block
239+
Map<String, String> llmResponse = new HashMap<>();
240+
llmResponse.put("response", responseWithJsonBlock);
241+
doAnswer(getLLMAnswer(llmResponse))
242+
.when(client)
243+
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));
244+
245+
// Create an MLAgent and run the MLChatAgentRunner
246+
MLAgent mlAgent = createMLAgentWithTools();
247+
Map<String, String> params = new HashMap<>();
248+
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
249+
params.put("verbose", "true");
250+
mlChatAgentRunner.run(mlAgent, params, agentActionListener);
251+
252+
// Capture the response passed to the listener
253+
ArgumentCaptor<Object> responseCaptor = ArgumentCaptor.forClass(Object.class);
254+
verify(agentActionListener).onResponse(responseCaptor.capture());
255+
256+
// Extract the captured response
257+
Object capturedResponse = responseCaptor.getValue();
258+
assertTrue(capturedResponse instanceof ModelTensorOutput);
259+
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;
260+
261+
ModelTensor parentInteractionModelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1);
262+
ModelTensor modelTensor1 = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0);
263+
ModelTensor modelTensor2 = modelTensorOutput.getMlModelOutputs().get(2).getMlModelTensors().get(0);
264+
265+
// Verify that the parsed values from JSON block are correctly set
266+
assertEquals("parent_interaction_id", parentInteractionModelTensor.getResult());
267+
assertEquals("Thought: parsed thought", modelTensor1.getResult());
268+
assertEquals("parsed final answer", modelTensor2.getResult());
269+
}
270+
191271
@Test
192272
public void testRunWithIncludeOutputNotSet() {
193273
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
@@ -209,6 +289,35 @@ public void testRunWithIncludeOutputNotSet() {
209289
assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response"));
210290
}
211291

292+
@Test
293+
public void testRunWithIncludeOutputMLModel() {
294+
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
295+
Mockito
296+
.doAnswer(generateToolResponseAsMLModelResult("First tool response", 1))
297+
.when(firstTool)
298+
.run(Mockito.anyMap(), toolListenerCaptor.capture());
299+
Mockito
300+
.doAnswer(generateToolResponseAsMLModelResult("Second tool response", 2))
301+
.when(secondTool)
302+
.run(Mockito.anyMap(), toolListenerCaptor.capture());
303+
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build();
304+
MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).build();
305+
final MLAgent mlAgent = MLAgent
306+
.builder()
307+
.name("TestAgent")
308+
.llm(llmSpec)
309+
.memory(mlMemorySpec)
310+
.tools(Arrays.asList(firstToolSpec, secondToolSpec))
311+
.build();
312+
mlChatAgentRunner.run(mlAgent, new HashMap<>(), agentActionListener);
313+
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
314+
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
315+
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors();
316+
assertEquals(1, agentOutput.size());
317+
// Respond with last tool output
318+
assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response"));
319+
}
320+
212321
@Test
213322
public void testRunWithIncludeOutputSet() {
214323
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
@@ -512,6 +621,23 @@ private Answer generateToolResponse(String response) {
512621
};
513622
}
514623

624+
private Answer generateToolResponseAsMLModelResult(String response, int type) {
625+
ModelTensor modelTensor;
626+
if (type == 1) {
627+
modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("return", response)).build();
628+
} else {
629+
modelTensor = ModelTensor.builder().result(response).build();
630+
}
631+
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
632+
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
633+
634+
return invocation -> {
635+
ActionListener<Object> listener = invocation.getArgument(1);
636+
listener.onResponse(mlModelTensorOutput);
637+
return null;
638+
};
639+
}
640+
515641
private Answer generateToolFailure(Exception e) {
516642
return invocation -> {
517643
ActionListener<Object> listener = invocation.getArgument(1);

ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java

+22-8
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import static org.mockito.ArgumentMatchers.eq;
1111
import static org.mockito.Mockito.doAnswer;
1212
import static org.mockito.Mockito.verify;
13+
import static org.opensearch.ml.common.utils.StringUtils.gson;
1314
import static org.opensearch.ml.engine.tools.AgentTool.DEFAULT_DESCRIPTION;
1415

1516
import java.util.Arrays;
@@ -24,8 +25,6 @@
2425
import org.opensearch.client.Client;
2526
import org.opensearch.core.action.ActionListener;
2627
import org.opensearch.ml.common.FunctionName;
27-
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
28-
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
2928
import org.opensearch.ml.common.output.model.ModelTensor;
3029
import org.opensearch.ml.common.output.model.ModelTensorOutput;
3130
import org.opensearch.ml.common.output.model.ModelTensors;
@@ -67,13 +66,28 @@ public void setup() {
6766
public void testAgenttestRunMethod() {
6867
Map<String, String> parameters = new HashMap<>();
6968
parameters.put("testKey", "testValue");
70-
AgentMLInput agentMLInput = AgentMLInput
71-
.AgentMLInputBuilder()
72-
.agentId("agentId")
73-
.functionName(FunctionName.AGENT)
74-
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(parameters).build())
75-
.build();
69+
doTestRunMethod(parameters);
70+
}
71+
72+
@Test
73+
public void testAgentWithChatAgentInput() {
74+
Map<String, String> parameters = new HashMap<>();
75+
parameters.put("testKey", "testValue");
76+
Map<String, String> chatAgentInput = new HashMap<>();
77+
chatAgentInput.put("input", gson.toJson(parameters));
78+
doTestRunMethod(chatAgentInput);
79+
assertEquals(chatAgentInput.size(), 1);
80+
assertEquals(chatAgentInput.get("input"), gson.toJson(parameters)); // assert no influence on original parameters
81+
}
82+
83+
@Test
84+
public void testAgentWithChatAgentInputWrongFormat() {
85+
Map<String, String> chatAgentInput = new HashMap<>();
86+
chatAgentInput.put("input", "wrong format");
87+
doTestRunMethod(chatAgentInput);
88+
}
7689

90+
private void doTestRunMethod(Map<String, String> parameters) {
7791
ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("thought", "thought 1", "action", "action1")).build();
7892
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
7993
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();

0 commit comments

Comments
 (0)