Skip to content

Commit 6d6ae29

Browse files
zhanghg08jackiehanyang
authored andcommitted
Add MLInputDatasetHandler to handle search query input (opensearch-project#36)
* Add MLInputDatasetHandler to handle search query input * Add type check for MLInputDatasetHandler. Add related test cases
1 parent 8e03fc4 commit 6d6ae29

File tree

2 files changed

+267
-0
lines changed

2 files changed

+267
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*
8+
* Modifications Copyright OpenSearch Contributors. See
9+
* GitHub history for details.
10+
*
11+
*/
12+
13+
package org.opensearch.ml.indices;
14+
15+
import lombok.AccessLevel;
16+
import lombok.RequiredArgsConstructor;
17+
import lombok.experimental.FieldDefaults;
18+
import lombok.extern.log4j.Log4j2;
19+
import org.opensearch.action.ActionListener;
20+
import org.opensearch.action.search.SearchRequest;
21+
import org.opensearch.client.Client;
22+
import org.opensearch.ml.common.dataframe.DataFrame;
23+
import org.opensearch.ml.common.dataframe.DataFrameBuilder;
24+
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
25+
import org.opensearch.ml.common.dataset.MLInputDataType;
26+
import org.opensearch.ml.common.dataset.MLInputDataset;
27+
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
28+
import org.opensearch.search.SearchHit;
29+
import org.opensearch.search.SearchHits;
30+
31+
import java.util.ArrayList;
32+
import java.util.List;
33+
import java.util.Map;
34+
35+
/**
36+
* Convert MLInputDataset to Dataframe
37+
*/
38+
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
39+
@RequiredArgsConstructor
40+
@Log4j2
41+
public class MLInputDatasetHandler {
42+
Client client;
43+
44+
/**
45+
* Retrieve DataFrame from DataFrameInputDataset
46+
* @param mlInputDataset MLInputDataset
47+
* @return DataFrame
48+
*/
49+
public DataFrame parseDataFrameInput(MLInputDataset mlInputDataset) {
50+
if (!mlInputDataset.getInputDataType().equals(MLInputDataType.DATA_FRAME)) {
51+
throw new IllegalArgumentException("Input dataset is not DATA_FRAME type.");
52+
}
53+
DataFrameInputDataset inputDataset = (DataFrameInputDataset) mlInputDataset;
54+
return inputDataset.getDataFrame();
55+
}
56+
57+
/**
58+
* Create DataFrame based on given search query
59+
* @param mlInputDataset MLInputDataset
60+
* @param listener ActionListener
61+
*/
62+
public void parseSearchQueryInput(MLInputDataset mlInputDataset, ActionListener<DataFrame> listener) {
63+
if (!mlInputDataset.getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) {
64+
throw new IllegalArgumentException("Input dataset is not SEARCH_QUERY type.");
65+
}
66+
SearchQueryInputDataset inputDataset = (SearchQueryInputDataset) mlInputDataset;
67+
SearchRequest searchRequest = new SearchRequest();
68+
searchRequest.source(inputDataset.getSearchSourceBuilder());
69+
List<String> indicesList = inputDataset.getIndices();
70+
String[] indices = new String[indicesList.size()];
71+
indices = indicesList.toArray(indices);
72+
searchRequest.indices(indices);
73+
74+
client.search(searchRequest, ActionListener.wrap(r -> {
75+
if (
76+
r == null ||
77+
r.getHits() == null ||
78+
r.getHits().getTotalHits() == null ||
79+
r.getHits().getTotalHits().value == 0
80+
) {
81+
// todo: add specific exception
82+
listener.onFailure(new RuntimeException("No document found"));
83+
return;
84+
}
85+
SearchHits hits = r.getHits();
86+
List<Map<String, Object>> input = new ArrayList<>();
87+
SearchHit[] searchHits = hits.getHits();
88+
for (SearchHit hit : searchHits) {
89+
input.add(hit.getSourceAsMap());
90+
}
91+
DataFrame dataFrame = DataFrameBuilder.load(input);
92+
listener.onResponse(dataFrame);
93+
return;
94+
}, e -> {
95+
log.error("Failed to search" + e);
96+
listener.onFailure(e);
97+
}));
98+
return;
99+
}
100+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*
8+
* Modifications Copyright OpenSearch Contributors. See
9+
* GitHub history for details.
10+
*
11+
*/
12+
13+
package org.opensearch.ml.indices;
14+
15+
import org.apache.lucene.search.TotalHits;
16+
import org.junit.Assert;
17+
import org.junit.Before;
18+
import org.junit.Rule;
19+
import org.junit.Test;
20+
import org.junit.rules.ExpectedException;
21+
import org.opensearch.action.ActionListener;
22+
import org.opensearch.action.search.SearchResponse;
23+
import org.opensearch.client.Client;
24+
import org.opensearch.common.bytes.BytesArray;
25+
import org.opensearch.common.bytes.BytesReference;
26+
import org.opensearch.index.query.QueryBuilders;
27+
import org.opensearch.ml.common.dataframe.DataFrame;
28+
import org.opensearch.ml.common.dataframe.DataFrameBuilder;
29+
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
30+
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
31+
import org.opensearch.search.SearchHit;
32+
import org.opensearch.search.SearchHits;
33+
import org.opensearch.search.builder.SearchSourceBuilder;
34+
35+
import java.util.ArrayList;
36+
import java.util.Arrays;
37+
import java.util.Collections;
38+
import java.util.HashMap;
39+
import java.util.List;
40+
import java.util.Map;
41+
42+
import static org.mockito.Matchers.any;
43+
import static org.mockito.Mockito.doAnswer;
44+
import static org.mockito.Mockito.mock;
45+
import static org.mockito.Mockito.spy;
46+
import static org.mockito.Mockito.times;
47+
import static org.mockito.Mockito.verify;
48+
import static org.mockito.Mockito.when;
49+
import org.mockito.ArgumentCaptor;
50+
51+
52+
public class MLInputDatasetHandlerTests{
53+
Client client;
54+
MLInputDatasetHandler mlInputDatasetHandler;
55+
ActionListener<DataFrame> listener;
56+
DataFrame dataFrame;
57+
SearchResponse searchResponse;
58+
59+
@Rule
60+
public ExpectedException expectedEx = ExpectedException.none();
61+
62+
@Before
63+
public void setup() {
64+
Map<String, Object> source = new HashMap<>();
65+
source.put("taskId", "111");
66+
List<Map<String, Object>> mapList = new ArrayList<>();
67+
mapList.add(source);
68+
dataFrame = DataFrameBuilder.load(mapList);
69+
client = mock(Client.class);
70+
mlInputDatasetHandler = new MLInputDatasetHandler(client);
71+
listener = spy(new ActionListener<DataFrame>() {
72+
@Override
73+
public void onResponse(DataFrame dataFrame) {}
74+
75+
@Override
76+
public void onFailure(Exception e) {}
77+
});
78+
79+
}
80+
81+
@Test
82+
public void testDataFrameInputDataset() {
83+
DataFrame testDataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap<String, Object>() {
84+
{
85+
put("key1", 2.0D);
86+
}
87+
}));
88+
DataFrameInputDataset dataFrameInputDataset = DataFrameInputDataset.builder()
89+
.dataFrame(testDataFrame)
90+
.build();
91+
DataFrame result = mlInputDatasetHandler.parseDataFrameInput(dataFrameInputDataset);
92+
Assert.assertEquals(testDataFrame, result);
93+
}
94+
95+
@Test
96+
public void testDataFrameInputDatasetWrongType() {
97+
expectedEx.expect(IllegalArgumentException.class);
98+
expectedEx.expectMessage("Input dataset is not DATA_FRAME type.");
99+
SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset.builder()
100+
.indices(Arrays.asList("index1"))
101+
.searchSourceBuilder(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery()))
102+
.build();
103+
DataFrame result = mlInputDatasetHandler.parseDataFrameInput(searchQueryInputDataset);
104+
}
105+
106+
107+
@Test
108+
@SuppressWarnings("unchecked")
109+
public void testSearchQueryInputDatasetWithHits() {
110+
searchResponse = mock(SearchResponse.class);
111+
BytesReference bytesArray = new BytesArray("{\"taskId\":\"111\"}");
112+
SearchHit hit = new SearchHit( 1 );
113+
hit.sourceRef(bytesArray);
114+
SearchHits hits = new SearchHits(new SearchHit[] {hit}, new TotalHits(1L, TotalHits.Relation.EQUAL_TO), 1f);
115+
when(searchResponse.getHits()).thenReturn(hits);
116+
doAnswer(invocation -> {
117+
ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocation.getArguments() [1];
118+
listener.onResponse(searchResponse);
119+
return null;
120+
}).when(client).search(any(), any());
121+
122+
SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset.builder()
123+
.indices(Arrays.asList("index1"))
124+
.searchSourceBuilder(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery()))
125+
.build();
126+
mlInputDatasetHandler.parseSearchQueryInput(searchQueryInputDataset, listener);
127+
ArgumentCaptor<DataFrame> captor = ArgumentCaptor.forClass(DataFrame.class);
128+
verify(listener, times(1)).onResponse(captor.capture());
129+
Assert.assertEquals(captor.getAllValues().size(), 1);
130+
}
131+
132+
@Test
133+
@SuppressWarnings("unchecked")
134+
public void testSearchQueryInputDatasetWithoutHits() {
135+
searchResponse = mock(SearchResponse.class);
136+
SearchHits hits = new SearchHits(new SearchHit[0], new TotalHits(1L, TotalHits.Relation.EQUAL_TO), 1f);
137+
when(searchResponse.getHits()).thenReturn(hits);
138+
doAnswer(invocation -> {
139+
ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocation.getArguments() [1];
140+
listener.onResponse(searchResponse);
141+
return null;
142+
}).when(client).search(any(), any());
143+
144+
SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset.builder()
145+
.indices(Arrays.asList("index1"))
146+
.searchSourceBuilder(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery()))
147+
.build();
148+
mlInputDatasetHandler.parseSearchQueryInput(searchQueryInputDataset, listener);
149+
verify(listener, times(1)).onFailure(any());
150+
}
151+
152+
@Test
153+
public void testSearchQueryInputDatasetWrongType() {
154+
expectedEx.expect(IllegalArgumentException.class);
155+
expectedEx.expectMessage("Input dataset is not SEARCH_QUERY type.");
156+
DataFrame testDataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap<String, Object>() {
157+
{
158+
put("key1", 2.0D);
159+
}
160+
}));
161+
DataFrameInputDataset dataFrameInputDataset = DataFrameInputDataset.builder()
162+
.dataFrame(testDataFrame)
163+
.build();
164+
mlInputDatasetHandler.parseSearchQueryInput(dataFrameInputDataset, listener);
165+
}
166+
167+
}

0 commit comments

Comments
 (0)