|
130 | 130 | import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
|
131 | 131 | import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
|
132 | 132 | import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
|
| 133 | +import org.opensearch.ml.memory.action.conversation.CreateConversationAction; |
| 134 | +import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; |
| 135 | +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; |
133 | 136 | import org.opensearch.search.SearchHit;
|
134 | 137 | import org.opensearch.search.SearchHits;
|
135 | 138 | import org.opensearch.search.aggregations.InternalAggregations;
|
@@ -205,6 +208,9 @@ public class MachineLearningNodeClientTest {
|
205 | 208 | @Mock
|
206 | 209 | ActionListener<ToolMetadata> getToolActionListener;
|
207 | 210 |
|
| 211 | + @Mock |
| 212 | + ActionListener<CreateConversationResponse> createConversationResponseActionListener; |
| 213 | + |
208 | 214 | @InjectMocks
|
209 | 215 | MachineLearningNodeClient machineLearningNodeClient;
|
210 | 216 |
|
@@ -950,6 +956,26 @@ public void listTools() {
|
950 | 956 | assertEquals("Use this tool to search general knowledge on wikipedia.", argumentCaptor.getValue().get(0).getDescription());
|
951 | 957 | }
|
952 | 958 |
|
| 959 | + @Test |
| 960 | + public void createConversation() { |
| 961 | + String name = "Conversation for a RAG pipeline"; |
| 962 | + String conversationId = "conversationId"; |
| 963 | + |
| 964 | + doAnswer(invocation -> { |
| 965 | + ActionListener<CreateConversationResponse> actionListener = invocation.getArgument(2); |
| 966 | + CreateConversationResponse output = new CreateConversationResponse(conversationId); |
| 967 | + actionListener.onResponse(output); |
| 968 | + return null; |
| 969 | + }).when(client).execute(eq(CreateConversationAction.INSTANCE), any(), any()); |
| 970 | + |
| 971 | + ArgumentCaptor<CreateConversationResponse> argumentCaptor = ArgumentCaptor.forClass(CreateConversationResponse.class); |
| 972 | + machineLearningNodeClient.createConversation(name, createConversationResponseActionListener); |
| 973 | + |
| 974 | + verify(client).execute(eq(CreateConversationAction.INSTANCE), isA(CreateConversationRequest.class), any()); |
| 975 | + verify(createConversationResponseActionListener).onResponse(argumentCaptor.capture()); |
| 976 | + assertEquals(conversationId, argumentCaptor.getValue().getId()); |
| 977 | + } |
| 978 | + |
953 | 979 | private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
|
954 | 980 | XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
|
955 | 981 |
|
|
0 commit comments