Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Conversation API in MLClient #3475

Merged
merged 18 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions client/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ plugins {
dependencies {
implementation project(path: ":${rootProject.name}-spi", configuration: 'shadow')
implementation project(path: ":${rootProject.name}-common", configuration: 'shadow')
implementation project(path: ":${rootProject.name}-memory")
compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
testImplementation group: 'junit', name: 'junit', version: '4.13.2'
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.15.2'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;

/**
* A client to provide interfaces for machine learning jobs. This will be used by other plugins.
Expand Down Expand Up @@ -553,4 +554,22 @@ default void getConfig(String configId, ActionListener<MLConfig> listener) {
* @param listener a listener to be notified of the result
*/
void getConfig(String configId, String tenantId, ActionListener<MLConfig> listener);

/**
* Create conversational memory for conversation
* @param name name of the conversation, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/memory-apis/create-memory/
* @return the result future
*/
default ActionFuture<CreateConversationResponse> createConversation(String name) {
PlainActionFuture<CreateConversationResponse> actionFuture = PlainActionFuture.newFuture();
createConversation(name, actionFuture);
return actionFuture;
}

/**
* Create conversational memory for conversation
* @param name name of the conversation, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/memory-apis/create-memory/
* @param listener action listener
*/
void createConversation(String name, ActionListener<CreateConversationResponse> listener);
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
import org.opensearch.ml.memory.action.conversation.CreateConversationAction;
import org.opensearch.ml.memory.action.conversation.CreateConversationRequest;
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;
import org.opensearch.transport.client.Client;

import lombok.AccessLevel;
Expand Down Expand Up @@ -318,6 +321,11 @@ public void getConfig(String configId, String tenantId, ActionListener<MLConfig>
client.execute(MLConfigGetAction.INSTANCE, mlConfigGetRequest, getMlGetConfigResponseActionListener(listener));
}

public void createConversation(String name, ActionListener<CreateConversationResponse> listener) {
CreateConversationRequest createConversationRequest = new CreateConversationRequest(name);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test coverage is missing, I saw you added a unit test. Could you please check why it's not reflecting?

client.execute(CreateConversationAction.INSTANCE, createConversationRequest, getCreateConversationResponseActionListener(listener));
}

private ActionListener<MLToolsListResponse> getMlListToolsResponseActionListener(ActionListener<List<ToolMetadata>> listener) {
ActionListener<MLToolsListResponse> internalListener = ActionListener.wrap(mlModelListResponse -> {
listener.onResponse(mlModelListResponse.getToolMetadataList());
Expand Down Expand Up @@ -386,6 +394,16 @@ private ActionListener<MLRegisterModelResponse> getMLRegisterModelResponseAction
return wrapActionListener(listener, MLRegisterModelResponse::fromActionResponse);
}

private ActionListener<CreateConversationResponse> getCreateConversationResponseActionListener(
ActionListener<CreateConversationResponse> listener
) {
ActionListener<CreateConversationResponse> actionListener = wrapActionListener(listener, response -> {
CreateConversationResponse conversationResponse = CreateConversationResponse.fromActionResponse(response);
return conversationResponse;
});
return actionListener;
}

private <T extends ActionResponse> ActionListener<T> wrapActionListener(
final ActionListener<T> listener,
final Function<ActionResponse, T> recreate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.config.MLConfigGetResponse;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
Expand All @@ -59,6 +58,7 @@
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;

public class MachineLearningClientTest {

Expand Down Expand Up @@ -107,7 +107,7 @@ public class MachineLearningClientTest {
MLRegisterAgentResponse registerAgentResponse;

@Mock
MLConfigGetResponse configGetResponse;
CreateConversationResponse createConversationResponse;

private final String modekId = "test_model_id";
private MLModel mlModel;
Expand Down Expand Up @@ -256,6 +256,11 @@ public void deleteAgent(String agentId, String tenantId, ActionListener<DeleteRe
public void getConfig(String configId, String tenantId, ActionListener<MLConfig> listener) {
listener.onResponse(mlConfig);
}

@Override
public void createConversation(String name, ActionListener<CreateConversationResponse> listener) {
listener.onResponse(createConversationResponse);
}
};
}

Expand Down Expand Up @@ -554,4 +559,9 @@ public void listTools() {
public void getConfig() {
assertEquals(mlConfig, machineLearningClient.getConfig("configId").actionGet());
}

@Test
public void createConversation() {
assertEquals(createConversationResponse, machineLearningClient.createConversation("Conversation for a RAG pipeline").actionGet());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
import org.opensearch.ml.memory.action.conversation.CreateConversationAction;
import org.opensearch.ml.memory.action.conversation.CreateConversationRequest;
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
Expand Down Expand Up @@ -219,6 +222,9 @@ public class MachineLearningNodeClientTest {
@Mock
ActionListener<MLConfig> getMlConfigListener;

@Mock
ActionListener<CreateConversationResponse> createConversationResponseActionListener;

@InjectMocks
MachineLearningNodeClient machineLearningNodeClient;

Expand Down Expand Up @@ -1455,6 +1461,25 @@ public void onFailure(Exception e) {
verify(client).execute(eq(MLTaskDeleteAction.INSTANCE), isA(MLTaskDeleteRequest.class), any());
}

public void createConversation() {
String name = "Conversation for a RAG pipeline";
String conversationId = "conversationId";

doAnswer(invocation -> {
ActionListener<CreateConversationResponse> actionListener = invocation.getArgument(2);
CreateConversationResponse output = new CreateConversationResponse(conversationId);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(CreateConversationAction.INSTANCE), any(), any());

ArgumentCaptor<CreateConversationResponse> argumentCaptor = ArgumentCaptor.forClass(CreateConversationResponse.class);
machineLearningNodeClient.createConversation(name, createConversationResponseActionListener);

verify(client).execute(eq(CreateConversationAction.INSTANCE), isA(CreateConversationRequest.class), any());
verify(createConversationResponseActionListener).onResponse(argumentCaptor.capture());
assertEquals(conversationId, argumentCaptor.getValue().getId());
}

private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,21 @@
*/
package org.opensearch.ml.memory.action.conversation;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;

import lombok.AllArgsConstructor;

Expand Down Expand Up @@ -67,4 +73,20 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
return builder;
}

public static CreateConversationResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLCreateConnectorResponse) {
return (CreateConversationResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new CreateConversationResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into CreateConversationResponse", e);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@
*/
package org.opensearch.ml.memory.action.conversation;

import static org.junit.Assert.assertEquals;

import java.io.IOException;
import java.io.UncheckedIOException;

import org.junit.Before;
import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.BytesStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
Expand All @@ -32,6 +38,13 @@

public class CreateConversationResponseTests extends OpenSearchTestCase {

CreateConversationResponse response;

@Before
public void setup() {
response = new CreateConversationResponse("test-id");
}

public void testCreateConversationResponseStreaming() throws IOException {
CreateConversationResponse response = new CreateConversationResponse("test-id");
assert (response.getId().equals("test-id"));
Expand All @@ -51,4 +64,34 @@ public void testToXContent() throws IOException {
String result = BytesReference.bytes(builder).utf8ToString();
assert (result.equals(expected));
}

@Test
public void fromActionResponseWithCreateConversationResponseSuccess() {
CreateConversationResponse responseFromActionResponse = CreateConversationResponse.fromActionResponse(response);
assertEquals(response.getId(), responseFromActionResponse.getId());
}

@Test
public void fromActionResponseSuccess() {
ActionResponse actionResponse = new ActionResponse() {
@Override
public void writeTo(StreamOutput out) throws IOException {
response.writeTo(out);
}
};
CreateConversationResponse responseFromActionResponse = CreateConversationResponse.fromActionResponse(actionResponse);
assertNotSame(response, responseFromActionResponse);
assertEquals(response.getId(), responseFromActionResponse.getId());
}

@Test(expected = UncheckedIOException.class)
public void fromActionResponseIOException() {
ActionResponse actionResponse = new ActionResponse() {
@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IOException();
}
};
CreateConversationResponse.fromActionResponse(actionResponse);
}
}
Loading