Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix argument pass #1941

Merged
merged 14 commits into from
Feb 2, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES;

import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -23,6 +26,8 @@
import java.util.regex.Pattern;

import org.apache.commons.text.StringSubstitutor;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.spi.tools.Tool;

public class AgentUtils {
Expand Down Expand Up @@ -152,4 +157,32 @@
throw new IllegalArgumentException("Model output is invalid");
}
}

public static String outputToOutputString(Object output) throws PrivilegedActionException {
String outputString;
if (output instanceof ModelTensorOutput) {
ModelTensor outputModel = ((ModelTensorOutput) output).getMlModelOutputs().get(0).getMlModelTensors().get(0);
if (outputModel.getDataAsMap() != null) {
outputString = AccessController
.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(outputModel.getDataAsMap()));
} else {
outputString = outputModel.getResult();
}
} else if (output instanceof String) {
outputString = (String) output;
} else {
outputString = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(output));

Check warning on line 174 in ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java

View check run for this annotation

Codecov / codecov/patch

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java#L174

Added line #L174 was not covered by tests
}
return outputString;
}

public static String parseInputFromLLMReturn(Map<String, ?> retMap) {
Object actionInput = retMap.get("action_input");
if (actionInput instanceof Map) {
return gson.toJson(actionInput);
} else {
return String.valueOf(actionInput);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.extractModelResponseJson;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.outputToOutputString;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseInputFromLLMReturn;

import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -325,7 +325,7 @@
}
String thought = String.valueOf(dataAsMap.get("thought"));
String action = String.valueOf(dataAsMap.get("action"));
String actionInput = String.valueOf(dataAsMap.get("action_input"));
String actionInput = parseInputFromLLMReturn(dataAsMap);
String finalAnswer = (String) dataAsMap.get("final_answer");
if (!dataAsMap.containsKey("thought")) {
String response = (String) dataAsMap.get("response");
Expand All @@ -336,7 +336,7 @@
Map map = gson.fromJson(jsonBlock, Map.class);
thought = String.valueOf(map.get("thought"));
action = String.valueOf(map.get("action"));
actionInput = String.valueOf(map.get("action_input"));
actionInput = parseInputFromLLMReturn(map);

Check warning on line 339 in ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

View check run for this annotation

Codecov / codecov/patch

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java#L339

Added line #L339 was not covered by tests
finalAnswer = (String) map.get("final_answer");
} else {
finalAnswer = response;
Expand Down Expand Up @@ -524,9 +524,7 @@
} else {
MLToolSpec toolSpec = toolSpecMap.get(lastAction.get());
if (toolSpec != null && toolSpec.isIncludeOutputInAgentResponse()) {
String outputString = output instanceof String
? (String) output
: AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(output));
String outputString = outputToOutputString(output);

String toolOutputKey = String.format("%s.output", toolSpec.getType());
if (additionalInfo.get(toolOutputKey) != null) {
Expand All @@ -546,7 +544,13 @@
.singletonList(
ModelTensor
.builder()
.dataAsMap(ImmutableMap.of("response", lastThought.get() + "\nObservation: " + output))
.dataAsMap(
ImmutableMap
.of(
"response",
lastThought.get() + "\nObservation: " + outputToOutputString(output)
)
)
.build()
)
)
Expand All @@ -555,7 +559,7 @@

String toolResponse = tmpParameters.get("prompt.tool_response");
StringSubstitutor toolResponseSubstitutor = new StringSubstitutor(
ImmutableMap.of("observation", output),
ImmutableMap.of("observation", outputToOutputString(output)),
"${parameters.",
"}"
);
Expand All @@ -567,7 +571,7 @@
.conversationIndexMessageBuilder()
.type("ReAct")
.question(lastActionInput.get())
.response((String) output)
.response(outputToOutputString(output))
.finalAnswer(false)
.sessionId(sessionId)
.build();
Expand All @@ -582,7 +586,7 @@
newPrompt.set(substitutor.replace(finalPrompt));
tmpParameters.put(PROMPT, newPrompt.get());

sessionMsgAnswerBuilder.append("\nObservation: ").append(output);
sessionMsgAnswerBuilder.append("\nObservation: ").append(outputToOutputString(output));
cotModelTensors
.add(
ModelTensors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

package org.opensearch.ml.engine.tools;

import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.util.HashMap;
import java.util.Map;

import org.opensearch.action.ActionRequest;
Expand Down Expand Up @@ -51,11 +54,12 @@ public AgentTool(Client client, String agentId) {

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
Map<String, String> extractedParameters = extractInputParameters(parameters);
AgentMLInput agentMLInput = AgentMLInput
.AgentMLInputBuilder()
.agentId(agentId)
.functionName(FunctionName.AGENT)
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(parameters).build())
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(extractedParameters).build())
.build();
ActionRequest request = new MLExecuteTaskRequest(FunctionName.AGENT, agentMLInput, false);
client.execute(MLExecuteTaskAction.INSTANCE, request, ActionListener.wrap(r -> {
Expand Down Expand Up @@ -135,4 +139,18 @@ public String getDefaultVersion() {
return null;
}
}

private Map<String, String> extractInputParameters(Map<String, String> parameters) {
Map<String, String> extractedParameters = new HashMap<>();
extractedParameters.putAll(parameters);
if (parameters.containsKey("input")) {
try {
Map<String, String> chatParameters = gson.fromJson(parameters.get("input"), Map.class);
extractedParameters.putAll(chatParameters);
} catch (Exception exception) {
log.info("fail extract parameters from key 'input' due to" + exception.getMessage());
}
}
return extractedParameters;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,86 @@ public void testParsingJsonBlockFromResponse() {
assertEquals("parsed final answer", modelTensor2.getResult());
}

@Test
public void testParsingJsonBlockFromResponse2() {
// Prepare the response with JSON block
String jsonBlock = "{\"thought\":\"parsed thought\", \"action\":\"parsed action\", "
+ "\"action_input\":\"parsed action input\", \"final_answer\":\"parsed final answer\"}";
String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text";

// Mock LLM response to not contain "thought" but contain "response" with JSON block
Map<String, String> llmResponse = new HashMap<>();
llmResponse.put("response", responseWithJsonBlock);
doAnswer(getLLMAnswer(llmResponse))
.when(client)
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));

// Create an MLAgent and run the MLChatAgentRunner
MLAgent mlAgent = createMLAgentWithTools();
Map<String, String> params = new HashMap<>();
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
params.put("verbose", "true");
mlChatAgentRunner.run(mlAgent, params, agentActionListener);

// Capture the response passed to the listener
ArgumentCaptor<Object> responseCaptor = ArgumentCaptor.forClass(Object.class);
verify(agentActionListener).onResponse(responseCaptor.capture());

// Extract the captured response
Object capturedResponse = responseCaptor.getValue();
assertTrue(capturedResponse instanceof ModelTensorOutput);
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;

ModelTensor parentInteractionModelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1);
ModelTensor modelTensor1 = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0);
ModelTensor modelTensor2 = modelTensorOutput.getMlModelOutputs().get(2).getMlModelTensors().get(0);

// Verify that the parsed values from JSON block are correctly set
assertEquals("parent_interaction_id", parentInteractionModelTensor.getResult());
assertEquals("Thought: parsed thought", modelTensor1.getResult());
assertEquals("parsed final answer", modelTensor2.getResult());
}

@Test
public void testParsingJsonBlockFromResponse3() {
// Prepare the response with JSON block
String jsonBlock = "{\"thought\":\"parsed thought\", \"action\":\"parsed action\", "
+ "\"action_input\":{\"a\":\"n\"}, \"final_answer\":\"parsed final answer\"}";
String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text";

// Mock LLM response to not contain "thought" but contain "response" with JSON block
Map<String, String> llmResponse = new HashMap<>();
llmResponse.put("response", responseWithJsonBlock);
doAnswer(getLLMAnswer(llmResponse))
.when(client)
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));

