Skip to content

Commit 1e74bab

Browse files
committed
Added create conversation API in MLClient
Signed-off-by: Owais Kazi <owaiskazi19@gmail.com>
1 parent ab1e054 commit 1e74bab

File tree

6 files changed

+100
-0
lines changed

6 files changed

+100
-0
lines changed

client/build.gradle

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ plugins {
1616
dependencies {
1717
implementation project(path: ":${rootProject.name}-spi", configuration: 'shadow')
1818
implementation project(path: ":${rootProject.name}-common", configuration: 'shadow')
19+
implementation project(path: ":${rootProject.name}-memory")
1920
compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
2021
testImplementation group: 'junit', name: 'junit', version: '4.13.2'
2122
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.7.0'

client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java

+19
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
3434
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
3535
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
36+
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;
3637

3738
/**
3839
* A client to provide interfaces for machine learning jobs. This will be used by other plugins.
@@ -428,4 +429,22 @@ default ActionFuture<ToolMetadata> getTool(String toolName) {
428429
*/
429430
void getTool(String toolName, ActionListener<ToolMetadata> listener);
430431

432+
/**
433+
* Create conversational memory for conversation
434+
* @param name name of the conversation, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/memory-apis/create-memory/
435+
* @return the result future
436+
*/
437+
default ActionFuture<CreateConversationResponse> createConversation(String name) {
438+
PlainActionFuture<CreateConversationResponse> actionFuture = PlainActionFuture.newFuture();
439+
createConversation(name, actionFuture);
440+
return actionFuture;
441+
}
442+
443+
/**
444+
* Create conversational memory for conversation
445+
* @param name name of the conversation, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/memory-apis/create-memory/
446+
* @param listener action listener
447+
*/
448+
void createConversation(String name, ActionListener<CreateConversationResponse> listener);
449+
431450
}

client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java

+18
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@
8585
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
8686
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
8787
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
88+
import org.opensearch.ml.memory.action.conversation.CreateConversationAction;
89+
import org.opensearch.ml.memory.action.conversation.CreateConversationRequest;
90+
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;
8891

8992
import lombok.AccessLevel;
9093
import lombok.RequiredArgsConstructor;
@@ -309,6 +312,11 @@ public void getTool(String toolName, ActionListener<ToolMetadata> listener) {
309312
client.execute(MLGetToolAction.INSTANCE, mlToolGetRequest, getMlGetToolResponseActionListener(listener));
310313
}
311314

315+
public void createConversation(String name, ActionListener<CreateConversationResponse> listener) {
316+
CreateConversationRequest createConversationRequest = new CreateConversationRequest(name);
317+
client.execute(CreateConversationAction.INSTANCE, createConversationRequest, getCreateConversationResponseActionListener(listener));
318+
}
319+
312320
private ActionListener<MLToolsListResponse> getMlListToolsResponseActionListener(ActionListener<List<ToolMetadata>> listener) {
313321
ActionListener<MLToolsListResponse> internalListener = ActionListener.wrap(mlModelListResponse -> {
314322
listener.onResponse(mlModelListResponse.getToolMetadataList());
@@ -379,6 +387,16 @@ private ActionListener<MLCreateConnectorResponse> getMlCreateConnectorResponseAc
379387
return actionListener;
380388
}
381389

390+
private ActionListener<CreateConversationResponse> getCreateConversationResponseActionListener(
391+
ActionListener<CreateConversationResponse> listener
392+
) {
393+
ActionListener<CreateConversationResponse> actionListener = wrapActionListener(listener, response -> {
394+
CreateConversationResponse conversationResponse = CreateConversationResponse.fromActionResponse(response);
395+
return conversationResponse;
396+
});
397+
return actionListener;
398+
}
399+
382400
private ActionListener<MLRegisterModelGroupResponse> getMlRegisterModelGroupResponseActionListener(
383401
ActionListener<MLRegisterModelGroupResponse> listener
384402
) {

client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java

+14
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
5555
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
5656
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
57+
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;
5758

5859
public class MachineLearningClientTest {
5960

@@ -98,6 +99,9 @@ public class MachineLearningClientTest {
9899
@Mock
99100
MLRegisterAgentResponse registerAgentResponse;
100101

102+
@Mock
103+
CreateConversationResponse createConversationResponse;
104+
101105
private String modekId = "test_model_id";
102106
private MLModel mlModel;
103107
private MLTask mlTask;
@@ -230,6 +234,11 @@ public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentRespons
230234
public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
231235
listener.onResponse(deleteResponse);
232236
}
237+
238+
@Override
239+
public void createConversation(String name, ActionListener<CreateConversationResponse> listener) {
240+
listener.onResponse(createConversationResponse);
241+
}
233242
};
234243
}
235244

@@ -502,4 +511,9 @@ public void getTool() {
502511
public void listTools() {
503512
assertEquals(toolMetadata, machineLearningClient.listTools().actionGet().get(0));
504513
}
514+
515+
@Test
516+
public void createConversation() {
517+
assertEquals(createConversationResponse, machineLearningClient.createConversation("Conversation for a RAG pipeline").actionGet());
518+
}
505519
}

client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java

+26
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@
130130
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
131131
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
132132
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
133+
import org.opensearch.ml.memory.action.conversation.CreateConversationAction;
134+
import org.opensearch.ml.memory.action.conversation.CreateConversationRequest;
135+
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;
133136
import org.opensearch.search.SearchHit;
134137
import org.opensearch.search.SearchHits;
135138
import org.opensearch.search.aggregations.InternalAggregations;
@@ -205,6 +208,9 @@ public class MachineLearningNodeClientTest {
205208
@Mock
206209
ActionListener<ToolMetadata> getToolActionListener;
207210

211+
@Mock
212+
ActionListener<CreateConversationResponse> createConversationResponseActionListener;
213+
208214
@InjectMocks
209215
MachineLearningNodeClient machineLearningNodeClient;
210216

@@ -950,6 +956,26 @@ public void listTools() {
950956
assertEquals("Use this tool to search general knowledge on wikipedia.", argumentCaptor.getValue().get(0).getDescription());
951957
}
952958

959+
@Test
960+
public void createConversation() {
961+
String name = "Conversation for a RAG pipeline";
962+
String conversationId = "conversationId";
963+
964+
doAnswer(invocation -> {
965+
ActionListener<CreateConversationResponse> actionListener = invocation.getArgument(2);
966+
CreateConversationResponse output = new CreateConversationResponse(conversationId);
967+
actionListener.onResponse(output);
968+
return null;
969+
}).when(client).execute(eq(CreateConversationAction.INSTANCE), any(), any());
970+
971+
ArgumentCaptor<CreateConversationResponse> argumentCaptor = ArgumentCaptor.forClass(CreateConversationResponse.class);
972+
machineLearningNodeClient.createConversation(name, createConversationResponseActionListener);
973+
974+
verify(client).execute(eq(CreateConversationAction.INSTANCE), isA(CreateConversationRequest.class), any());
975+
verify(createConversationResponseActionListener).onResponse(argumentCaptor.capture());
976+
assertEquals(conversationId, argumentCaptor.getValue().getId());
977+
}
978+
953979
private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
954980
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
955981

memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationResponse.java

+22
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,21 @@
1717
*/
1818
package org.opensearch.ml.memory.action.conversation;
1919

20+
import java.io.ByteArrayInputStream;
21+
import java.io.ByteArrayOutputStream;
2022
import java.io.IOException;
23+
import java.io.UncheckedIOException;
2124

2225
import org.opensearch.core.action.ActionResponse;
26+
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
27+
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
2328
import org.opensearch.core.common.io.stream.StreamInput;
2429
import org.opensearch.core.common.io.stream.StreamOutput;
2530
import org.opensearch.core.xcontent.ToXContent;
2631
import org.opensearch.core.xcontent.ToXContentObject;
2732
import org.opensearch.core.xcontent.XContentBuilder;
2833
import org.opensearch.ml.common.conversation.ActionConstants;
34+
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
2935

3036
import lombok.AllArgsConstructor;
3137

@@ -67,4 +73,20 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
6773
return builder;
6874
}
6975

76+
public static CreateConversationResponse fromActionResponse(ActionResponse actionResponse) {
77+
if (actionResponse instanceof MLCreateConnectorResponse) {
78+
return (CreateConversationResponse) actionResponse;
79+
}
80+
81+
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
82+
actionResponse.writeTo(osso);
83+
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
84+
return new CreateConversationResponse(input);
85+
}
86+
} catch (IOException e) {
87+
throw new UncheckedIOException("failed to parse ActionResponse into CreateConversationResponse", e);
88+
}
89+
90+
}
91+
7092
}

0 commit comments

Comments
 (0)