Skip to content

Commit 7eee9f6

Browse files
Enhance Message and Memory API Validation and storage (opensearch-project#3283) (opensearch-project#3295)
* Enchance Message and Memory API Validation and storage Throw an error when an unknown field is provided in CreateConversation or CreateInteraction. Skip saving empty fields in interactions and conversations to optimize storage usage. Modify GET requests for interactions and conversations to return only non-null fields. Throw an exception if all fields in a create interaction call are empty or null. Add unit tests to cover the above cases. Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> * Update unit test to check for null instead of empty map Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> * Refactored userstr to Camel Case Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> * Addressing comments Used assertThrows and added promptTemplate with empty string in test_ToXContent to ensure well rounded testing of expected functionality Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> * Undo: throw an error when an unknown field is provided in CreateConversation or CreateInteraction. Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> --------- Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> (cherry picked from commit 06d39b9) Co-authored-by: Rithin Pullela <rithinp@amazon.com>
1 parent 13985ef commit 7eee9f6

File tree

13 files changed

+160
-93
lines changed

13 files changed

+160
-93
lines changed

common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java

+12-4
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,18 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para
184184
builder.field(ActionConstants.CONVERSATION_ID_FIELD, conversationId);
185185
builder.field(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, id);
186186
builder.field(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, createTime);
187-
builder.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input);
188-
builder.field(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate);
189-
builder.field(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response);
190-
builder.field(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin);
187+
if (input != null && !input.trim().isEmpty()) {
188+
builder.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input);
189+
}
190+
if (promptTemplate != null && !promptTemplate.trim().isEmpty()) {
191+
builder.field(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate);
192+
}
193+
if (response != null && !response.trim().isEmpty()) {
194+
builder.field(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response);
195+
}
196+
if (origin != null && !origin.trim().isEmpty()) {
197+
builder.field(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin);
198+
}
191199
if (additionalInfo != null) {
192200
builder.field(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, additionalInfo);
193201
}

common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,17 @@ public void test_ToXContent() throws IOException {
122122
.builder()
123123
.conversationId("conversation id")
124124
.origin("amazon bedrock")
125+
.promptTemplate(" ")
125126
.parentInteractionId("parant id")
126127
.additionalInfo(Collections.singletonMap("suggestion", "new suggestion"))
128+
.response("sample response")
127129
.traceNum(1)
128130
.build();
129131
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
130132
interaction.toXContent(builder, EMPTY_PARAMS);
131133
String interactionContent = TestHelper.xContentBuilderToString(builder);
132134
assertEquals(
133-
"{\"memory_id\":\"conversation id\",\"message_id\":null,\"create_time\":null,\"input\":null,\"prompt_template\":null,\"response\":null,\"origin\":\"amazon bedrock\",\"additional_info\":{\"suggestion\":\"new suggestion\"},\"parent_message_id\":\"parant id\",\"trace_number\":1}",
135+
"{\"memory_id\":\"conversation id\",\"message_id\":null,\"create_time\":null,\"response\":\"sample response\",\"origin\":\"amazon bedrock\",\"additional_info\":{\"suggestion\":\"new suggestion\"},\"parent_message_id\":\"parant id\",\"trace_number\":1}",
134136
interactionContent
135137
);
136138
}

memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java

+21-5
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,28 @@ public static CreateConversationRequest fromRestRequest(RestRequest restRequest)
137137
}
138138
try (XContentParser parser = restRequest.contentParser()) {
139139
Map<String, Object> body = parser.map();
140+
String name = null;
141+
String applicationType = null;
142+
Map<String, String> additionalInfo = null;
143+
144+
for (String key : body.keySet()) {
145+
switch (key) {
146+
case ActionConstants.REQUEST_CONVERSATION_NAME_FIELD:
147+
name = (String) body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD);
148+
break;
149+
case APPLICATION_TYPE_FIELD:
150+
applicationType = (String) body.get(APPLICATION_TYPE_FIELD);
151+
break;
152+
case META_ADDITIONAL_INFO_FIELD:
153+
additionalInfo = (Map<String, String>) body.get(META_ADDITIONAL_INFO_FIELD);
154+
break;
155+
default:
156+
parser.skipChildren();
157+
break;
158+
}
159+
}
140160
if (body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD) != null) {
141-
return new CreateConversationRequest(
142-
(String) body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD),
143-
body.get(APPLICATION_TYPE_FIELD) == null ? null : (String) body.get(APPLICATION_TYPE_FIELD),
144-
body.get(META_ADDITIONAL_INFO_FIELD) == null ? null : (Map<String, String>) body.get(META_ADDITIONAL_INFO_FIELD)
145-
);
161+
return new CreateConversationRequest(name, applicationType, additionalInfo);
146162
} else {
147163
return new CreateConversationRequest();
148164
}

memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java

+10
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,16 @@ public static CreateInteractionRequest fromRestRequest(RestRequest request) thro
171171
}
172172
}
173173

174+
boolean allFieldsEmpty = (input == null || input.trim().isEmpty())
175+
&& (prompt == null || prompt.trim().isEmpty())
176+
&& (response == null || response.trim().isEmpty())
177+
&& (origin == null || origin.trim().isEmpty())
178+
&& (addinf == null || addinf.isEmpty());
179+
if (allFieldsEmpty) {
180+
throw new IllegalArgumentException(
181+
"At least one of the following parameters must be non-empty: " + "input, prompt_template, response, origin, additional_info"
182+
);
183+
}
174184
return new CreateInteractionRequest(cid, input, prompt, response, origin, addinf, parintid, tracenum);
175185
}
176186

memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java

+34-33
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import java.io.IOException;
2424
import java.time.Instant;
25+
import java.util.HashMap;
2526
import java.util.LinkedList;
2627
import java.util.List;
2728
import java.util.Map;
@@ -139,24 +140,24 @@ public void createConversation(
139140
) {
140141
initConversationMetaIndexIfAbsent(ActionListener.wrap(indexExists -> {
141142
if (indexExists) {
142-
String userstr = getUserStrFromThreadContext();
143+
String userStr = getUserStrFromThreadContext();
143144
Instant now = Instant.now();
144-
IndexRequest request = Requests
145-
.indexRequest(META_INDEX_NAME)
146-
.source(
147-
ConversationalIndexConstants.META_CREATED_TIME_FIELD,
148-
now,
149-
ConversationalIndexConstants.META_UPDATED_TIME_FIELD,
150-
now,
151-
ConversationalIndexConstants.META_NAME_FIELD,
152-
name,
153-
ConversationalIndexConstants.USER_FIELD,
154-
userstr == null ? null : User.parse(userstr).getName(),
155-
ConversationalIndexConstants.APPLICATION_TYPE_FIELD,
156-
applicationType,
157-
ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD,
158-
additionalInfos == null ? Map.of() : additionalInfos
159-
);
145+
Map<String, Object> sourceMap = new HashMap<>();
146+
sourceMap.put(ConversationalIndexConstants.META_CREATED_TIME_FIELD, now);
147+
sourceMap.put(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, now);
148+
if (name != null && !name.trim().isEmpty()) {
149+
sourceMap.put(ConversationalIndexConstants.META_NAME_FIELD, name);
150+
}
151+
if (userStr != null && !userStr.trim().isEmpty()) {
152+
sourceMap.put(ConversationalIndexConstants.USER_FIELD, User.parse(userStr).getName());
153+
}
154+
if (applicationType != null && !applicationType.trim().isEmpty()) {
155+
sourceMap.put(ConversationalIndexConstants.APPLICATION_TYPE_FIELD, applicationType);
156+
}
157+
if (additionalInfos != null && !additionalInfos.isEmpty()) {
158+
sourceMap.put(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD, additionalInfos);
159+
}
160+
IndexRequest request = Requests.indexRequest(META_INDEX_NAME).source(sourceMap);
160161
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
161162
ActionListener<String> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
162163
ActionListener<IndexResponse> al = ActionListener.wrap(resp -> {
@@ -210,12 +211,12 @@ public void getConversations(int from, int maxResults, ActionListener<List<Conve
210211
return;
211212
}
212213
SearchRequest request = Requests.searchRequest(META_INDEX_NAME);
213-
String userstr = getUserStrFromThreadContext();
214+
String userStr = getUserStrFromThreadContext();
214215
QueryBuilder queryBuilder;
215-
if (userstr == null)
216+
if (userStr == null)
216217
queryBuilder = new MatchAllQueryBuilder();
217218
else
218-
queryBuilder = new TermQueryBuilder(ConversationalIndexConstants.USER_FIELD, User.parse(userstr).getName());
219+
queryBuilder = new TermQueryBuilder(ConversationalIndexConstants.USER_FIELD, User.parse(userStr).getName());
219220
request.source().query(queryBuilder);
220221
request.source().from(from).size(maxResults);
221222
request.source().sort(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, SortOrder.DESC);
@@ -264,8 +265,8 @@ public void deleteConversation(String conversationId, ActionListener<Boolean> li
264265
return;
265266
}
266267
DeleteRequest delRequest = Requests.deleteRequest(META_INDEX_NAME).id(conversationId);
267-
String userstr = getUserStrFromThreadContext();
268-
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
268+
String userStr = getUserStrFromThreadContext();
269+
String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName();
269270
this.checkAccess(conversationId, ActionListener.wrap(access -> {
270271
if (access) {
271272
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
@@ -308,7 +309,7 @@ public void checkAccess(String conversationId, ActionListener<Boolean> listener)
308309
listener.onResponse(true);
309310
return;
310311
}
311-
String userstr = getUserStrFromThreadContext();
312+
String userStr = getUserStrFromThreadContext();
312313
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
313314
ActionListener<Boolean> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
314315
GetRequest getRequest = Requests.getRequest(META_INDEX_NAME).id(conversationId);
@@ -318,12 +319,12 @@ public void checkAccess(String conversationId, ActionListener<Boolean> listener)
318319
throw new ResourceNotFoundException("Memory [" + conversationId + "] not found");
319320
}
320321
// If security is off - User doesn't exist - you have permission
321-
if (userstr == null || User.parse(userstr) == null) {
322+
if (userStr == null || User.parse(userStr) == null) {
322323
internalListener.onResponse(true);
323324
return;
324325
}
325326
ConversationMeta conversation = ConversationMeta.fromMap(conversationId, getResponse.getSourceAsMap());
326-
String user = User.parse(userstr).getName();
327+
String user = User.parse(userStr).getName();
327328
// If you're not the owner of this conversation, you do not have permission
328329
if (!user.equals(conversation.getUser())) {
329330
internalListener.onResponse(false);
@@ -353,9 +354,9 @@ public void searchConversations(SearchRequest request, ActionListener<SearchResp
353354
QueryBuilder originalQuery = request.source().query();
354355
BoolQueryBuilder newQuery = new BoolQueryBuilder();
355356
newQuery.must(originalQuery);
356-
String userstr = getUserStrFromThreadContext();
357-
if (userstr != null) {
358-
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
357+
String userStr = getUserStrFromThreadContext();
358+
if (userStr != null) {
359+
String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName();
359360
newQuery.must(new TermQueryBuilder(ConversationalIndexConstants.USER_FIELD, user));
360361
}
361362
request.source().query(newQuery);
@@ -388,11 +389,11 @@ public void updateConversation(String conversationId, UpdateRequest updateReques
388389
if (access) {
389390
innerUpdateConversation(updateRequest, listener);
390391
} else {
391-
String userstr = client
392+
String userStr = client
392393
.threadPool()
393394
.getThreadContext()
394395
.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
395-
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
396+
String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName();
396397
throw new OpenSearchStatusException(
397398
"User [" + user + "] does not have access to memory " + conversationId,
398399
RestStatus.UNAUTHORIZED
@@ -421,7 +422,7 @@ public void getConversation(String conversationId, ActionListener<ConversationMe
421422
listener.onFailure(new IndexNotFoundException("cannot get memory since the memory index does not exist", META_INDEX_NAME));
422423
return;
423424
}
424-
String userstr = getUserStrFromThreadContext();
425+
String userStr = getUserStrFromThreadContext();
425426
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
426427
ActionListener<ConversationMeta> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
427428
GetRequest request = Requests.getRequest(META_INDEX_NAME).id(conversationId);
@@ -432,12 +433,12 @@ public void getConversation(String conversationId, ActionListener<ConversationMe
432433
}
433434
ConversationMeta conversation = ConversationMeta.fromMap(conversationId, getResponse.getSourceAsMap());
434435
// If no security, return conversation
435-
if (userstr == null || User.parse(userstr) == null) {
436+
if (userStr == null || User.parse(userStr) == null) {
436437
internalListener.onResponse(conversation);
437438
return;
438439
}
439440
// If security and correct user, return conversation
440-
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
441+
String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName();
441442
if (user.equals(conversation.getUser())) {
442443
internalListener.onResponse(conversation);
443444
log.info("Successfully get the memory for {}", conversationId);

0 commit comments

Comments
 (0)