Skip to content

Commit 7a7f001

Browse files
committed
add more user based permission check in Memory
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent d0895bb commit 7a7f001

15 files changed

+327
-65
lines changed

memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java

+7
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,13 @@ public void createInteraction(
253253
*/
254254
public void updateConversation(String conversationId, Map<String, Object> updateContent, ActionListener<UpdateResponse> listener);
255255

256+
/**
257+
* Update an interaction
258+
* @param updateContent update content for the conversations index
259+
* @param listener receives the update response
260+
*/
261+
public void updateInteraction(String interactionId, Map<String, Object> updateContent, ActionListener<UpdateResponse> listener);
262+
256263
/**
257264
* Get a single ConversationMeta object
258265
* @param conversationId id of the conversation to get

memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportAction.java

+2
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ protected void doExecute(Task task, CreateInteractionRequest request, ActionList
9494
Map<String, String> additionalInfo = request.getAdditionalInfo();
9595
String parintIid = request.getParentIid();
9696
Integer traceNumber = request.getTraceNumber();
97+
log.info("parintIid is : " + parintIid);
98+
log.info("traceNumber is : " + traceNumber);
9799
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
98100
ActionListener<CreateInteractionResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
99101
ActionListener<String> al = ActionListener.wrap(iid -> {

memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportAction.java

+1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ public void doExecute(Task task, GetInteractionRequest request, ActionListener<G
8080
return;
8181
}
8282
String interactionId = request.getInteractionId();
83+
log.info("interactionId is: " + interactionId);
8384
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
8485
ActionListener<GetInteractionResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
8586
ActionListener<Interaction> al = ActionListener.wrap(interaction -> {

memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationRequest.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import static org.opensearch.action.ValidateActions.addValidationError;
99
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_NAME_FIELD;
10+
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.USER_FIELD;
1011

1112
import java.io.ByteArrayInputStream;
1213
import java.io.ByteArrayOutputStream;
@@ -35,7 +36,7 @@ public class UpdateConversationRequest extends ActionRequest {
3536
private String conversationId;
3637
private Map<String, Object> updateContent;
3738

38-
private static final Set<String> allowedList = new HashSet<>(Arrays.asList(META_NAME_FIELD));
39+
private static final Set<String> allowedList = new HashSet<>(Arrays.asList(META_NAME_FIELD, USER_FIELD));
3940

4041
@Builder
4142
public UpdateConversationRequest(String conversationId, Map<String, Object> updateContent) {

memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java

+40-16
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,20 @@
88
import java.time.Instant;
99
import java.util.Map;
1010

11+
import org.opensearch.OpenSearchException;
1112
import org.opensearch.action.ActionRequest;
1213
import org.opensearch.action.DocWriteResponse;
1314
import org.opensearch.action.support.ActionFilters;
1415
import org.opensearch.action.support.HandledTransportAction;
15-
import org.opensearch.action.support.WriteRequest;
16-
import org.opensearch.action.update.UpdateRequest;
1716
import org.opensearch.action.update.UpdateResponse;
1817
import org.opensearch.client.Client;
18+
import org.opensearch.cluster.service.ClusterService;
1919
import org.opensearch.common.inject.Inject;
2020
import org.opensearch.common.util.concurrent.ThreadContext;
2121
import org.opensearch.core.action.ActionListener;
2222
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
23+
import org.opensearch.ml.memory.ConversationalMemoryHandler;
24+
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
2325
import org.opensearch.tasks.Task;
2426
import org.opensearch.transport.TransportService;
2527

@@ -28,29 +30,51 @@
2830
@Log4j2
2931
public class UpdateConversationTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> {
3032
Client client;
33+
private ConversationalMemoryHandler cmHandler;
34+
35+
private volatile boolean featureIsEnabled;
3136

3237
@Inject
33-
public UpdateConversationTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
38+
public UpdateConversationTransportAction(
39+
TransportService transportService,
40+
ActionFilters actionFilters,
41+
Client client,
42+
OpenSearchConversationalMemoryHandler cmHandler,
43+
ClusterService clusterService
44+
) {
3445
super(UpdateConversationAction.NAME, transportService, actionFilters, UpdateConversationRequest::new);
3546
this.client = client;
47+
this.cmHandler = cmHandler;
48+
System.out.println(clusterService.getSettings());
49+
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
50+
clusterService
51+
.getClusterSettings()
52+
.addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it);
3653
}
3754

3855
@Override
3956
protected void doExecute(Task task, ActionRequest request, ActionListener<UpdateResponse> listener) {
40-
UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.fromActionRequest(request);
41-
String conversationId = updateConversationRequest.getConversationId();
42-
UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.META_INDEX_NAME, conversationId);
43-
Map<String, Object> updateContent = updateConversationRequest.getUpdateContent();
44-
updateContent.putIfAbsent(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, Instant.now());
57+
if (!featureIsEnabled) {
58+
listener
59+
.onFailure(
60+
new OpenSearchException(
61+
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
62+
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()
63+
)
64+
);
65+
return;
66+
} else {
67+
UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.fromActionRequest(request);
68+
String conversationId = updateConversationRequest.getConversationId();
69+
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
70+
Map<String, Object> updateContent = updateConversationRequest.getUpdateContent();
71+
updateContent.putIfAbsent(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, Instant.now());
4572

46-
updateRequest.doc(updateContent);
47-
updateRequest.docAsUpsert(true);
48-
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
49-
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
50-
client.update(updateRequest, getUpdateResponseListener(conversationId, listener, context));
51-
} catch (Exception e) {
52-
log.error("Failed to update Conversation for conversation id" + conversationId, e);
53-
listener.onFailure(e);
73+
cmHandler.updateConversation(conversationId, updateContent, getUpdateResponseListener(conversationId, listener, context));
74+
} catch (Exception e) {
75+
log.error("Failed to update Conversation " + conversationId, e);
76+
listener.onFailure(e);
77+
}
5478
}
5579
}
5680

memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java

+35-10
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,22 @@
55

66
package org.opensearch.ml.memory.action.conversation;
77

8+
import java.util.Map;
9+
10+
import org.opensearch.OpenSearchException;
811
import org.opensearch.action.ActionRequest;
912
import org.opensearch.action.DocWriteResponse;
1013
import org.opensearch.action.support.ActionFilters;
1114
import org.opensearch.action.support.HandledTransportAction;
12-
import org.opensearch.action.support.WriteRequest;
13-
import org.opensearch.action.update.UpdateRequest;
1415
import org.opensearch.action.update.UpdateResponse;
1516
import org.opensearch.client.Client;
17+
import org.opensearch.cluster.service.ClusterService;
1618
import org.opensearch.common.inject.Inject;
1719
import org.opensearch.common.util.concurrent.ThreadContext;
1820
import org.opensearch.core.action.ActionListener;
1921
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
22+
import org.opensearch.ml.memory.ConversationalMemoryHandler;
23+
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
2024
import org.opensearch.tasks.Task;
2125
import org.opensearch.transport.TransportService;
2226

@@ -25,26 +29,47 @@
2529
@Log4j2
2630
public class UpdateInteractionTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> {
2731
Client client;
32+
private ConversationalMemoryHandler cmHandler;
33+
34+
private volatile boolean featureIsEnabled;
2835

2936
@Inject
30-
public UpdateInteractionTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
37+
public UpdateInteractionTransportAction(
38+
TransportService transportService,
39+
ActionFilters actionFilters,
40+
Client client,
41+
OpenSearchConversationalMemoryHandler cmHandler,
42+
ClusterService clusterService
43+
) {
3144
super(UpdateInteractionAction.NAME, transportService, actionFilters, UpdateInteractionRequest::new);
3245
this.client = client;
46+
this.cmHandler = cmHandler;
47+
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
48+
clusterService
49+
.getClusterSettings()
50+
.addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it);
3351
}
3452

3553
@Override
3654
protected void doExecute(Task task, ActionRequest request, ActionListener<UpdateResponse> listener) {
55+
if (!featureIsEnabled) {
56+
listener
57+
.onFailure(
58+
new OpenSearchException(
59+
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
60+
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()
61+
)
62+
);
63+
return;
64+
}
3765
UpdateInteractionRequest updateInteractionRequest = UpdateInteractionRequest.fromActionRequest(request);
3866
String interactionId = updateInteractionRequest.getInteractionId();
39-
UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.INTERACTIONS_INDEX_NAME, interactionId);
40-
updateRequest.doc(updateInteractionRequest.getUpdateContent());
41-
updateRequest.docAsUpsert(true);
42-
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
67+
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
68+
Map<String, Object> updateContent = updateInteractionRequest.getUpdateContent();
4369

44-
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
45-
client.update(updateRequest, getUpdateResponseListener(interactionId, listener, context));
70+
cmHandler.updateInteraction(interactionId, updateContent, getUpdateResponseListener(interactionId, listener, context));
4671
} catch (Exception e) {
47-
log.error("Failed to update Interaction for interaction id " + interactionId, e);
72+
log.error("Failed to update Interaction " + interactionId, e);
4873
listener.onFailure(e);
4974
}
5075
}

memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java

+21-1
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ public void checkAccess(String conversationId, ActionListener<Boolean> listener)
296296
return;
297297
}
298298
String userstr = getUserStrFromThreadContext();
299+
log.info("user name is :" + User.parse(userstr).getName());
299300
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
300301
ActionListener<Boolean> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
301302
GetRequest getRequest = Requests.getRequest(META_INDEX_NAME).id(conversationId);
@@ -312,6 +313,7 @@ public void checkAccess(String conversationId, ActionListener<Boolean> listener)
312313
ConversationMeta conversation = ConversationMeta.fromMap(conversationId, getResponse.getSourceAsMap());
313314
String user = User.parse(userstr).getName();
314315
// If you're not the owner of this conversation, you do not have permission
316+
log.info("conversation user is :" + conversation.getUser());
315317
if (!user.equals(conversation.getUser())) {
316318
internalListener.onResponse(false);
317319
return;
@@ -364,14 +366,32 @@ public void searchConversations(SearchRequest request, ActionListener<SearchResp
364366
* @param updateRequest original update request
365367
* @param listener receives the update response for the wrapped query
366368
*/
367-
public void updateConversation(UpdateRequest updateRequest, ActionListener<UpdateResponse> listener) {
369+
public void updateConversation(String conversationId, UpdateRequest updateRequest, ActionListener<UpdateResponse> listener) {
368370
if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) {
369371
listener
370372
.onFailure(
371373
new IndexNotFoundException("cannot update conversation since the conversation index does not exist", META_INDEX_NAME)
372374
);
373375
return;
374376
}
377+
378+
this.checkAccess(conversationId, ActionListener.wrap(access -> {
379+
if (access) {
380+
innerUpdateConversation(updateRequest, listener);
381+
} else {
382+
String userstr = client
383+
.threadPool()
384+
.getThreadContext()
385+
.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
386+
String user = User.parse(userstr) == null
387+
? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS
388+
: User.parse(userstr).getName();
389+
throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId);
390+
}
391+
}, e -> { listener.onFailure(e);}));
392+
}
393+
394+
private void innerUpdateConversation(UpdateRequest updateRequest, ActionListener<UpdateResponse> listener) {
375395
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
376396
ActionListener<UpdateResponse> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
377397
client.update(updateRequest, internalListener);

0 commit comments

Comments
 (0)