Skip to content

Commit 05d8ee4

Browse files
committed
Added Conversation API in MLClient
Signed-off-by: Owais <owaiskazi19@gmail.com>
1 parent 103fbe7 commit 05d8ee4

File tree

7 files changed

+140
-2
lines changed

7 files changed

+140
-2
lines changed

client/build.gradle

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ plugins {
1616
dependencies {
1717
implementation project(path: ":${rootProject.name}-spi", configuration: 'shadow')
1818
implementation project(path: ":${rootProject.name}-common", configuration: 'shadow')
19+
implementation project(path: ":${rootProject.name}-memory")
1920
compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
2021
testImplementation group: 'junit', name: 'junit', version: '4.13.2'
2122
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.15.2'

client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java

+19
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
3535
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
3636
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
37+
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;
3738

3839
/**
3940
* A client to provide interfaces for machine learning jobs. This will be used by other plugins.
@@ -553,4 +554,22 @@ default void getConfig(String configId, ActionListener<MLConfig> listener) {
553554
* @param listener a listener to be notified of the result
554555
*/
555556
void getConfig(String configId, String tenantId, ActionListener<MLConfig> listener);
557+
558+
/**
559+
* Create conversational memory for conversation
560+
* @param name name of the conversation, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/memory-apis/create-memory/
561+
* @return the result future
562+
*/
563+
default ActionFuture<CreateConversationResponse> createConversation(String name) {
564+
PlainActionFuture<CreateConversationResponse> actionFuture = PlainActionFuture.newFuture();
565+
createConversation(name, actionFuture);
566+
return actionFuture;
567+
}
568+
569+
/**
570+
* Create conversational memory for conversation
571+
* @param name name of the conversation, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/memory-apis/create-memory/
572+
* @param listener action listener
573+
*/
574+
void createConversation(String name, ActionListener<CreateConversationResponse> listener);
556575
}

client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java

+18
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@
8888
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
8989
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
9090
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
91+
import org.opensearch.ml.memory.action.conversation.CreateConversationAction;
92+
import org.opensearch.ml.memory.action.conversation.CreateConversationRequest;
93+
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;
9194
import org.opensearch.transport.client.Client;
9295

9396
import lombok.AccessLevel;
@@ -318,6 +321,11 @@ public void getConfig(String configId, String tenantId, ActionListener<MLConfig>
318321
client.execute(MLConfigGetAction.INSTANCE, mlConfigGetRequest, getMlGetConfigResponseActionListener(listener));
319322
}
320323

324+
public void createConversation(String name, ActionListener<CreateConversationResponse> listener) {
325+
CreateConversationRequest createConversationRequest = new CreateConversationRequest(name);
326+
client.execute(CreateConversationAction.INSTANCE, createConversationRequest, getCreateConversationResponseActionListener(listener));
327+
}
328+
321329
private ActionListener<MLToolsListResponse> getMlListToolsResponseActionListener(ActionListener<List<ToolMetadata>> listener) {
322330
ActionListener<MLToolsListResponse> internalListener = ActionListener.wrap(mlModelListResponse -> {
323331
listener.onResponse(mlModelListResponse.getToolMetadataList());
@@ -386,6 +394,16 @@ private ActionListener<MLRegisterModelResponse> getMLRegisterModelResponseAction
386394
return wrapActionListener(listener, MLRegisterModelResponse::fromActionResponse);
387395
}
388396

397+
private ActionListener<CreateConversationResponse> getCreateConversationResponseActionListener(
398+
ActionListener<CreateConversationResponse> listener
399+
) {
400+
ActionListener<CreateConversationResponse> actionListener = wrapActionListener(listener, response -> {
401+
CreateConversationResponse conversationResponse = CreateConversationResponse.fromActionResponse(response);
402+
return conversationResponse;
403+
});
404+
return actionListener;
405+
}
406+
389407
private <T extends ActionResponse> ActionListener<T> wrapActionListener(
390408
final ActionListener<T> listener,
391409
final Function<ActionResponse, T> recreate

client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java

+12-2
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
import org.opensearch.ml.common.output.MLOutput;
5050
import org.opensearch.ml.common.output.MLTrainingOutput;
5151
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
52-
import org.opensearch.ml.common.transport.config.MLConfigGetResponse;
5352
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
5453
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
5554
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
@@ -59,6 +58,7 @@
5958
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
6059
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
6160
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
61+
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;
6262

6363
public class MachineLearningClientTest {
6464

@@ -107,7 +107,7 @@ public class MachineLearningClientTest {
107107
MLRegisterAgentResponse registerAgentResponse;
108108

109109
@Mock
110-
MLConfigGetResponse configGetResponse;
110+
CreateConversationResponse createConversationResponse;
111111

112112
private final String modekId = "test_model_id";
113113
private MLModel mlModel;
@@ -256,6 +256,11 @@ public void deleteAgent(String agentId, String tenantId, ActionListener<DeleteRe
256256
public void getConfig(String configId, String tenantId, ActionListener<MLConfig> listener) {
257257
listener.onResponse(mlConfig);
258258
}
259+
260+
@Override
261+
public void createConversation(String name, ActionListener<CreateConversationResponse> listener) {
262+
listener.onResponse(createConversationResponse);
263+
}
259264
};
260265
}
261266

@@ -554,4 +559,9 @@ public void listTools() {
554559
public void getConfig() {
555560
assertEquals(mlConfig, machineLearningClient.getConfig("configId").actionGet());
556561
}
562+
563+
@Test
564+
public void createConversation() {
565+
assertEquals(createConversationResponse, machineLearningClient.createConversation("Conversation for a RAG pipeline").actionGet());
566+
}
557567
}

client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java

+25
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,9 @@
140140
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
141141
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
142142
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
143+
import org.opensearch.ml.memory.action.conversation.CreateConversationAction;
144+
import org.opensearch.ml.memory.action.conversation.CreateConversationRequest;
145+
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;
143146
import org.opensearch.search.SearchHit;
144147
import org.opensearch.search.SearchHits;
145148
import org.opensearch.search.aggregations.InternalAggregations;
@@ -219,6 +222,9 @@ public class MachineLearningNodeClientTest {
219222
@Mock
220223
ActionListener<MLConfig> getMlConfigListener;
221224

225+
@Mock
226+
ActionListener<CreateConversationResponse> createConversationResponseActionListener;
227+
222228
@InjectMocks
223229
MachineLearningNodeClient machineLearningNodeClient;
224230

@@ -1455,6 +1461,25 @@ public void onFailure(Exception e) {
14551461
verify(client).execute(eq(MLTaskDeleteAction.INSTANCE), isA(MLTaskDeleteRequest.class), any());
14561462
}
14571463

1464+
public void createConversation() {
1465+
String name = "Conversation for a RAG pipeline";
1466+
String conversationId = "conversationId";
1467+
1468+
doAnswer(invocation -> {
1469+
ActionListener<CreateConversationResponse> actionListener = invocation.getArgument(2);
1470+
CreateConversationResponse output = new CreateConversationResponse(conversationId);
1471+
actionListener.onResponse(output);
1472+
return null;
1473+
}).when(client).execute(eq(CreateConversationAction.INSTANCE), any(), any());
1474+
1475+
ArgumentCaptor<CreateConversationResponse> argumentCaptor = ArgumentCaptor.forClass(CreateConversationResponse.class);
1476+
machineLearningNodeClient.createConversation(name, createConversationResponseActionListener);
1477+
1478+
verify(client).execute(eq(CreateConversationAction.INSTANCE), isA(CreateConversationRequest.class), any());
1479+
verify(createConversationResponseActionListener).onResponse(argumentCaptor.capture());
1480+
assertEquals(conversationId, argumentCaptor.getValue().getId());
1481+
}
1482+
14581483
private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
14591484
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
14601485

memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationResponse.java

+22
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,21 @@
1717
*/
1818
package org.opensearch.ml.memory.action.conversation;
1919

20+
import java.io.ByteArrayInputStream;
21+
import java.io.ByteArrayOutputStream;
2022
import java.io.IOException;
23+
import java.io.UncheckedIOException;
2124

2225
import org.opensearch.core.action.ActionResponse;
26+
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
27+
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
2328
import org.opensearch.core.common.io.stream.StreamInput;
2429
import org.opensearch.core.common.io.stream.StreamOutput;
2530
import org.opensearch.core.xcontent.ToXContent;
2631
import org.opensearch.core.xcontent.ToXContentObject;
2732
import org.opensearch.core.xcontent.XContentBuilder;
2833
import org.opensearch.ml.common.conversation.ActionConstants;
34+
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
2935

3036
import lombok.AllArgsConstructor;
3137

@@ -67,4 +73,20 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
6773
return builder;
6874
}
6975

76+
public static CreateConversationResponse fromActionResponse(ActionResponse actionResponse) {
77+
if (actionResponse instanceof MLCreateConnectorResponse) {
78+
return (CreateConversationResponse) actionResponse;
79+
}
80+
81+
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
82+
actionResponse.writeTo(osso);
83+
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
84+
return new CreateConversationResponse(input);
85+
}
86+
} catch (IOException e) {
87+
throw new UncheckedIOException("failed to parse ActionResponse into CreateConversationResponse", e);
88+
}
89+
90+
}
91+
7092
}

memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationResponseTests.java

+43
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,16 @@
1717
*/
1818
package org.opensearch.ml.memory.action.conversation;
1919

20+
import static org.junit.Assert.assertEquals;
21+
2022
import java.io.IOException;
23+
import java.io.UncheckedIOException;
2124

25+
import org.junit.Before;
26+
import org.junit.Test;
2227
import org.opensearch.common.io.stream.BytesStreamOutput;
2328
import org.opensearch.common.xcontent.XContentType;
29+
import org.opensearch.core.action.ActionResponse;
2430
import org.opensearch.core.common.bytes.BytesReference;
2531
import org.opensearch.core.common.io.stream.BytesStreamInput;
2632
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
@@ -32,6 +38,13 @@
3238

3339
public class CreateConversationResponseTests extends OpenSearchTestCase {
3440

41+
CreateConversationResponse response;
42+
43+
@Before
44+
public void setup() {
45+
response = new CreateConversationResponse("test-id");
46+
}
47+
3548
public void testCreateConversationResponseStreaming() throws IOException {
3649
CreateConversationResponse response = new CreateConversationResponse("test-id");
3750
assert (response.getId().equals("test-id"));
@@ -51,4 +64,34 @@ public void testToXContent() throws IOException {
5164
String result = BytesReference.bytes(builder).utf8ToString();
5265
assert (result.equals(expected));
5366
}
67+
68+
@Test
69+
public void fromActionResponseWithCreateConversationResponseSuccess() {
70+
CreateConversationResponse responseFromActionResponse = CreateConversationResponse.fromActionResponse(response);
71+
assertEquals(response.getId(), responseFromActionResponse.getId());
72+
}
73+
74+
@Test
75+
public void fromActionResponseSuccess() {
76+
ActionResponse actionResponse = new ActionResponse() {
77+
@Override
78+
public void writeTo(StreamOutput out) throws IOException {
79+
response.writeTo(out);
80+
}
81+
};
82+
CreateConversationResponse responseFromActionResponse = CreateConversationResponse.fromActionResponse(actionResponse);
83+
assertNotSame(response, responseFromActionResponse);
84+
assertEquals(response.getId(), responseFromActionResponse.getId());
85+
}
86+
87+
@Test(expected = UncheckedIOException.class)
88+
public void fromActionResponseIOException() {
89+
ActionResponse actionResponse = new ActionResponse() {
90+
@Override
91+
public void writeTo(StreamOutput out) throws IOException {
92+
throw new IOException();
93+
}
94+
};
95+
CreateConversationResponse.fromActionResponse(actionResponse);
96+
}
5497
}

0 commit comments

Comments
 (0)