Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Fix argument pass #1993

Merged
merged 1 commit into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES;

import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -23,6 +26,8 @@
import java.util.regex.Pattern;

import org.apache.commons.text.StringSubstitutor;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.spi.tools.Tool;

public class AgentUtils {
Expand Down Expand Up @@ -152,4 +157,32 @@ public static String extractModelResponseJson(String text) {
throw new IllegalArgumentException("Model output is invalid");
}
}

public static String outputToOutputString(Object output) throws PrivilegedActionException {
String outputString;
if (output instanceof ModelTensorOutput) {
ModelTensor outputModel = ((ModelTensorOutput) output).getMlModelOutputs().get(0).getMlModelTensors().get(0);
if (outputModel.getDataAsMap() != null) {
outputString = AccessController
.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(outputModel.getDataAsMap()));
} else {
outputString = outputModel.getResult();
}
} else if (output instanceof String) {
outputString = (String) output;
} else {
outputString = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(output));
}
return outputString;
}

public static String parseInputFromLLMReturn(Map<String, ?> retMap) {
Object actionInput = retMap.get("action_input");
if (actionInput instanceof Map) {
return gson.toJson(actionInput);
} else {
return String.valueOf(actionInput);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.extractModelResponseJson;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.outputToOutputString;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseInputFromLLMReturn;

import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -325,7 +325,7 @@ private void runReAct(
}
String thought = String.valueOf(dataAsMap.get("thought"));
String action = String.valueOf(dataAsMap.get("action"));
String actionInput = String.valueOf(dataAsMap.get("action_input"));
String actionInput = parseInputFromLLMReturn(dataAsMap);
String finalAnswer = (String) dataAsMap.get("final_answer");
if (!dataAsMap.containsKey("thought")) {
String response = (String) dataAsMap.get("response");
Expand All @@ -336,7 +336,7 @@ private void runReAct(
Map map = gson.fromJson(jsonBlock, Map.class);
thought = String.valueOf(map.get("thought"));
action = String.valueOf(map.get("action"));
actionInput = String.valueOf(map.get("action_input"));
actionInput = parseInputFromLLMReturn(map);
finalAnswer = (String) map.get("final_answer");
} else {
finalAnswer = response;
Expand Down Expand Up @@ -509,9 +509,7 @@ private void runReAct(
} else {
MLToolSpec toolSpec = toolSpecMap.get(lastAction.get());
if (toolSpec != null && toolSpec.isIncludeOutputInAgentResponse()) {
String outputString = output instanceof String
? (String) output
: AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(output));
String outputString = outputToOutputString(output);

String toolOutputKey = String.format("%s.output", toolSpec.getType());
if (additionalInfo.get(toolOutputKey) != null) {
Expand All @@ -531,7 +529,13 @@ private void runReAct(
.singletonList(
ModelTensor
.builder()
.dataAsMap(ImmutableMap.of("response", lastThought.get() + "\nObservation: " + output))
.dataAsMap(
ImmutableMap
.of(
"response",
lastThought.get() + "\nObservation: " + outputToOutputString(output)
)
)
.build()
)
)
Expand All @@ -540,7 +544,7 @@ private void runReAct(

String toolResponse = tmpParameters.get("prompt.tool_response");
StringSubstitutor toolResponseSubstitutor = new StringSubstitutor(
ImmutableMap.of("observation", output),
ImmutableMap.of("observation", outputToOutputString(output)),
"${parameters.",
"}"
);
Expand All @@ -552,7 +556,7 @@ private void runReAct(
.conversationIndexMessageBuilder()
.type("ReAct")
.question(lastActionInput.get())
.response((String) output)
.response(outputToOutputString(output))
.finalAnswer(false)
.sessionId(sessionId)
.build();
Expand All @@ -567,7 +571,7 @@ private void runReAct(
newPrompt.set(substitutor.replace(finalPrompt));
tmpParameters.put(PROMPT, newPrompt.get());

sessionMsgAnswerBuilder.append("\nObservation: ").append(output);
sessionMsgAnswerBuilder.append("\nObservation: ").append(outputToOutputString(output));
cotModelTensors
.add(
ModelTensors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

package org.opensearch.ml.engine.tools;

import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.util.HashMap;
import java.util.Map;

import org.opensearch.action.ActionRequest;
Expand Down Expand Up @@ -51,11 +54,12 @@ public AgentTool(Client client, String agentId) {

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
Map<String, String> extractedParameters = extractInputParameters(parameters);
AgentMLInput agentMLInput = AgentMLInput
.AgentMLInputBuilder()
.agentId(agentId)
.functionName(FunctionName.AGENT)
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(parameters).build())
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(extractedParameters).build())
.build();
ActionRequest request = new MLExecuteTaskRequest(FunctionName.AGENT, agentMLInput, false);
client.execute(MLExecuteTaskAction.INSTANCE, request, ActionListener.wrap(r -> {
Expand Down Expand Up @@ -135,4 +139,18 @@ public String getDefaultVersion() {
return null;
}
}

private Map<String, String> extractInputParameters(Map<String, String> parameters) {
Map<String, String> extractedParameters = new HashMap<>();
extractedParameters.putAll(parameters);
if (parameters.containsKey("input")) {
try {
Map<String, String> chatParameters = gson.fromJson(parameters.get("input"), Map.class);
extractedParameters.putAll(chatParameters);
} catch (Exception exception) {
log.info("fail extract parameters from key 'input' due to" + exception.getMessage());
}
}
return extractedParameters;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,86 @@ public void testParsingJsonBlockFromResponse() {
assertEquals("parsed final answer", modelTensor2.getResult());
}

@Test
public void testParsingJsonBlockFromResponse2() {
// Prepare the response with JSON block
String jsonBlock = "{\"thought\":\"parsed thought\", \"action\":\"parsed action\", "
+ "\"action_input\":\"parsed action input\", \"final_answer\":\"parsed final answer\"}";
String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text";

// Mock LLM response to not contain "thought" but contain "response" with JSON block
Map<String, String> llmResponse = new HashMap<>();
llmResponse.put("response", responseWithJsonBlock);
doAnswer(getLLMAnswer(llmResponse))
.when(client)
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));

// Create an MLAgent and run the MLChatAgentRunner
MLAgent mlAgent = createMLAgentWithTools();
Map<String, String> params = new HashMap<>();
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
params.put("verbose", "true");
mlChatAgentRunner.run(mlAgent, params, agentActionListener);

// Capture the response passed to the listener
ArgumentCaptor<Object> responseCaptor = ArgumentCaptor.forClass(Object.class);
verify(agentActionListener).onResponse(responseCaptor.capture());

// Extract the captured response
Object capturedResponse = responseCaptor.getValue();
assertTrue(capturedResponse instanceof ModelTensorOutput);
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;

ModelTensor parentInteractionModelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1);
ModelTensor modelTensor1 = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0);
ModelTensor modelTensor2 = modelTensorOutput.getMlModelOutputs().get(2).getMlModelTensors().get(0);

// Verify that the parsed values from JSON block are correctly set
assertEquals("parent_interaction_id", parentInteractionModelTensor.getResult());
assertEquals("Thought: parsed thought", modelTensor1.getResult());
assertEquals("parsed final answer", modelTensor2.getResult());
}

@Test
public void testParsingJsonBlockFromResponse3() {
// Prepare the response with JSON block
String jsonBlock = "{\"thought\":\"parsed thought\", \"action\":\"parsed action\", "
+ "\"action_input\":{\"a\":\"n\"}, \"final_answer\":\"parsed final answer\"}";
String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text";

// Mock LLM response to not contain "thought" but contain "response" with JSON block
Map<String, String> llmResponse = new HashMap<>();
llmResponse.put("response", responseWithJsonBlock);
doAnswer(getLLMAnswer(llmResponse))
.when(client)
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));

// Create an MLAgent and run the MLChatAgentRunner
MLAgent mlAgent = createMLAgentWithTools();
Map<String, String> params = new HashMap<>();
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
params.put("verbose", "true");
mlChatAgentRunner.run(mlAgent, params, agentActionListener);

// Capture the response passed to the listener
ArgumentCaptor<Object> responseCaptor = ArgumentCaptor.forClass(Object.class);
verify(agentActionListener).onResponse(responseCaptor.capture());

// Extract the captured response
Object capturedResponse = responseCaptor.getValue();
assertTrue(capturedResponse instanceof ModelTensorOutput);
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;

ModelTensor parentInteractionModelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1);
ModelTensor modelTensor1 = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0);
ModelTensor modelTensor2 = modelTensorOutput.getMlModelOutputs().get(2).getMlModelTensors().get(0);

// Verify that the parsed values from JSON block are correctly set
assertEquals("parent_interaction_id", parentInteractionModelTensor.getResult());
assertEquals("Thought: parsed thought", modelTensor1.getResult());
assertEquals("parsed final answer", modelTensor2.getResult());
}

@Test
public void testRunWithIncludeOutputNotSet() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
Expand All @@ -209,6 +289,35 @@ public void testRunWithIncludeOutputNotSet() {
assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response"));
}

@Test
public void testRunWithIncludeOutputMLModel() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
Mockito
.doAnswer(generateToolResponseAsMLModelResult("First tool response", 1))
.when(firstTool)
.run(Mockito.anyMap(), toolListenerCaptor.capture());
Mockito
.doAnswer(generateToolResponseAsMLModelResult("Second tool response", 2))
.when(secondTool)
.run(Mockito.anyMap(), toolListenerCaptor.capture());
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build();
MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).build();
final MLAgent mlAgent = MLAgent
.builder()
.name("TestAgent")
.llm(llmSpec)
.memory(mlMemorySpec)
.tools(Arrays.asList(firstToolSpec, secondToolSpec))
.build();
mlChatAgentRunner.run(mlAgent, new HashMap<>(), agentActionListener);
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors();
assertEquals(1, agentOutput.size());
// Respond with last tool output
assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response"));
}

@Test
public void testRunWithIncludeOutputSet() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
Expand Down Expand Up @@ -512,6 +621,23 @@ private Answer generateToolResponse(String response) {
};
}

private Answer generateToolResponseAsMLModelResult(String response, int type) {
ModelTensor modelTensor;
if (type == 1) {
modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("return", response)).build();
} else {
modelTensor = ModelTensor.builder().result(response).build();
}
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();

return invocation -> {
ActionListener<Object> listener = invocation.getArgument(1);
listener.onResponse(mlModelTensorOutput);
return null;
};
}

