@@ -268,6 +268,46 @@ public void testParsingJsonBlockFromResponse3() {
268
268
assertEquals ("parsed final answer" , modelTensor2 .getResult ());
269
269
}
270
270
271
+ @ Test
272
+ public void testParsingJsonBlockFromResponse4 () {
273
+ // Prepare the response with JSON block
274
+ String jsonBlock = "{\" thought\" :\" parsed thought\" , \" action\" :\" parsed action\" , "
275
+ + "\" action_input\" :\" parsed action input\" , \" final_answer\" :\" parsed final answer\" }" ;
276
+ String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text" ;
277
+
278
+ // Mock LLM response to not contain "thought" but contain "response" with JSON block
279
+ Map <String , String > llmResponse = new HashMap <>();
280
+ llmResponse .put ("response" , responseWithJsonBlock );
281
+ doAnswer (getLLMAnswer (llmResponse ))
282
+ .when (client )
283
+ .execute (any (ActionType .class ), any (ActionRequest .class ), isA (ActionListener .class ));
284
+
285
+ // Create an MLAgent and run the MLChatAgentRunner
286
+ MLAgent mlAgent = createMLAgentWithTools ();
287
+ Map <String , String > params = new HashMap <>();
288
+ params .put (MLAgentExecutor .PARENT_INTERACTION_ID , "parent_interaction_id" );
289
+ params .put ("verbose" , "false" );
290
+ mlChatAgentRunner .run (mlAgent , params , agentActionListener );
291
+
292
+ // Capture the response passed to the listener
293
+ ArgumentCaptor <Object > responseCaptor = ArgumentCaptor .forClass (Object .class );
294
+ verify (agentActionListener ).onResponse (responseCaptor .capture ());
295
+
296
+ // Extract the captured response
297
+ Object capturedResponse = responseCaptor .getValue ();
298
+ assertTrue (capturedResponse instanceof ModelTensorOutput );
299
+ ModelTensorOutput modelTensorOutput = (ModelTensorOutput ) capturedResponse ;
300
+
301
+ ModelTensor memoryIdModelTensor = modelTensorOutput .getMlModelOutputs ().get (0 ).getMlModelTensors ().get (0 );
302
+ ModelTensor parentInteractionModelTensor = modelTensorOutput .getMlModelOutputs ().get (0 ).getMlModelTensors ().get (1 );
303
+
304
+ // Verify that the parsed values from JSON block are correctly set
305
+ assertEquals ("memory_id" , memoryIdModelTensor .getName ());
306
+ assertEquals ("conversation_id" , memoryIdModelTensor .getResult ());
307
+ assertEquals ("parent_interaction_id" , parentInteractionModelTensor .getName ());
308
+ assertEquals ("parent_interaction_id" , parentInteractionModelTensor .getResult ());
309
+ }
310
+
271
311
@ Test
272
312
public void testRunWithIncludeOutputNotSet () {
273
313
LLMSpec llmSpec = LLMSpec .builder ().modelId ("MODEL_ID" ).build ();
@@ -293,7 +333,7 @@ public void testRunWithIncludeOutputNotSet() {
293
333
mlChatAgentRunner .run (mlAgent , new HashMap <>(), agentActionListener );
294
334
Mockito .verify (agentActionListener ).onResponse (objectCaptor .capture ());
295
335
ModelTensorOutput modelTensorOutput = (ModelTensorOutput ) objectCaptor .getValue ();
296
- List <ModelTensor > agentOutput = modelTensorOutput .getMlModelOutputs ().get (0 ).getMlModelTensors ();
336
+ List <ModelTensor > agentOutput = modelTensorOutput .getMlModelOutputs ().get (1 ).getMlModelTensors ();
297
337
assertEquals (1 , agentOutput .size ());
298
338
// Respond with last tool output
299
339
assertEquals ("This is the final answer" , agentOutput .get (0 ).getDataAsMap ().get ("response" ));
@@ -322,7 +362,7 @@ public void testRunWithIncludeOutputMLModel() {
322
362
mlChatAgentRunner .run (mlAgent , new HashMap <>(), agentActionListener );
323
363
Mockito .verify (agentActionListener ).onResponse (objectCaptor .capture ());
324
364
ModelTensorOutput modelTensorOutput = (ModelTensorOutput ) objectCaptor .getValue ();
325
- List <ModelTensor > agentOutput = modelTensorOutput .getMlModelOutputs ().get (0 ).getMlModelTensors ();
365
+ List <ModelTensor > agentOutput = modelTensorOutput .getMlModelOutputs ().get (1 ).getMlModelTensors ();
326
366
assertEquals (1 , agentOutput .size ());
327
367
// Respond with last tool output
328
368
assertEquals ("This is the final answer" , agentOutput .get (0 ).getDataAsMap ().get ("response" ));
@@ -356,7 +396,7 @@ public void testRunWithIncludeOutputSet() {
356
396
mlChatAgentRunner .run (mlAgent , params , agentActionListener );
357
397
Mockito .verify (agentActionListener ).onResponse (objectCaptor .capture ());
358
398
ModelTensorOutput modelTensorOutput = (ModelTensorOutput ) objectCaptor .getValue ();
359
- List <ModelTensor > agentOutput = modelTensorOutput .getMlModelOutputs ().get (0 ).getMlModelTensors ();
399
+ List <ModelTensor > agentOutput = modelTensorOutput .getMlModelOutputs ().get (1 ).getMlModelTensors ();
360
400
assertEquals (1 , agentOutput .size ());
361
401
// Respond with last tool output
362
402
assertEquals ("This is the final answer" , agentOutput .get (0 ).getDataAsMap ().get ("response" ));
0 commit comments