Skip to content

Commit ce78f8f

Browse files
committed
fixing create index step and array input for processors
Signed-off-by: Amit Galitzky <amgalitz@amazon.com>
1 parent 149e22a commit ce78f8f

10 files changed

+196
-19
lines changed

src/main/java/org/opensearch/flowframework/common/DefaultUseCases.java

+7-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,13 @@ public enum DefaultUseCases {
9393
"substitutionTemplates/semantic-search-with-model-and-query-enricher-template.json"
9494
),
9595
/** defaults file and substitution ready template for hybrid search, no model creation*/
96-
HYBRID_SEARCH("hybrid_search", "defaults/hybrid-search-defaults.json", "substitutionTemplates/hybrid-search-template.json");
96+
HYBRID_SEARCH("hybrid_search", "defaults/hybrid-search-defaults.json", "substitutionTemplates/hybrid-search-template.json"),
97+
/** defaults file and substitution ready template for conversational search with cohere chat model*/
98+
CONVERSATIONAL_SEARCH_WITH_COHERE_DEPLOY(
99+
"conversational_search_with_llm_deploy",
100+
"defaults/conversational-search-defaults.json",
101+
"substitutionTemplates/conversational-search-with-cohere-model-template.json"
102+
);
97103

98104
private final String useCaseName;
99105
private final String defaultsFile;

src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java

