@@ -188,6 +188,86 @@ public void testParsingJsonBlockFromResponse() {
188
188
assertEquals ("parsed final answer" , modelTensor2 .getResult ());
189
189
}
190
190
191
+ @ Test
192
+ public void testParsingJsonBlockFromResponse2 () {
193
+ // Prepare the response with JSON block
194
+ String jsonBlock = "{\" thought\" :\" parsed thought\" , \" action\" :\" parsed action\" , "
195
+ + "\" action_input\" :\" parsed action input\" , \" final_answer\" :\" parsed final answer\" }" ;
196
+ String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text" ;
197
+
198
+ // Mock LLM response to not contain "thought" but contain "response" with JSON block
199
+ Map <String , String > llmResponse = new HashMap <>();
200
+ llmResponse .put ("response" , responseWithJsonBlock );
201
+ doAnswer (getLLMAnswer (llmResponse ))
202
+ .when (client )
203
+ .execute (any (ActionType .class ), any (ActionRequest .class ), isA (ActionListener .class ));
204
+
205
+ // Create an MLAgent and run the MLChatAgentRunner
206
+ MLAgent mlAgent = createMLAgentWithTools ();
207
+ Map <String , String > params = new HashMap <>();
208
+ params .put (MLAgentExecutor .PARENT_INTERACTION_ID , "parent_interaction_id" );
209
+ params .put ("verbose" , "true" );
210
+ mlChatAgentRunner .run (mlAgent , params , agentActionListener );
211
+
212
+ // Capture the response passed to the listener
213
+ ArgumentCaptor <Object > responseCaptor = ArgumentCaptor .forClass (Object .class );
214
+ verify (agentActionListener ).onResponse (responseCaptor .capture ());
215
+
216
+ // Extract the captured response
217
+ Object capturedResponse = responseCaptor .getValue ();
218
+ assertTrue (capturedResponse instanceof ModelTensorOutput );
219
+ ModelTensorOutput modelTensorOutput = (ModelTensorOutput ) capturedResponse ;
220
+
221
+ ModelTensor parentInteractionModelTensor = modelTensorOutput .getMlModelOutputs ().get (0 ).getMlModelTensors ().get (1 );
222
+ ModelTensor modelTensor1 = modelTensorOutput .getMlModelOutputs ().get (1 ).getMlModelTensors ().get (0 );
223
+ ModelTensor modelTensor2 = modelTensorOutput .getMlModelOutputs ().get (2 ).getMlModelTensors ().get (0 );
224
+
225
+ // Verify that the parsed values from JSON block are correctly set
226
+ assertEquals ("parent_interaction_id" , parentInteractionModelTensor .getResult ());
227
+ assertEquals ("Thought: parsed thought" , modelTensor1 .getResult ());
228
+ assertEquals ("parsed final answer" , modelTensor2 .getResult ());
229
+ }
230
+
231
+ @ Test
232
+ public void testParsingJsonBlockFromResponse3 () {
233
+ // Prepare the response with JSON block
234
+ String jsonBlock = "{\" thought\" :\" parsed thought\" , \" action\" :\" parsed action\" , "
235
+ + "\" action_input\" :{\" a\" :\" n\" }, \" final_answer\" :\" parsed final answer\" }" ;
236
+ String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text" ;
237
+
238
+ // Mock LLM response to not contain "thought" but contain "response" with JSON block
239
+ Map <String , String > llmResponse = new HashMap <>();
240
+ llmResponse .put ("response" , responseWithJsonBlock );
241
+ doAnswer (getLLMAnswer (llmResponse ))
242
+ .when (client )
243
+ .execute (any (ActionType .class ), any (ActionRequest .class ), isA (ActionListener .class ));
244
+
245
+ // Create an MLAgent and run the MLChatAgentRunner
246
+ MLAgent mlAgent = createMLAgentWithTools ();
247
+ Map <String , String > params = new HashMap <>();
248
+ params .put (MLAgentExecutor .PARENT_INTERACTION_ID , "parent_interaction_id" );
249
+ params .put ("verbose" , "true" );
250
+ mlChatAgentRunner .run (mlAgent , params , agentActionListener );
251
+
252
+ // Capture the response passed to the listener
253
+ ArgumentCaptor <Object > responseCaptor = ArgumentCaptor .forClass (Object .class );
254
+ verify (agentActionListener ).onResponse (responseCaptor .capture ());
255
+
256
+ // Extract the captured response
257
+ Object capturedResponse = responseCaptor .getValue ();
258
+ assertTrue (capturedResponse instanceof ModelTensorOutput );
259
+ ModelTensorOutput modelTensorOutput = (ModelTensorOutput ) capturedResponse ;
260
+
261
+ ModelTensor parentInteractionModelTensor = modelTensorOutput .getMlModelOutputs ().get (0 ).getMlModelTensors ().get (1 );
262
+ ModelTensor modelTensor1 = modelTensorOutput .getMlModelOutputs ().get (1 ).getMlModelTensors ().get (0 );
263
+ ModelTensor modelTensor2 = modelTensorOutput .getMlModelOutputs ().get (2 ).getMlModelTensors ().get (0 );
264
+
265
+ // Verify that the parsed values from JSON block are correctly set
266
+ assertEquals ("parent_interaction_id" , parentInteractionModelTensor .getResult ());
267
+ assertEquals ("Thought: parsed thought" , modelTensor1 .getResult ());
268
+ assertEquals ("parsed final answer" , modelTensor2 .getResult ());
269
+ }
270
+
191
271
@ Test
192
272
public void testRunWithIncludeOutputNotSet () {
193
273
LLMSpec llmSpec = LLMSpec .builder ().modelId ("MODEL_ID" ).build ();
@@ -209,6 +289,35 @@ public void testRunWithIncludeOutputNotSet() {
209
289
assertEquals ("This is the final answer" , agentOutput .get (0 ).getDataAsMap ().get ("response" ));
210
290
}
211
291
292
+ @ Test
293
+ public void testRunWithIncludeOutputMLModel () {
294
+ LLMSpec llmSpec = LLMSpec .builder ().modelId ("MODEL_ID" ).build ();
295
+ Mockito
296
+ .doAnswer (generateToolResponseAsMLModelResult ("First tool response" , 1 ))
297
+ .when (firstTool )
298
+ .run (Mockito .anyMap (), toolListenerCaptor .capture ());
299
+ Mockito
300
+ .doAnswer (generateToolResponseAsMLModelResult ("Second tool response" , 2 ))
301
+ .when (secondTool )
302
+ .run (Mockito .anyMap (), toolListenerCaptor .capture ());
303
+ MLToolSpec firstToolSpec = MLToolSpec .builder ().name (FIRST_TOOL ).type (FIRST_TOOL ).build ();
304
+ MLToolSpec secondToolSpec = MLToolSpec .builder ().name (SECOND_TOOL ).type (SECOND_TOOL ).build ();
305
+ final MLAgent mlAgent = MLAgent
306
+ .builder ()
307
+ .name ("TestAgent" )
308
+ .llm (llmSpec )
309
+ .memory (mlMemorySpec )
310
+ .tools (Arrays .asList (firstToolSpec , secondToolSpec ))
311
+ .build ();
312
+ mlChatAgentRunner .run (mlAgent , new HashMap <>(), agentActionListener );
313
+ Mockito .verify (agentActionListener ).onResponse (objectCaptor .capture ());
314
+ ModelTensorOutput modelTensorOutput = (ModelTensorOutput ) objectCaptor .getValue ();
315
+ List <ModelTensor > agentOutput = modelTensorOutput .getMlModelOutputs ().get (0 ).getMlModelTensors ();
316
+ assertEquals (1 , agentOutput .size ());
317
+ // Respond with last tool output
318
+ assertEquals ("This is the final answer" , agentOutput .get (0 ).getDataAsMap ().get ("response" ));
319
+ }
320
+
212
321
@ Test
213
322
public void testRunWithIncludeOutputSet () {
214
323
LLMSpec llmSpec = LLMSpec .builder ().modelId ("MODEL_ID" ).build ();
@@ -512,6 +621,23 @@ private Answer generateToolResponse(String response) {
512
621
};
513
622
}
514
623
624
+ private Answer generateToolResponseAsMLModelResult (String response , int type ) {
625
+ ModelTensor modelTensor ;
626
+ if (type == 1 ) {
627
+ modelTensor = ModelTensor .builder ().dataAsMap (ImmutableMap .of ("return" , response )).build ();
628
+ } else {
629
+ modelTensor = ModelTensor .builder ().result (response ).build ();
630
+ }
631
+ ModelTensors modelTensors = ModelTensors .builder ().mlModelTensors (Arrays .asList (modelTensor )).build ();
632
+ ModelTensorOutput mlModelTensorOutput = ModelTensorOutput .builder ().mlModelOutputs (Arrays .asList (modelTensors )).build ();
633
+
634
+ return invocation -> {
635
+ ActionListener <Object > listener = invocation .getArgument (1 );
636
+ listener .onResponse (mlModelTensorOutput );
637
+ return null ;
638
+ };
639
+ }
640
+
515
641
private Answer generateToolFailure (Exception e ) {
516
642
return invocation -> {
517
643
ActionListener <Object > listener = invocation .getArgument (1 );
0 commit comments