Skip to content

Commit 6ebc15c

Browse files
committed
add more user based permission check in Memory
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent d0895bb commit 6ebc15c

8 files changed

+290
-57
lines changed

memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java

+7
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,13 @@ public void createInteraction(
253253
*/
254254
public void updateConversation(String conversationId, Map<String, Object> updateContent, ActionListener<UpdateResponse> listener);
255255

256+
/**
257+
* Update an interaction
258+
* @param updateContent update content for the conversations index
259+
* @param listener receives the update response
260+
*/
261+
public void updateInteraction(String interactionId, Map<String, Object> updateContent, ActionListener<UpdateResponse> listener);
262+
256263
/**
257264
* Get a single ConversationMeta object
258265
* @param conversationId id of the conversation to get

memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java

+40-16
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,20 @@
88
import java.time.Instant;
99
import java.util.Map;
1010

11+
import org.opensearch.OpenSearchException;
1112
import org.opensearch.action.ActionRequest;
1213
import org.opensearch.action.DocWriteResponse;
1314
import org.opensearch.action.support.ActionFilters;
1415
import org.opensearch.action.support.HandledTransportAction;
15-
import org.opensearch.action.support.WriteRequest;
16-
import org.opensearch.action.update.UpdateRequest;
1716
import org.opensearch.action.update.UpdateResponse;
1817
import org.opensearch.client.Client;
18+
import org.opensearch.cluster.service.ClusterService;
1919
import org.opensearch.common.inject.Inject;
2020
import org.opensearch.common.util.concurrent.ThreadContext;
2121
import org.opensearch.core.action.ActionListener;
2222
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
23+
import org.opensearch.ml.memory.ConversationalMemoryHandler;
24+
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
2325
import org.opensearch.tasks.Task;
2426
import org.opensearch.transport.TransportService;
2527

@@ -28,29 +30,51 @@
2830
@Log4j2
2931
public class UpdateConversationTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> {
3032
Client client;
33+
private ConversationalMemoryHandler cmHandler;
34+
35+
private volatile boolean featureIsEnabled;
3136

3237
@Inject
33-
public UpdateConversationTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
38+
public UpdateConversationTransportAction(
39+
TransportService transportService,
40+
ActionFilters actionFilters,
41+
Client client,
42+
OpenSearchConversationalMemoryHandler cmHandler,
43+
ClusterService clusterService
44+
) {
3445
super(UpdateConversationAction.NAME, transportService, actionFilters, UpdateConversationRequest::new);
3546
this.client = client;
47+
this.cmHandler = cmHandler;
48+
System.out.println(clusterService.getSettings());
49+
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
50+
clusterService
51+
.getClusterSettings()
52+
.addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it);
3653
}
3754

3855
@Override
3956
protected void doExecute(Task task, ActionRequest request, ActionListener<UpdateResponse> listener) {
40-
UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.fromActionRequest(request);
41-
String conversationId = updateConversationRequest.getConversationId();
42-
UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.META_INDEX_NAME, conversationId);
43-
Map<String, Object> updateContent = updateConversationRequest.getUpdateContent();
44-
updateContent.putIfAbsent(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, Instant.now());
57+
if (!featureIsEnabled) {
58+
listener
59+
.onFailure(
60+
new OpenSearchException(
61+
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
62+
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()
63+
)
64+
);
65+
return;
66+
} else {
67+
UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.fromActionRequest(request);
68+
String conversationId = updateConversationRequest.getConversationId();
69+
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
70+
Map<String, Object> updateContent = updateConversationRequest.getUpdateContent();
71+
updateContent.putIfAbsent(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, Instant.now());
4572

46-
updateRequest.doc(updateContent);
47-
updateRequest.docAsUpsert(true);
48-
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
49-
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
50-
client.update(updateRequest, getUpdateResponseListener(conversationId, listener, context));
51-
} catch (Exception e) {
52-
log.error("Failed to update Conversation for conversation id" + conversationId, e);
53-
listener.onFailure(e);
73+
cmHandler.updateConversation(conversationId, updateContent, getUpdateResponseListener(conversationId, listener, context));
74+
} catch (Exception e) {
75+
log.error("Failed to update Conversation " + conversationId, e);
76+
listener.onFailure(e);
77+
}
5478
}
5579
}
5680

memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java

+35-10
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,22 @@
55

66
package org.opensearch.ml.memory.action.conversation;
77

8+
import java.util.Map;
9+
10+
import org.opensearch.OpenSearchException;
811
import org.opensearch.action.ActionRequest;
912
import org.opensearch.action.DocWriteResponse;
1013
import org.opensearch.action.support.ActionFilters;
1114
import org.opensearch.action.support.HandledTransportAction;
12-
import org.opensearch.action.support.WriteRequest;
13-
import org.opensearch.action.update.UpdateRequest;
1415
import org.opensearch.action.update.UpdateResponse;
1516
import org.opensearch.client.Client;
17+
import org.opensearch.cluster.service.ClusterService;
1618
import org.opensearch.common.inject.Inject;
1719
import org.opensearch.common.util.concurrent.ThreadContext;
1820
import org.opensearch.core.action.ActionListener;
1921
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
22+
import org.opensearch.ml.memory.ConversationalMemoryHandler;
23+
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
2024
import org.opensearch.tasks.Task;
2125
import org.opensearch.transport.TransportService;
2226

@@ -25,26 +29,47 @@
2529
@Log4j2
2630
public class UpdateInteractionTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> {
2731
Client client;
32+
private ConversationalMemoryHandler cmHandler;
33+
34+
private volatile boolean featureIsEnabled;
2835

2936
@Inject
30-
public UpdateInteractionTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
37+
public UpdateInteractionTransportAction(
38+
TransportService transportService,
39+
ActionFilters actionFilters,
40+
Client client,
41+
OpenSearchConversationalMemoryHandler cmHandler,
42+
ClusterService clusterService
43+
) {
3144
super(UpdateInteractionAction.NAME, transportService, actionFilters, UpdateInteractionRequest::new);
3245
this.client = client;
46+
this.cmHandler = cmHandler;
47+
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
48+
clusterService
49+
.getClusterSettings()
50+
.addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it);
3351
}
3452

3553
@Override
3654
protected void doExecute(Task task, ActionRequest request, ActionListener<UpdateResponse> listener) {
55+
if (!featureIsEnabled) {
56+
listener
57+
.onFailure(
58+
new OpenSearchException(
59+
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
60+
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()
61+
)
62+
);
63+
return;
64+
}
3765
UpdateInteractionRequest updateInteractionRequest = UpdateInteractionRequest.fromActionRequest(request);
3866
String interactionId = updateInteractionRequest.getInteractionId();
39-
UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.INTERACTIONS_INDEX_NAME, interactionId);
40-
updateRequest.doc(updateInteractionRequest.getUpdateContent());
41-
updateRequest.docAsUpsert(true);
42-
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
67+
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
68+
Map<String, Object> updateContent = updateInteractionRequest.getUpdateContent();
4369

44-
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
45-
client.update(updateRequest, getUpdateResponseListener(interactionId, listener, context));
70+
cmHandler.updateInteraction(interactionId, updateContent, getUpdateResponseListener(interactionId, listener, context));
4671
} catch (Exception e) {
47-
log.error("Failed to update Interaction for interaction id " + interactionId, e);
72+
log.error("Failed to update Interaction " + interactionId, e);
4873
listener.onFailure(e);
4974
}
5075
}

memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java

+123-2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
import org.opensearch.action.index.IndexResponse;
4141
import org.opensearch.action.search.SearchRequest;
4242
import org.opensearch.action.search.SearchResponse;
43+
import org.opensearch.action.update.UpdateRequest;
44+
import org.opensearch.action.update.UpdateResponse;
4345
import org.opensearch.client.Client;
4446
import org.opensearch.client.Requests;
4547
import org.opensearch.cluster.service.ClusterService;
@@ -330,6 +332,47 @@ public void getTraces(String interactionId, int from, int maxResults, ActionList
330332
listener.onResponse(List.of());
331333
return;
332334
}
335+
336+
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
337+
ActionListener<List<Interaction>> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
338+
GetRequest request = Requests.getRequest(INTERACTIONS_INDEX_NAME).id(interactionId);
339+
ActionListener<GetResponse> al = ActionListener.wrap(getResponse -> {
340+
// If the interaction doesn't exist, fail
341+
if (!(getResponse.isExists() && getResponse.getId().equals(interactionId))) {
342+
throw new ResourceNotFoundException("Interaction [" + interactionId + "] not found");
343+
}
344+
Interaction interaction = Interaction.fromMap(interactionId, getResponse.getSourceAsMap());
345+
// checks if the user has permission to access the conversation that the interaction belongs to
346+
String conversationId = interaction.getConversationId();
347+
ActionListener<Boolean> accessListener = ActionListener.wrap(access -> {
348+
if (access) {
349+
innerGetTraces(interactionId, from, maxResults, listener);
350+
} else {
351+
String userstr = client
352+
.threadPool()
353+
.getThreadContext()
354+
.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
355+
String user = User.parse(userstr) == null
356+
? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS
357+
: User.parse(userstr).getName();
358+
throw new OpenSearchSecurityException("User [" + user + "] does not have access to interaction " + interactionId);
359+
}
360+
}, e -> { listener.onFailure(e); });
361+
conversationMetaIndex.checkAccess(conversationId, accessListener);
362+
}, e -> { internalListener.onFailure(e); });
363+
client.admin().indices().refresh(Requests.refreshRequest(INTERACTIONS_INDEX_NAME), ActionListener.wrap(refreshResponse -> {
364+
client.get(request, al);
365+
}, e -> {
366+
log.error("Failed to refresh interactions index during get interaction ", e);
367+
internalListener.onFailure(e);
368+
}));
369+
} catch (Exception e) {
370+
listener.onFailure(e);
371+
}
372+
}
373+
374+
@VisibleForTesting
375+
void innerGetTraces(String interactionId, int from, int maxResults, ActionListener<List<Interaction>> listener) {
333376
SearchRequest request = Requests.searchRequest(INTERACTIONS_INDEX_NAME);
334377
// Build the query
335378
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();
@@ -509,12 +552,13 @@ public void getInteraction(String interactionId, ActionListener<Interaction> lis
509552
ActionListener<Interaction> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
510553
GetRequest request = Requests.getRequest(INTERACTIONS_INDEX_NAME).id(interactionId);
511554
ActionListener<GetResponse> al = ActionListener.wrap(getResponse -> {
512-
// If the conversation doesn't exist, fail
555+
// If the interaction doesn't exist, fail
513556
if (!(getResponse.isExists() && getResponse.getId().equals(interactionId))) {
514557
throw new ResourceNotFoundException("Interaction [" + interactionId + "] not found");
515558
}
516559
Interaction interaction = Interaction.fromMap(interactionId, getResponse.getSourceAsMap());
517-
internalListener.onResponse(interaction);
560+
// checks if the user has permission to access the conversation that the interaction belongs to
561+
checkInteractionPermission(interactionId, interaction, internalListener);
518562
}, e -> { internalListener.onFailure(e); });
519563
client.admin().indices().refresh(Requests.refreshRequest(INTERACTIONS_INDEX_NAME), ActionListener.wrap(refreshResponse -> {
520564
client.get(request, al);
@@ -526,4 +570,81 @@ public void getInteraction(String interactionId, ActionListener<Interaction> lis
526570
listener.onFailure(e);
527571
}
528572
}
573+
574+
public void updateInteraction(String interactionId, UpdateRequest updateRequest, ActionListener<UpdateResponse> listener) {
575+
if (!clusterService.state().metadata().hasIndex(INTERACTIONS_INDEX_NAME)) {
576+
listener
577+
.onFailure(
578+
new IndexNotFoundException(
579+
"cannot update interaction since the interaction index does not exist",
580+
INTERACTIONS_INDEX_NAME
581+
)
582+
);
583+
return;
584+
}
585+
586+
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
587+
ActionListener<UpdateResponse> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
588+
GetRequest request = Requests.getRequest(INTERACTIONS_INDEX_NAME).id(interactionId);
589+
ActionListener<GetResponse> al = ActionListener.wrap(getResponse -> {
590+
// If the interaction doesn't exist, fail
591+
if (!(getResponse.isExists() && getResponse.getId().equals(interactionId))) {
592+
throw new ResourceNotFoundException("Interaction [" + interactionId + "] not found");
593+
}
594+
Interaction interaction = Interaction.fromMap(interactionId, getResponse.getSourceAsMap());
595+
// checks if the user has permission to access the conversation that the interaction belongs to
596+
String conversationId = interaction.getConversationId();
597+
ActionListener<Boolean> accessListener = ActionListener.wrap(access -> {
598+
if (access) {
599+
innerUpdateInteraction(updateRequest, internalListener);
600+
} else {
601+
String userstr = client
602+
.threadPool()
603+
.getThreadContext()
604+
.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
605+
String user = User.parse(userstr) == null
606+
? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS
607+
: User.parse(userstr).getName();
608+
throw new OpenSearchSecurityException("User [" + user + "] does not have access to interaction " + interactionId);
609+
}
610+
}, e -> { listener.onFailure(e); });
611+
conversationMetaIndex.checkAccess(conversationId, accessListener);
612+
}, e -> { internalListener.onFailure(e); });
613+
client.admin().indices().refresh(Requests.refreshRequest(INTERACTIONS_INDEX_NAME), ActionListener.wrap(refreshResponse -> {
614+
client.get(request, al);
615+
}, e -> {
616+
log.error("Failed to refresh interactions index during get interaction ", e);
617+
internalListener.onFailure(e);
618+
}));
619+
} catch (Exception e) {
620+
listener.onFailure(e);
621+
}
622+
}
623+
624+
private void innerUpdateInteraction(UpdateRequest updateRequest, ActionListener<UpdateResponse> listener) {
625+
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
626+
ActionListener<UpdateResponse> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
627+
client.update(updateRequest, internalListener);
628+
} catch (Exception e) {
629+
log.error("Failed to update Conversation. Details {}:", e);
630+
listener.onFailure(e);
631+
}
632+
}
633+
634+
private void checkInteractionPermission(String interactionId, Interaction interaction, ActionListener<Interaction> internalListener) {
635+
String conversationId = interaction.getConversationId();
636+
ActionListener<Boolean> accessListener = ActionListener.wrap(access -> {
637+
if (access) {
638+
internalListener.onResponse(interaction);
639+
} else {
640+
String userstr = client
641+
.threadPool()
642+
.getThreadContext()
643+
.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
644+
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
645+
throw new OpenSearchSecurityException("User [" + user + "] does not have access to interaction " + interactionId);
646+
}
647+
}, e -> { internalListener.onFailure(e); });
648+
conversationMetaIndex.checkAccess(conversationId, accessListener);
649+
}
529650
}

memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java

+10
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,16 @@ public void updateConversation(String conversationId, Map<String, Object> update
399399
conversationMetaIndex.updateConversation(updateRequest, listener);
400400
}
401401

402+
public void updateInteraction(String interactionId, Map<String, Object> updateContent, ActionListener<UpdateResponse> listener) {
403+
UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.INTERACTIONS_INDEX_NAME, interactionId);
404+
405+
updateRequest.doc(updateContent);
406+
updateRequest.docAsUpsert(true);
407+
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
408+
409+
interactionsIndex.updateInteraction(interactionId, updateRequest, listener);
410+
}
411+
402412
/**
403413
* Get a single ConversationMeta object
404414
* @param conversationId id of the conversation to get

0 commit comments

Comments
 (0)