|
8 | 8 | import java.time.Instant;
|
9 | 9 | import java.util.Map;
|
10 | 10 |
|
| 11 | +import org.opensearch.OpenSearchException; |
11 | 12 | import org.opensearch.action.ActionRequest;
|
12 | 13 | import org.opensearch.action.DocWriteResponse;
|
13 | 14 | import org.opensearch.action.support.ActionFilters;
|
14 | 15 | import org.opensearch.action.support.HandledTransportAction;
|
15 |
| -import org.opensearch.action.support.WriteRequest; |
16 |
| -import org.opensearch.action.update.UpdateRequest; |
17 | 16 | import org.opensearch.action.update.UpdateResponse;
|
18 | 17 | import org.opensearch.client.Client;
|
| 18 | +import org.opensearch.cluster.service.ClusterService; |
19 | 19 | import org.opensearch.common.inject.Inject;
|
20 | 20 | import org.opensearch.common.util.concurrent.ThreadContext;
|
21 | 21 | import org.opensearch.core.action.ActionListener;
|
22 | 22 | import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
|
| 23 | +import org.opensearch.ml.memory.ConversationalMemoryHandler; |
| 24 | +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; |
23 | 25 | import org.opensearch.tasks.Task;
|
24 | 26 | import org.opensearch.transport.TransportService;
|
25 | 27 |
|
|
28 | 30 | @Log4j2
|
29 | 31 | public class UpdateConversationTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> {
|
30 | 32 | Client client;
|
| 33 | + private ConversationalMemoryHandler cmHandler; |
| 34 | + |
| 35 | + private volatile boolean featureIsEnabled; |
31 | 36 |
|
32 | 37 | @Inject
|
33 |
| - public UpdateConversationTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) { |
| 38 | + public UpdateConversationTransportAction( |
| 39 | + TransportService transportService, |
| 40 | + ActionFilters actionFilters, |
| 41 | + Client client, |
| 42 | + OpenSearchConversationalMemoryHandler cmHandler, |
| 43 | + ClusterService clusterService |
| 44 | + ) { |
34 | 45 | super(UpdateConversationAction.NAME, transportService, actionFilters, UpdateConversationRequest::new);
|
35 | 46 | this.client = client;
|
| 47 | + this.cmHandler = cmHandler; |
| 48 | + System.out.println(clusterService.getSettings()); |
| 49 | + this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); |
| 50 | + clusterService |
| 51 | + .getClusterSettings() |
| 52 | + .addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it); |
36 | 53 | }
|
37 | 54 |
|
38 | 55 | @Override
|
39 | 56 | protected void doExecute(Task task, ActionRequest request, ActionListener<UpdateResponse> listener) {
|
40 |
| - UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.fromActionRequest(request); |
41 |
| - String conversationId = updateConversationRequest.getConversationId(); |
42 |
| - UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.META_INDEX_NAME, conversationId); |
43 |
| - Map<String, Object> updateContent = updateConversationRequest.getUpdateContent(); |
44 |
| - updateContent.putIfAbsent(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, Instant.now()); |
| 57 | + if (!featureIsEnabled) { |
| 58 | + listener |
| 59 | + .onFailure( |
| 60 | + new OpenSearchException( |
| 61 | + "The experimental Conversation Memory feature is not enabled. To enable, please update the setting " |
| 62 | + + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() |
| 63 | + ) |
| 64 | + ); |
| 65 | + return; |
| 66 | + } else { |
| 67 | + UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.fromActionRequest(request); |
| 68 | + String conversationId = updateConversationRequest.getConversationId(); |
| 69 | + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { |
| 70 | + Map<String, Object> updateContent = updateConversationRequest.getUpdateContent(); |
| 71 | + updateContent.putIfAbsent(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, Instant.now()); |
45 | 72 |
|
46 |
| - updateRequest.doc(updateContent); |
47 |
| - updateRequest.docAsUpsert(true); |
48 |
| - updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); |
49 |
| - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { |
50 |
| - client.update(updateRequest, getUpdateResponseListener(conversationId, listener, context)); |
51 |
| - } catch (Exception e) { |
52 |
| - log.error("Failed to update Conversation for conversation id" + conversationId, e); |
53 |
| - listener.onFailure(e); |
| 73 | + cmHandler.updateConversation(conversationId, updateContent, getUpdateResponseListener(conversationId, listener, context)); |
| 74 | + } catch (Exception e) { |
| 75 | + log.error("Failed to update Conversation " + conversationId, e); |
| 76 | + listener.onFailure(e); |
| 77 | + } |
54 | 78 | }
|
55 | 79 | }
|
56 | 80 |
|
|
0 commit comments