@@ -188,6 +188,46 @@ public void testParsingJsonBlockFromResponse() {
188
188
assertEquals ("parsed final answer" , modelTensor2 .getResult ());
189
189
}
190
190
191
+ @ Test
192
+ public void testParsingJsonBlockFromResponse2 () {
193
+ // Prepare the response with JSON block
194
+ String jsonBlock = "{\" thought\" :\" parsed thought\" , \" action\" :\" parsed action\" , "
195
+ + "\" action_input\" :\" parsed action input\" , \" final_answer\" :\" parsed final answer\" }" ;
196
+ String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text" ;
197
+
198
+ // Mock LLM response to not contain "thought" but contain "response" with JSON block
199
+ Map <String , String > llmResponse = new HashMap <>();
200
+ llmResponse .put ("response" , responseWithJsonBlock );
201
+ doAnswer (getLLMAnswer (llmResponse ))
202
+ .when (client )
203
+ .execute (any (ActionType .class ), any (ActionRequest .class ), isA (ActionListener .class ));
204
+
205
+ // Create an MLAgent and run the MLChatAgentRunner
206
+ MLAgent mlAgent = createMLAgentWithTools ();
207
+ Map <String , String > params = new HashMap <>();
208
+ params .put (MLAgentExecutor .PARENT_INTERACTION_ID , "parent_interaction_id" );
209
+ params .put ("verbose" , "true" );
210
+ mlChatAgentRunner .run (mlAgent , params , agentActionListener );
211
+
212
+ // Capture the response passed to the listener
213
+ ArgumentCaptor <Object > responseCaptor = ArgumentCaptor .forClass (Object .class );
214
+ verify (agentActionListener ).onResponse (responseCaptor .capture ());
215
+
216
+ // Extract the captured response
217
+ Object capturedResponse = responseCaptor .getValue ();
218
+ assertTrue (capturedResponse instanceof ModelTensorOutput );
219
+ ModelTensorOutput modelTensorOutput = (ModelTensorOutput ) capturedResponse ;
220
+
221
+ ModelTensor parentInteractionModelTensor = modelTensorOutput .getMlModelOutputs ().get (0 ).getMlModelTensors ().get (1 );
222
+ ModelTensor modelTensor1 = modelTensorOutput .getMlModelOutputs ().get (1 ).getMlModelTensors ().get (0 );
223
+ ModelTensor modelTensor2 = modelTensorOutput .getMlModelOutputs ().get (2 ).getMlModelTensors ().get (0 );
224
+
225
+ // Verify that the parsed values from JSON block are correctly set
226
+ assertEquals ("parent_interaction_id" , parentInteractionModelTensor .getResult ());
227
+ assertEquals ("Thought: parsed thought" , modelTensor1 .getResult ());
228
+ assertEquals ("parsed final answer" , modelTensor2 .getResult ());
229
+ }
230
+
191
231
@ Test
192
232
public void testRunWithIncludeOutputNotSet () {
193
233
LLMSpec llmSpec = LLMSpec .builder ().modelId ("MODEL_ID" ).build ();
@@ -209,6 +249,29 @@ public void testRunWithIncludeOutputNotSet() {
209
249
assertEquals ("This is the final answer" , agentOutput .get (0 ).getDataAsMap ().get ("response" ));
210
250
}
211
251
252
+ @ Test
253
+ public void testRunWithIncludeOutputMLModel () {
254
+ LLMSpec llmSpec = LLMSpec .builder ().modelId ("MODEL_ID" ).build ();
255
+ Mockito .doAnswer (generateToolResponseAsMLModelResult ("First tool response" , 1 )).when (firstTool ).run (Mockito .anyMap (), toolListenerCaptor .capture ());
256
+ Mockito .doAnswer (generateToolResponseAsMLModelResult ("Second tool response" , 2 )).when (secondTool ).run (Mockito .anyMap (), toolListenerCaptor .capture ());
257
+ MLToolSpec firstToolSpec = MLToolSpec .builder ().name (FIRST_TOOL ).type (FIRST_TOOL ).build ();
258
+ MLToolSpec secondToolSpec = MLToolSpec .builder ().name (SECOND_TOOL ).type (SECOND_TOOL ).build ();
259
+ final MLAgent mlAgent = MLAgent
260
+ .builder ()
261
+ .name ("TestAgent" )
262
+ .llm (llmSpec )
263
+ .memory (mlMemorySpec )
264
+ .tools (Arrays .asList (firstToolSpec , secondToolSpec ))
265
+ .build ();
266
+ mlChatAgentRunner .run (mlAgent , new HashMap <>(), agentActionListener );
267
+ Mockito .verify (agentActionListener ).onResponse (objectCaptor .capture ());
268
+ ModelTensorOutput modelTensorOutput = (ModelTensorOutput ) objectCaptor .getValue ();
269
+ List <ModelTensor > agentOutput = modelTensorOutput .getMlModelOutputs ().get (0 ).getMlModelTensors ();
270
+ assertEquals (1 , agentOutput .size ());
271
+ // Respond with last tool output
272
+ assertEquals ("This is the final answer" , agentOutput .get (0 ).getDataAsMap ().get ("response" ));
273
+ }
274
+
212
275
@ Test
213
276
public void testRunWithIncludeOutputSet () {
214
277
LLMSpec llmSpec = LLMSpec .builder ().modelId ("MODEL_ID" ).build ();
@@ -512,6 +575,24 @@ private Answer generateToolResponse(String response) {
512
575
};
513
576
}
514
577
578
+ private Answer generateToolResponseAsMLModelResult (String response , int type ) {
579
+ ModelTensor modelTensor ;
580
+ if (type == 1 ) {
581
+ modelTensor = ModelTensor .builder ().dataAsMap (ImmutableMap .of ("return" , response )).build ();
582
+ }
583
+ else {
584
+ modelTensor = ModelTensor .builder ().result (response ).build ();
585
+ }
586
+ ModelTensors modelTensors = ModelTensors .builder ().mlModelTensors (Arrays .asList (modelTensor )).build ();
587
+ ModelTensorOutput mlModelTensorOutput = ModelTensorOutput .builder ().mlModelOutputs (Arrays .asList (modelTensors )).build ();
588
+
589
+ return invocation -> {
590
+ ActionListener <Object > listener = invocation .getArgument (1 );
591
+ listener .onResponse (mlModelTensorOutput );
592
+ return null ;
593
+ };
594
+ }
595
+
515
596
private Answer generateToolFailure (Exception e ) {
516
597
return invocation -> {
517
598
ActionListener <Object > listener = invocation .getArgument (1 );
0 commit comments