// Create an MLAgent and run the MLChatAgentRunner
MLAgent mlAgent = createMLAgentWithTools();
Map<String, String> params = new HashMap<>();
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
params.put("verbose", "true");
mlChatAgentRunner.run(mlAgent, params, agentActionListener);

// Capture the response passed to the listener
ArgumentCaptor<Object> responseCaptor = ArgumentCaptor.forClass(Object.class);
verify(agentActionListener).onResponse(responseCaptor.capture());

// Extract the captured response
Object capturedResponse = responseCaptor.getValue();
assertTrue(capturedResponse instanceof ModelTensorOutput);
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;

ModelTensor parentInteractionModelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1);
ModelTensor modelTensor1 = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0);
ModelTensor modelTensor2 = modelTensorOutput.getMlModelOutputs().get(2).getMlModelTensors().get(0);

// Verify that the parsed values from JSON block are correctly set
assertEquals("parent_interaction_id", parentInteractionModelTensor.getResult());
assertEquals("Thought: parsed thought", modelTensor1.getResult());
assertEquals("parsed final answer", modelTensor2.getResult());
}

@Test
public void testRunWithIncludeOutputNotSet() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
Expand All @@ -209,6 +289,35 @@ public void testRunWithIncludeOutputNotSet() {
assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response"));
}

