|
7 | 7 |
|
8 | 8 | import static org.opensearch.ml.common.utils.StringUtils.gson;
|
9 | 9 |
|
| 10 | +import java.util.HashMap; |
10 | 11 | import java.util.Map;
|
11 | 12 |
|
12 | 13 | import org.opensearch.action.ActionRequest;
|
@@ -53,12 +54,12 @@ public AgentTool(Client client, String agentId) {
|
53 | 54 |
|
54 | 55 | @Override
|
55 | 56 | public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
|
56 |
| - extractFromChatParameters(parameters); |
| 57 | + Map<String, String> extractedParameters = extractInputParameters(parameters); |
57 | 58 | AgentMLInput agentMLInput = AgentMLInput
|
58 | 59 | .AgentMLInputBuilder()
|
59 | 60 | .agentId(agentId)
|
60 | 61 | .functionName(FunctionName.AGENT)
|
61 |
| - .inputDataset(RemoteInferenceInputDataSet.builder().parameters(parameters).build()) |
| 62 | + .inputDataset(RemoteInferenceInputDataSet.builder().parameters(extractedParameters).build()) |
62 | 63 | .build();
|
63 | 64 | ActionRequest request = new MLExecuteTaskRequest(FunctionName.AGENT, agentMLInput, false);
|
64 | 65 | client.execute(MLExecuteTaskAction.INSTANCE, request, ActionListener.wrap(r -> {
|
@@ -139,14 +140,17 @@ public String getDefaultVersion() {
|
139 | 140 | }
|
140 | 141 | }
|
141 | 142 |
|
142 |
| - private void extractFromChatParameters(Map<String, String> parameters) { |
| 143 | + private Map<String, String> extractInputParameters(Map<String, String> parameters) { |
| 144 | + Map<String, String> extractedParameters = new HashMap<>(); |
| 145 | + extractedParameters.putAll(parameters); |
143 | 146 | if (parameters.containsKey("input")) {
|
144 | 147 | try {
|
145 | 148 | Map<String, String> chatParameters = gson.fromJson(parameters.get("input"), Map.class);
|
146 |
| - parameters.putAll(chatParameters); |
| 149 | + extractedParameters.putAll(chatParameters); |
147 | 150 | } catch (Exception exception) {
|
148 | 151 | log.info("fail extract parameters from key 'input' due to" + exception.getMessage());
|
149 | 152 | }
|
150 | 153 | }
|
| 154 | + return extractedParameters; |
151 | 155 | }
|
152 | 156 | }
|
0 commit comments