+3-4
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,11 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
131131
try {
132132
XContentParser parser = request.contentParser();
133133
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
134-
Map<String, String> userDefaults = ParseUtils.parseStringToStringMap(parser);
134+
Map<String, Object> userDefaults = ParseUtils.parseStringToObjectMap(parser);
135135
// updates the default params with anything user has given that matches
136-
for (Map.Entry<String, String> userDefaultsEntry : userDefaults.entrySet()) {
136+
for (Map.Entry<String, Object> userDefaultsEntry : userDefaults.entrySet()) {
137137
String key = userDefaultsEntry.getKey();
138-
String value = userDefaultsEntry.getValue();
138+
String value = userDefaultsEntry.getValue().toString();
139139
if (useCaseDefaultsMap.containsKey(key)) {
140140
useCaseDefaultsMap.put(key, value);
141141
}
@@ -154,7 +154,6 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
154154
null,
155155
useCaseDefaultsMap
156156
);
157-
158157
XContentParser parserTestJson = ParseUtils.jsonToParser(useCaseTemplateFileInStringFormat);
159158
ensureExpectedToken(XContentParser.Token.START_OBJECT, parserTestJson.currentToken(), parserTestJson);
160159
template = Template.parse(parserTestJson);

src/main/java/org/opensearch/flowframework/util/ParseUtils.java

+28
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import java.io.IOException;
3535
import java.io.InputStream;
3636
import java.time.Instant;
37+
import java.util.ArrayList;
3738
import java.util.HashMap;
3839
import java.util.HashSet;
3940
import java.util.List;
@@ -169,6 +170,7 @@ public static Map<String, String> parseStringToStringMap(XContentParser parser)
169170
/**
170171
* Parses an XContent object representing a map of String keys to Object values.
171172
* The Object value here can either be a string or a map
173+
* If an array is found in the given parser we conver the array to a string representation of the array
172174
* @param parser An XContent parser whose position is at the start of the map object to parse
173175
* @return A map as identified by the key-value pairs in the XContent
174176
* @throws IOException on a parse failure
@@ -182,6 +184,15 @@ public static Map<String, Object> parseStringToObjectMap(XContentParser parser)
182184
if (parser.currentToken() == XContentParser.Token.START_OBJECT) {
183185
// If the current token is a START_OBJECT, parse it as Map<String, String>
184186
map.put(fieldName, parseStringToStringMap(parser));
187+
} else if (parser.currentToken() == XContentParser.Token.START_ARRAY) {
188+
// If an array, parse it to a string
189+
// Handle array: convert it to a string representation
190+
List<String> elements = new ArrayList<>();
191+
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
192+
elements.add("\"" + parser.text() + "\""); // Adding escaped quotes around each element
193+
}
194+
String arrayString = "[" + String.join(", ", elements) + "]";
195+
map.put(fieldName, arrayString);
185196
} else {
186197
// Otherwise, parse it as a string
187198
map.put(fieldName, parser.text());
@@ -413,4 +424,21 @@ public static Map<String, String> parseJsonFileToStringToStringMap(String path)
413424
Map<String, String> mappedJsonFile = mapper.readValue(jsonContent, Map.class);
414425
return mappedJsonFile;
415426
}
427+
428+
/**
429+
* Takes an input string, then checks if there is an array in the string with backslashes around strings
430+
* (e.g. "[\"text\", \"hello\"]" to "["text", "hello"]"), this is needed for processors that take in string arrays,
431+
* This also removes the quotations around the array making the array valid to consume
432+
* (e.g. "weights": "[0.7, 0.3]" -> "weights": [0.7, 0.3])
433+
* @param input The inputString given to be transformed
434+
* @return the transformed string
435+
*/
436+
public static String removingBackslashesAndQuotesInArrayInJsonString(String input) {
437+
return Pattern.compile("\"\\[(.*?)]\"").matcher(input).replaceAll(matchResult -> {
438+
// Extract matched content and remove backslashes before quotes
439+
String withoutEscapes = matchResult.group(1).replaceAll("\\\\\"", "\"");
440+
// Return the transformed string with the brackets but without the outer quotes
441+
return "[" + withoutEscapes + "]";
442+
});
443+
}
416444
}

src/main/java/org/opensearch/flowframework/workflow/AbstractCreatePipelineStep.java

+4-5
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,11 @@ public PlainActionFuture<WorkflowData> execute(
8585
String pipelineId = (String) inputs.get(PIPELINE_ID);
8686
String configurations = (String) inputs.get(CONFIGURATIONS);
8787

88-
// Special case for processors that have arrays that need to have the quotes removed
89-
// (e.g. "weights": "[0.7, 0.3]" -> "weights": [0.7, 0.3]
90-
// Define a regular expression pattern to match stringified arrays
91-
String transformedJsonString = configurations.replaceAll("\"\\[(.*?)]\"", "[$1]");
88+
// Special case for processors that have arrays that need to have the quotes around or
89+
// backslashes around strings in array removed
90+
String transformedJsonStringForStringArray = ParseUtils.removingBackslashesAndQuotesInArrayInJsonString(configurations);
9291

93-
byte[] byteArr = transformedJsonString.getBytes(StandardCharsets.UTF_8);
92+
byte[] byteArr = transformedJsonStringForStringArray.getBytes(StandardCharsets.UTF_8);
9493
BytesReference configurationsBytes = new BytesArray(byteArr);
9594

9695
String pipelineToBeCreated = this.getName();

src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java

+34-2
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,27 @@
1414
import org.opensearch.action.admin.indices.create.CreateIndexRequest;
1515
import org.opensearch.action.support.PlainActionFuture;
1616
import org.opensearch.client.Client;
17-
import org.opensearch.common.xcontent.XContentType;
17+
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
18+
import org.opensearch.common.xcontent.XContentHelper;
1819
import org.opensearch.core.action.ActionListener;
1920
import org.opensearch.core.common.bytes.BytesArray;
2021
import org.opensearch.core.common.bytes.BytesReference;
22+
import org.opensearch.core.rest.RestStatus;
23+
import org.opensearch.core.xcontent.MediaTypeRegistry;
2124
import org.opensearch.flowframework.exception.FlowFrameworkException;
25+
import org.opensearch.flowframework.exception.WorkflowStepException;
2226
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
2327
import org.opensearch.flowframework.util.ParseUtils;
28+
import org.opensearch.index.mapper.MapperService;
2429

2530
import java.io.IOException;
2631
import java.nio.charset.StandardCharsets;
2732
import java.util.Collections;
33+
import java.util.HashMap;
2834
import java.util.Map;
2935
import java.util.Set;
3036

37+
import static java.util.Collections.singletonMap;
3138
import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS;
3239
import static org.opensearch.flowframework.common.WorkflowResources.INDEX_NAME;
3340
import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep;
@@ -85,8 +92,13 @@ public PlainActionFuture<WorkflowData> execute(
8592

8693
byte[] byteArr = configurations.getBytes(StandardCharsets.UTF_8);
8794
BytesReference configurationsBytes = new BytesArray(byteArr);
95+
CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName);
96+
if (!configurations.isEmpty()) {
97+
Map<String, Object> sourceAsMap = XContentHelper.convertToMap(configurationsBytes, false, MediaTypeRegistry.JSON).v2();
98+
sourceAsMap = prepareMappings(sourceAsMap);
99+
createIndexRequest.source(sourceAsMap, LoggingDeprecationHandler.INSTANCE);
100+
}
88101

89-
CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName).source(configurationsBytes, XContentType.JSON);
90102
client.admin().indices().create(createIndexRequest, ActionListener.wrap(acknowledgedResponse -> {
91103
String resourceName = getResourceByWorkflowStep(getName());
92104
logger.info("Created index: {}", indexName);
@@ -129,6 +141,26 @@ public PlainActionFuture<WorkflowData> execute(
129141
return createIndexFuture;
130142
}
131143

144+
// This method to check if the mapping contains a type `_doc` and if yes we fail the request
145+
// is to duplicate the behavior we have today through create index rest API, we want users
146+
// to encounter the same behavior and not suddenly have to add `_doc` while using our create_index step
147+
private static Map<String, Object> prepareMappings(Map<String, Object> source) {
148+
if (source.containsKey("mappings") == false || (source.get("mappings") instanceof Map) == false) {
149+
return source;
150+
}
151+
152+
Map<String, Object> newSource = new HashMap<>(source);
153+
154+
@SuppressWarnings("unchecked")
155+
Map<String, Object> mappings = (Map<String, Object>) source.get("mappings");
156+
if (MapperService.isMappingSourceTyped(MapperService.SINGLE_MAPPING_NAME, mappings)) {
157+
throw new WorkflowStepException("The mapping definition cannot be nested under a type", RestStatus.BAD_REQUEST);
158+
}
159+
160+
newSource.put("mappings", singletonMap(MapperService.SINGLE_MAPPING_NAME, mappings));
161+
return newSource;
162+
}
163+
132164
@Override
133165
public String getName() {
134166
return NAME;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
"template.name": "deploy-cohere-chat-model",
3+
"template.description": "deploying cohere chat model",
4+
"create_connector.name": "Cohere Chat Model",
5+
"create_connector.description": "The connector to Cohere's public chat API",
6+
"create_connector.protocol": "http",
7+
"create_connector.model": "command",
8+
"create_connector.endpoint": "api.cohere.ai",
9+
"create_connector.credential.key": "123",
10+
"create_connector.actions.url": "https://api.cohere.ai/v1/chat",
11+
"create_connector.actions.request_body": "{ \"message\": \"${parameters.message}\", \"model\": \"${parameters.model}\" }",
12+
"register_remote_model.name": "Cohere chat model",
13+
"register_remote_model.description": "cohere-chat-model",
14+
"create_search_pipeline.pipeline_id": "rag-pipeline",
15+
"create_search_pipeline.retrieval_augmented_generation.tag": "openai_pipeline_demo",
16+
"create_search_pipeline.retrieval_augmented_generation.description": "Demo pipeline Using cohere Connector",
17+
"create_search_pipeline.retrieval_augmented_generation.context_field_list": "[\"text\", \"hello\"]",
18+
"create_search_pipeline.retrieval_augmented_generation.system_prompt": "You are a helpful assistant",
19+
"create_search_pipeline.retrieval_augmented_generation.user_instructions": "Generate a concise and informative answer in less than 100 words for the given question"
20+
}

src/main/resources/defaults/multimodal-search-bedrock-titan-defaults.json

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"create_connector.name": "Amazon Bedrock Connector: multi-modal embedding",
55
"create_connector.description": "The connector to bedrock Titan multi-modal embedding model",
66
"create_connector.region": "us-east-1",
7-
"create_connector.input_docs_processed_step_size": 2,
7+
"create_connector.input_docs_processed_step_size": "2",
88
"create_connector.endpoint": "api.openai.com",
99
"create_connector.credential.access_key": "123",
1010
"create_connector.credential.secret_key": "123",
@@ -17,12 +17,12 @@
1717
"register_remote_model.description": "bedrock-multi-modal-embedding-model",
1818
"create_ingest_pipeline.pipeline_id": "nlp-multimodal-ingest-pipeline",
1919
"create_ingest_pipeline.description": "A text/image embedding pipeline",
20-
"create_ingest_pipeline.embedding": "vector_embedding",
20+
"text_image_embedding.create_ingest_pipeline.embedding": "vector_embedding",
2121
"text_image_embedding.field_map.text": "image_description",
2222
"text_image_embedding.field_map.image": "image_binary",
2323
"create_index.name": "my-multimodal-nlp-index",
24-
"create_index.settings.number_of_shards": 2,
25-
"text_image_embedding.field_map.output.dimension": 1024,
24+
"create_index.settings.number_of_shards": "2",
25+
"text_image_embedding.field_map.output.dimension": "1024",
2626
"create_index.mappings.method.engine": "lucene",
2727
"create_index.mappings.method.name": "hnsw"
2828
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
{
2+
"name": "${{template.name}}",
3+
"description": "${{template.description}}",
4+
"use_case": "SEMANTIC_SEARCH",
5+
"version": {
6+
"template": "1.0.0",
7+
"compatibility": [
8+
"2.12.0",
9+
"3.0.0"
10+
]
11+
},
12+
"workflows": {
13+
"provision": {
14+
"nodes": [
15+
{
16+
"id": "create_connector",
17+
"type": "create_connector",
18+
"user_inputs": {
19+
"name": "${{create_connector}}",
20+
"description": "${{create_connector.description}}",
21+
"version": "1",
22+
"protocol": "${{create_connector.protocol}}",
23+
"parameters": {
24+
"endpoint": "${{create_connector.endpoint}}",
25+
"model": "${{create_connector.model}}"
26+
},
27+
"credential": {
28+
"key": "${{create_connector.credential.key}}"
29+
},
30+
"actions": [
31+
{
32+
"action_type": "predict",
33+
"method": "POST",
34+
"url": "${{create_connector.actions.url}}",
35+
"headers": {
36+
"Authorization": "Bearer ${credential.key}"
37+
},
38+
"request_body": "${{create_connector.actions.request_body}}"
39+
}
40+
]
41+
}
42+
},
43+
{
44+
"id": "register_model",
45+
"type": "register_remote_model",
46+
"previous_node_inputs": {
47+
"create_connector": "parameters"
48+
},
49+
"user_inputs": {
50+
"name": "${{register_remote_model.name}}",
51+
"function_name": "remote",
52+
"description": "${{register_remote_model.description}}",
53+
"deploy": true
54+
}
55+
},
56+
{
57+
"id": "create_search_pipeline",
58+
"type": "create_search_pipeline",
59+
"previous_node_inputs": {
60+
"register_model": "model_id"
61+
},
62+
"user_inputs": {
63+
"pipeline_id": "${{create_search_pipeline.pipeline_id}}",
64+
"configurations": {
65+
"response_processors": [
66+
{
67+
"retrieval_augmented_generation": {
68+
"tag": "${{create_search_pipeline.retrieval_augmented_generation.tag}}",
69+
"description": "${{create_search_pipeline.retrieval_augmented_generation.description}}",
70+
"model_id": "${{register_model.model_id}}",
71+
"context_field_list": "${{create_search_pipeline.retrieval_augmented_generation.context_field_list}}",
72+
"system_prompt": "${{create_search_pipeline.retrieval_augmented_generation.system_prompt}}",
73+
"user_instructions": "${{create_search_pipeline.retrieval_augmented_generation.user_instructions}}"
74+
}
75+
}
76+
]
77+
}
78+
}
79+
}
80+
]
81+
}
82+
}
83+
}

src/main/resources/substitutionTemplates/multi-modal-search-with-bedrock-titan-template.json

+3-3
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"input_docs_processed_step_size": "${{create_connector.input_docs_processed_step_size}}"
2727
},
2828
"credential": {
29-
"access_ key": "${{create_connector.credential.access_key}}",
29+
"access_key": "${{create_connector.credential.access_key}}",
3030
"secret_key": "${{create_connector.credential.secret_key}}",
3131
"session_token": "${{create_connector.credential.session_token}}"
3232
},
@@ -73,7 +73,7 @@
7373
{
7474
"text_image_embedding": {
7575
"model_id": "${{register_model.model_id}}",
76-
"embedding": "${{create_ingest_pipeline.embedding}}",
76+
"embedding": "${{text_image_embedding.create_ingest_pipeline.embedding}}",
7777
"field_map": {
7878
"text": "${{text_image_embedding.field_map.text}}",
7979
"image": "${{text_image_embedding.field_map.image}}"
@@ -103,7 +103,7 @@
103103
"id": {
104104
"type": "text"
105105
},
106-
"${{text_embedding.field_map.output}}": {
106+
"${{text_image_embedding.create_ingest_pipeline.embedding}}": {
107107
"type": "knn_vector",
108108
"dimension": "${{text_image_embedding.field_map.output.dimension}}",
109109
"method": {

src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java

+10
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,16 @@ public void testConditionallySubstituteWithUnmatchedPlaceholders() {
110110
assertEquals("This string has unmatched ${{placeholder}}", result);
111111
}
112112

113+
public void testRemovingBackslashesAndQuotesInArrayInJsonString() {
114+
String inputNumArray = "normalization-processor.combination.parameters.weights: \"[0.3, 0.7]\"";
115+
String outputNumArray = ParseUtils.removingBackslashesAndQuotesInArrayInJsonString(inputNumArray);
116+
assertEquals("normalization-processor.combination.parameters.weights: [0.3, 0.7]", outputNumArray);
117+
String inputStringArray =
118+
"create_search_pipeline.retrieval_augmented_generation.context_field_list: \"[\\\"text\\\", \\\"hello\\\"]\"";
119+
String outputStringArray = ParseUtils.removingBackslashesAndQuotesInArrayInJsonString(inputStringArray);
120+
assertEquals("create_search_pipeline.retrieval_augmented_generation.context_field_list: [\"text\", \"hello\"]", outputStringArray);
121+
}
122+
113123
public void testConditionallySubstituteWithOutputsSubstitution() {
114124
String input = "This string contains ${{node.step}}";
115125
Map<String, WorkflowData> outputs = new HashMap<>();

0 commit comments

Comments
 (0)