Skip to content

Commit f18868f

Browse files
committed
Address comments and add one more UT to cover uncovered line
Signed-off-by: zane-neo <zaniu@amazon.com>
1 parent 0fcad86 commit f18868f

File tree

5 files changed

+68
-24
lines changed

5 files changed

+68
-24
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
2121
- Allowing execution of hybrid query on index alias with filters ([#670](https://github.com/opensearch-project/neural-search/pull/670))
2222
### Bug Fixes
2323
- Add support for request_cache flag in hybrid query ([#663](https://github.com/opensearch-project/neural-search/pull/663))
24-
- Fix may type validation issue in multiple pipeline processors ([#661](https://github.com/opensearch-project/neural-search/pull/661))
24+
- Fix map type validation issue in multiple pipeline processors ([#661](https://github.com/opensearch-project/neural-search/pull/661))
2525
### Infrastructure
2626
### Documentation
2727
### Maintenance

src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ private void validateEmbeddingFieldsValue(final IngestDocument ingestDocument) {
177177
fieldMap,
178178
1,
179179
ProcessorDocumentUtils.getMaxDepth(sourceAndMetadataMap, clusterService, environment),
180-
true
180+
false
181181
);
182182
}
183183

src/main/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtils.java

+38-20
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,29 @@
88
import org.opensearch.cluster.metadata.IndexMetadata;
99
import org.opensearch.cluster.service.ClusterService;
1010
import org.opensearch.common.settings.Settings;
11+
import org.opensearch.core.common.util.CollectionUtils;
1112
import org.opensearch.env.Environment;
1213
import org.opensearch.index.mapper.IndexFieldMapper;
1314
import org.opensearch.index.mapper.MapperService;
1415

1516
import java.util.List;
1617
import java.util.Locale;
1718
import java.util.Map;
18-
import java.util.Objects;
1919

20+
/**
21+
* This class is used to accommodate the common code pieces of parsing, validating and processing the document for multiple
22+
* pipeline processors.
23+
*/
2024
public class ProcessorDocumentUtils {
2125

26+
/**
27+
* This method is used to get the max depth of the index or from system settings.
28+
*
29+
* @param sourceAndMetadataMap _source and metadata info in document.
30+
* @param clusterService cluster service passed from OpenSearch core.
31+
* @param environment environment passed from OpenSearch core.
32+
* @return max depth of the index or from system settings.
33+
*/
2234
public static long getMaxDepth(Map<String, Object> sourceAndMetadataMap, ClusterService clusterService, Environment environment) {
2335
String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString();
2436
IndexMetadata indexMetadata = clusterService.state().metadata().index(indexName);
@@ -29,12 +41,23 @@ public static long getMaxDepth(Map<String, Object> sourceAndMetadataMap, Cluster
2941
return MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings());
3042
}
3143

44+
/**
45+
* Validates a map type value recursively up to a specified depth. Supports Map type, List type and String type.
46+
* If current sourceValue is Map or List type, recursively validates its values, otherwise validates its value.
47+
*
48+
* @param sourceKey the key of the source map being validated, the first level is always the "field_map" key.
49+
* @param sourceValue the source map being validated, the first level is always the sourceAndMetadataMap.
50+
* @param fieldMap the configuration map for validation, the first level is always the value of "field_map" in the processor configuration.
51+
* @param depth the current depth of recursion
52+
* @param maxDepth the maximum allowed depth for recursion
53+
* @param allowEmpty flag to allow empty values in map type validation.
54+
*/
3255
@SuppressWarnings({ "rawtypes", "unchecked" })
3356
public static void validateMapTypeValue(
3457
final String sourceKey,
3558
final Map<String, Object> sourceValue,
3659
final Object fieldMap,
37-
final int depth,
60+
final long depth,
3861
final long maxDepth,
3962
final boolean allowEmpty
4063
) {
@@ -73,18 +96,19 @@ private static void validateListTypeValue(
7396
String sourceKey,
7497
List sourceValue,
7598
Object fieldMap,
76-
int depth,
99+
long depth,
77100
long maxDepth,
78101
boolean allowEmpty
79102
) {
80103
validateDepth(sourceKey, depth, maxDepth);
81-
if (sourceValue == null || sourceValue.isEmpty()) return;
82-
Object firstNonNullElement = sourceValue.stream().filter(Objects::nonNull).findFirst().orElse(null);
83-
if (firstNonNullElement == null) return;
104+
if (CollectionUtils.isEmpty(sourceValue)) return;
84105
for (Object element : sourceValue) {
85-
if (firstNonNullElement instanceof List) { // nested list case.
86-
validateListTypeValue(sourceKey, (List) element, fieldMap, depth + 1, maxDepth, allowEmpty);
87-
} else if (firstNonNullElement instanceof Map) {
106+
if (element == null) {
107+
throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it");
108+
}
109+
if (element instanceof List) { // nested list case.
110+
throw new IllegalArgumentException("list type field [" + sourceKey + "] is nested list type, cannot process it");
111+
} else if (element instanceof Map) {
88112
validateMapTypeValue(
89113
sourceKey,
90114
(Map<String, Object>) element,
@@ -93,23 +117,17 @@ private static void validateListTypeValue(
93117
maxDepth,
94118
allowEmpty
95119
);
96-
} else if (!(firstNonNullElement instanceof String)) {
120+
} else if (!(element instanceof String)) {
97121
throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it");
98-
} else {
99-
if (element == null) {
100-
throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it");
101-
} else if (!(element instanceof String)) {
102-
throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it");
103-
} else if (!allowEmpty && StringUtils.isBlank(element.toString())) {
104-
throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it");
105-
}
122+
} else if (!allowEmpty && StringUtils.isBlank(element.toString())) {
123+
throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it");
106124
}
107125
}
108126
}
109127

110-
private static void validateDepth(String sourceKey, int depth, long maxDepth) {
128+
private static void validateDepth(String sourceKey, long depth, long maxDepth) {
111129
if (depth > maxDepth) {
112-
throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, cannot process it");
130+
throw new IllegalArgumentException("map type field [" + sourceKey + "] reaches max depth limit, cannot process it");
113131
}
114132
}
115133
}

src/test/java/org/opensearch/neuralsearch/processor/TextChunkingProcessorTests.java

+5-2
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ public void testExecute_withFixedTokenLength_andMaxDepthLimitExceedFieldMap_then
607607
IllegalArgumentException.class,
608608
() -> processor.execute(ingestDocument)
609609
);
610-
assertEquals("map type field [body] reached max depth limit, cannot process it", illegalArgumentException.getMessage());
610+
assertEquals("map type field [body] reaches max depth limit, cannot process it", illegalArgumentException.getMessage());
611611
}
612612

613613
@SneakyThrows
@@ -657,7 +657,10 @@ public void testExecute_withFixedTokenLength_andSourceDataListWithHybridType_the
657657
IllegalArgumentException.class,
658658
() -> processor.execute(ingestDocument)
659659
);
660-
assertEquals("list type field [body] has non string value, cannot process it", illegalArgumentException.getMessage());
660+
assertEquals(
661+
"[body] configuration doesn't match actual value type, configuration type is: java.lang.String, actual value type is: com.google.common.collect.RegularImmutableMap",
662+
illegalArgumentException.getMessage()
663+
);
661664
}
662665

663666
@SneakyThrows

src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java

+23
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import java.util.function.Supplier;
2727

2828
import org.junit.Before;
29+
import org.mockito.ArgumentCaptor;
2930
import org.mockito.InjectMocks;
3031
import org.mockito.Mock;
3132
import org.mockito.MockitoAnnotations;
@@ -488,6 +489,28 @@ public void test_updateDocument_appendVectorFieldsToDocument_successful() {
488489
assertEquals(2, ((List<?>) ingestDocument.getSourceAndMetadata().get("oriKey6_knn")).size());
489490
}
490491

492+
public void test_doublyNestedList_withMapType_successful() {
493+
Map<String, Object> config = createNestedListConfiguration();
494+
495+
Map<String, Object> toEmbeddings = new HashMap<>();
496+
toEmbeddings.put("textField", "text to embedding");
497+
List<Map<String, Object>> l1List = new ArrayList<>();
498+
l1List.add(toEmbeddings);
499+
List<List<Map<String, Object>>> l2List = new ArrayList<>();
500+
l2List.add(l1List);
501+
Map<String, Object> document = new HashMap<>();
502+
document.put("nestedField", l2List);
503+
document.put(IndexFieldMapper.NAME, "my_index");
504+
505+
IngestDocument ingestDocument = new IngestDocument(document, new HashMap<>());
506+
TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config);
507+
BiConsumer handler = mock(BiConsumer.class);
508+
processor.execute(ingestDocument, handler);
509+
ArgumentCaptor<IllegalArgumentException> argumentCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class);
510+
verify(handler).accept(isNull(), argumentCaptor.capture());
511+
assertEquals("list type field [nestedField] is nested list type, cannot process it", argumentCaptor.getValue().getMessage());
512+
}
513+
491514
private List<List<Float>> createMockVectorResult() {
492515
List<List<Float>> modelTensorList = new ArrayList<>();
493516
List<Float> number1 = ImmutableList.of(1.234f, 2.354f);

0 commit comments

Comments
 (0)