private Answer generateToolFailure(Exception e) {
return invocation -> {
ActionListener<Object> listener = invocation.getArgument(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.engine.tools.AgentTool.DEFAULT_DESCRIPTION;

import java.util.Arrays;
Expand All @@ -24,8 +25,6 @@
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
Expand Down Expand Up @@ -67,13 +66,28 @@ public void setup() {
public void testAgenttestRunMethod() {
Map<String, String> parameters = new HashMap<>();
parameters.put("testKey", "testValue");
AgentMLInput agentMLInput = AgentMLInput
.AgentMLInputBuilder()
.agentId("agentId")
.functionName(FunctionName.AGENT)
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(parameters).build())
.build();
doTestRunMethod(parameters);
}

@Test
public void testAgentWithChatAgentInput() {
Map<String, String> parameters = new HashMap<>();
parameters.put("testKey", "testValue");
Map<String, String> chatAgentInput = new HashMap<>();
chatAgentInput.put("input", gson.toJson(parameters));
doTestRunMethod(chatAgentInput);
assertEquals(chatAgentInput.size(), 1);
assertEquals(chatAgentInput.get("input"), gson.toJson(parameters)); // assert no influence on original parameters
}

@Test
public void testAgentWithChatAgentInputWrongFormat() {
Map<String, String> chatAgentInput = new HashMap<>();
chatAgentInput.put("input", "wrong format");
doTestRunMethod(chatAgentInput);
}

private void doTestRunMethod(Map<String, String> parameters) {
ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("thought", "thought 1", "action", "action1")).build();
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
Expand Down
Loading