Skip to content

Commit c62ac95

Browse files
committed
add more user based permission check in Memory
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent d0895bb commit c62ac95

13 files changed

+318
-65
lines changed

memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java

+7
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,13 @@ public void createInteraction(
253253
*/
254254
public void updateConversation(String conversationId, Map<String, Object> updateContent, ActionListener<UpdateResponse> listener);
255255

256+
/**
257+
* Update an interaction
258+
* @param updateContent update content for the conversations index
259+
* @param listener receives the update response
260+
*/
261+
public void updateInteraction(String interactionId, Map<String, Object> updateContent, ActionListener<UpdateResponse> listener);
262+
256263
/**
257264
* Get a single ConversationMeta object
258265
* @param conversationId id of the conversation to get

memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationRequest.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import static org.opensearch.action.ValidateActions.addValidationError;
99
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_NAME_FIELD;
10+
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.USER_FIELD;
1011

1112
import java.io.ByteArrayInputStream;
1213
import java.io.ByteArrayOutputStream;
@@ -35,7 +36,7 @@ public class UpdateConversationRequest extends ActionRequest {
3536
private String conversationId;
3637
private Map<String, Object> updateContent;
3738

38-
private static final Set<String> allowedList = new HashSet<>(Arrays.asList(META_NAME_FIELD));
39+
private static final Set<String> allowedList = new HashSet<>(Arrays.asList(META_NAME_FIELD, USER_FIELD));
3940

4041
@Builder
4142
public UpdateConversationRequest(String conversationId, Map<String, Object> updateContent) {

memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java

+40-16
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,20 @@
88
import java.time.Instant;
99
import java.util.Map;
1010

11+
import org.opensearch.OpenSearchException;
1112
import org.opensearch.action.ActionRequest;
1213
import org.opensearch.action.DocWriteResponse;
1314
import org.opensearch.action.support.ActionFilters;
1415
import org.opensearch.action.support.HandledTransportAction;
15-
import org.opensearch.action.support.WriteRequest;
16-
import org.opensearch.action.update.UpdateRequest;
1716
import org.opensearch.action.update.UpdateResponse;
1817
import org.opensearch.client.Client;
18+
import org.opensearch.cluster.service.ClusterService;
1919
import org.opensearch.common.inject.Inject;
2020
import org.opensearch.common.util.concurrent.ThreadContext;
2121
import org.opensearch.core.action.ActionListener;
2222
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
23+
import org.opensearch.ml.memory.ConversationalMemoryHandler;
24+
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
2325
import org.opensearch.tasks.Task;
2426
import org.opensearch.transport.TransportService;
2527

@@ -28,29 +30,51 @@
2830
@Log4j2
2931
public class UpdateConversationTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> {
3032
Client client;
33+
private ConversationalMemoryHandler cmHandler;
34+
35+
private volatile boolean featureIsEnabled;
3136

3237
@Inject
33-
public UpdateConversationTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
38+
public UpdateConversationTransportAction(
39+
TransportService transportService,
40+
ActionFilters actionFilters,
41+
Client client,
42+
OpenSearchConversationalMemoryHandler cmHandler,
43+
ClusterService clusterService
44+
) {
3445
super(UpdateConversationAction.NAME, transportService, actionFilters, UpdateConversationRequest::new);
3546
this.client = client;
47+
this.cmHandler = cmHandler;
48+
System.out.println(clusterService.getSettings());
49+
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
50+
clusterService
51+
.getClusterSettings()
52+
.addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it);
3653
}
3754

3855
@Override
3956
protected void doExecute(Task task, ActionRequest request, ActionListener<UpdateResponse> listener) {
40-
UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.fromActionRequest(request);
41-
String conversationId = updateConversationRequest.getConversationId();
42-
UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.META_INDEX_NAME, conversationId);
43-
Map<String, Object> updateContent = updateConversationRequest.getUpdateContent();
44-
updateContent.putIfAbsent(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, Instant.now());
57+
if (!featureIsEnabled) {
58+
listener
59+
.onFailure(
60+
new OpenSearchException(
61+
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
62+
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()
63+
)
64+
);
65+
return;
66+
} else {
67+
UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.fromActionRequest(request);
68+
String conversationId = updateConversationRequest.getConversationId();
69+
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
70+
Map<String, Object> updateContent = updateConversationRequest.getUpdateContent();
71+
updateContent.putIfAbsent(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, Instant.now());
4572

46-
updateRequest.doc(updateContent);
47-
updateRequest.docAsUpsert(true);
48-
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
49-
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
50-
client.update(updateRequest, getUpdateResponseListener(conversationId, listener, context));
51-
} catch (Exception e) {
52-
log.error("Failed to update Conversation for conversation id" + conversationId, e);
53-
listener.onFailure(e);
73+
cmHandler.updateConversation(conversationId, updateContent, getUpdateResponseListener(conversationId, listener, context));
74+
} catch (Exception e) {
75+
log.error("Failed to update Conversation " + conversationId, e);
76+
listener.onFailure(e);
77+
}
5478
}
5579
}
5680

memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java

+35-10
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,22 @@
55

66
package org.opensearch.ml.memory.action.conversation;
77

8+
import java.util.Map;
9+
10+
import org.opensearch.OpenSearchException;
811
import org.opensearch.action.ActionRequest;
912
import org.opensearch.action.DocWriteResponse;
1013
import org.opensearch.action.support.ActionFilters;
1114
import org.opensearch.action.support.HandledTransportAction;
12-
import org.opensearch.action.support.WriteRequest;
13-
import org.opensearch.action.update.UpdateRequest;
1415
import org.opensearch.action.update.UpdateResponse;
1516
import org.opensearch.client.Client;
17+
import org.opensearch.cluster.service.ClusterService;
1618
import org.opensearch.common.inject.Inject;
1719
import org.opensearch.common.util.concurrent.ThreadContext;
1820
import org.opensearch.core.action.ActionListener;
1921
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
22+
import org.opensearch.ml.memory.ConversationalMemoryHandler;
23+
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
2024
import org.opensearch.tasks.Task;
2125
import org.opensearch.transport.TransportService;
2226

@@ -25,26 +29,47 @@
2529
@Log4j2
2630
public class UpdateInteractionTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> {
2731
Client client;
32+
private ConversationalMemoryHandler cmHandler;
33+
34+
private volatile boolean featureIsEnabled;
2835

2936
@Inject
30-
public UpdateInteractionTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
37+
public UpdateInteractionTransportAction(
38+
TransportService transportService,
39+
ActionFilters actionFilters,
40+
Client client,
41+
OpenSearchConversationalMemoryHandler cmHandler,
42+
ClusterService clusterService
43+
) {
3144
super(UpdateInteractionAction.NAME, transportService, actionFilters, UpdateInteractionRequest::new);
3245
this.client = client;
46+
this.cmHandler = cmHandler;
47+
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
48+
clusterService
49+
.getClusterSettings()
50+
.addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it);
3351
}
3452

3553
@Override
3654
protected void doExecute(Task task, ActionRequest request, ActionListener<UpdateResponse> listener) {
55+
if (!featureIsEnabled) {
56+
listener
57+
.onFailure(
58+
new OpenSearchException(
59+
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
60+
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()
61+
)
62+
);
63+
return;
64+
}
3765
UpdateInteractionRequest updateInteractionRequest = UpdateInteractionRequest.fromActionRequest(request);
3866
String interactionId = updateInteractionRequest.getInteractionId();
39-
UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.INTERACTIONS_INDEX_NAME, interactionId);
40-
updateRequest.doc(updateInteractionRequest.getUpdateContent());
41-
updateRequest.docAsUpsert(true);
42-
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
67+
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
68+
Map<String, Object> updateContent = updateInteractionRequest.getUpdateContent();
4369

44-
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
45-
client.update(updateRequest, getUpdateResponseListener(interactionId, listener, context));
70+
cmHandler.updateInteraction(interactionId, updateContent, getUpdateResponseListener(interactionId, listener, context));
4671
} catch (Exception e) {
47-
log.error("Failed to update Interaction for interaction id " + interactionId, e);
72+
log.error("Failed to update Interaction " + interactionId, e);
4873
listener.onFailure(e);
4974
}
5075
}

memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java

+19-1
Original file line numberDiff line numberDiff line change
@@ -364,14 +364,32 @@ public void searchConversations(SearchRequest request, ActionListener<SearchResp
364364
* @param updateRequest original update request
365365
* @param listener receives the update response for the wrapped query
366366
*/
367-
public void updateConversation(UpdateRequest updateRequest, ActionListener<UpdateResponse> listener) {
367+
public void updateConversation(String conversationId, UpdateRequest updateRequest, ActionListener<UpdateResponse> listener) {
368368
if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) {
369369
listener
370370
.onFailure(
371371
new IndexNotFoundException("cannot update conversation since the conversation index does not exist", META_INDEX_NAME)
372372
);
373373
return;
374374
}
375+
376+
this.checkAccess(conversationId, ActionListener.wrap(access -> {
377+
if (access) {
378+
innerUpdateConversation(updateRequest, listener);
379+
} else {
380+
String userstr = client
381+
.threadPool()
382+
.getThreadContext()
383+
.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
384+
String user = User.parse(userstr) == null
385+
? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS
386+
: User.parse(userstr).getName();
387+
throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId);
388+
}
389+
}, e -> { listener.onFailure(e);}));
390+
}
391+
392+
private void innerUpdateConversation(UpdateRequest updateRequest, ActionListener<UpdateResponse> listener) {
375393
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
376394
ActionListener<UpdateResponse> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
377395
client.update(updateRequest, internalListener);

0 commit comments

Comments
 (0)