@Test
public void testRunWithIncludeOutputMLModel() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
Mockito
.doAnswer(generateToolResponseAsMLModelResult("First tool response", 1))
.when(firstTool)
.run(Mockito.anyMap(), toolListenerCaptor.capture());
Mockito
.doAnswer(generateToolResponseAsMLModelResult("Second tool response", 2))
.when(secondTool)
.run(Mockito.anyMap(), toolListenerCaptor.capture());
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build();
MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).build();
final MLAgent mlAgent = MLAgent
.builder()
.name("TestAgent")
.llm(llmSpec)
.memory(mlMemorySpec)
.tools(Arrays.asList(firstToolSpec, secondToolSpec))
.build();
mlChatAgentRunner.run(mlAgent, new HashMap<>(), agentActionListener);
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors();
assertEquals(1, agentOutput.size());
// Respond with last tool output
assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response"));
}

@Test
public void testRunWithIncludeOutputSet() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
Expand Down Expand Up @@ -512,6 +621,23 @@ private Answer generateToolResponse(String response) {
};
}

private Answer generateToolResponseAsMLModelResult(String response, int type) {
ModelTensor modelTensor;
if (type == 1) {
modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("return", response)).build();
} else {
modelTensor = ModelTensor.builder().result(response).build();
}
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();

return invocation -> {
ActionListener<Object> listener = invocation.getArgument(1);
listener.onResponse(mlModelTensorOutput);
return null;
};
}

private Answer generateToolFailure(Exception e) {
return invocation -> {
ActionListener<Object> listener = invocation.getArgument(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.engine.tools.AgentTool.DEFAULT_DESCRIPTION;

import java.util.Arrays;
Expand All @@ -24,8 +25,6 @@
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
Expand Down Expand Up @@ -67,13 +66,28 @@ public void setup() {
public void testAgenttestRunMethod() {
Map<String, String> parameters = new HashMap<>();
parameters.put("testKey", "testValue");
AgentMLInput agentMLInput = AgentMLInput
.AgentMLInputBuilder()
.agentId("agentId")
.functionName(FunctionName.AGENT)
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(parameters).build())
.build();
doTestRunMethod(parameters);
}

@Test
public void testAgentWithChatAgentInput() {
Map<String, String> parameters = new HashMap<>();
parameters.put("testKey", "testValue");
Map<String, String> chatAgentInput = new HashMap<>();
chatAgentInput.put("input", gson.toJson(parameters));
doTestRunMethod(chatAgentInput);
assertEquals(chatAgentInput.size(), 1);
assertEquals(chatAgentInput.get("input"), gson.toJson(parameters)); // assert no influence on original parameters
}

@Test
public void testAgentWithChatAgentInputWrongFormat() {
Map<String, String> chatAgentInput = new HashMap<>();
chatAgentInput.put("input", "wrong format");
doTestRunMethod(chatAgentInput);
}

private void doTestRunMethod(Map<String, String> parameters) {
ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("thought", "thought 1", "action", "action1")).build();
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
Expand Down
Loading