Skip to content

Commit dea4cf6

Browse files
committed
protect original parameters
Signed-off-by: xinyual <xinyual@amazon.com>
1 parent f14a509 commit dea4cf6

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java

+8-4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import static org.opensearch.ml.common.utils.StringUtils.gson;
99

10+
import java.util.HashMap;
1011
import java.util.Map;
1112

1213
import org.opensearch.action.ActionRequest;
@@ -53,12 +54,12 @@ public AgentTool(Client client, String agentId) {
5354

5455
@Override
5556
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
56-
extractFromChatParameters(parameters);
57+
Map<String, String> extractedParameters = extractInputParameters(parameters);
5758
AgentMLInput agentMLInput = AgentMLInput
5859
.AgentMLInputBuilder()
5960
.agentId(agentId)
6061
.functionName(FunctionName.AGENT)
61-
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(parameters).build())
62+
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(extractedParameters).build())
6263
.build();
6364
ActionRequest request = new MLExecuteTaskRequest(FunctionName.AGENT, agentMLInput, false);
6465
client.execute(MLExecuteTaskAction.INSTANCE, request, ActionListener.wrap(r -> {
@@ -139,14 +140,17 @@ public String getDefaultVersion() {
139140
}
140141
}
141142

142-
private void extractFromChatParameters(Map<String, String> parameters) {
143+
private Map<String, String> extractInputParameters(Map<String, String> parameters) {
144+
Map<String, String> extractedParameters = new HashMap<>();
145+
extractedParameters.putAll(parameters);
143146
if (parameters.containsKey("input")) {
144147
try {
145148
Map<String, String> chatParameters = gson.fromJson(parameters.get("input"), Map.class);
146-
parameters.putAll(chatParameters);
149+
extractedParameters.putAll(chatParameters);
147150
} catch (Exception exception) {
148151
log.info("fail extract parameters from key 'input' due to" + exception.getMessage());
149152
}
150153
}
154+
return extractedParameters;
151155
}
152156
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java

+2
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ public void testAgentWithChatAgentInput() {
7676
Map<String, String> chatAgentInput = new HashMap<>();
7777
chatAgentInput.put("input", gson.toJson(parameters));
7878
doTestRunMethod(chatAgentInput);
79+
assertEquals(chatAgentInput.size(), 1);
80+
assertEquals(chatAgentInput.get("input"), gson.toJson(parameters)); // assert no influence on original parameters
7981
}
8082

8183
@Test

0 commit comments

Comments
 (0)