36
36
import org .apache .lucene .search .SortField ;
37
37
import org .apache .lucene .search .TotalHits ;
38
38
import org .apache .lucene .search .TotalHits .Relation ;
39
+ import org .apache .lucene .util .SuppressForbidden ;
40
+ import org .opensearch .OpenSearchException ;
39
41
import org .opensearch .common .Nullable ;
40
42
import org .opensearch .common .annotation .PublicApi ;
41
43
import org .opensearch .common .lucene .Lucene ;
50
52
import org .opensearch .server .proto .FetchSearchResultProto ;
51
53
import org .opensearch .server .proto .QuerySearchResultProto ;
52
54
55
+ import java .io .ByteArrayInputStream ;
56
+ import java .io .ByteArrayOutputStream ;
53
57
import java .io .IOException ;
58
+ import java .io .InputStream ;
59
+ import java .io .ObjectInputStream ;
60
+ import java .io .ObjectOutputStream ;
54
61
import java .io .OutputStream ;
55
62
import java .util .ArrayList ;
56
63
import java .util .Arrays ;
@@ -93,6 +100,7 @@ public SearchHits(SearchHit[] hits, @Nullable TotalHits totalHits, float maxScor
93
100
this (hits , totalHits , maxScore , null , null , null );
94
101
}
95
102
103
+ @ SuppressForbidden (reason = "serialization of object to protobuf" )
96
104
public SearchHits (
97
105
SearchHit [] hits ,
98
106
@ Nullable TotalHits totalHits ,
@@ -109,18 +117,18 @@ public SearchHits(
109
117
this .collapseValues = collapseValues ;
110
118
if (FeatureFlags .isEnabled (FeatureFlags .PROTOBUF_SETTING )) {
111
119
List <FetchSearchResultProto .SearchHit > searchHitList = new ArrayList <>();
112
- for (int i = 0 ; i < hits . length ; i ++ ) {
120
+ for (SearchHit hit : hits ) {
113
121
FetchSearchResultProto .SearchHit .Builder searchHitBuilder = FetchSearchResultProto .SearchHit .newBuilder ();
114
- if (hits [ i ] .getIndex () != null ) {
115
- searchHitBuilder .setIndex (hits [ i ] .getIndex ());
122
+ if (hit .getIndex () != null ) {
123
+ searchHitBuilder .setIndex (hit .getIndex ());
116
124
}
117
- searchHitBuilder .setId (hits [ i ] .getId ());
118
- searchHitBuilder .setScore (hits [ i ] .getScore ());
119
- searchHitBuilder .setSeqNo (hits [ i ] .getSeqNo ());
120
- searchHitBuilder .setPrimaryTerm (hits [ i ] .getPrimaryTerm ());
121
- searchHitBuilder .setVersion (hits [ i ] .getVersion ());
122
- if (hits [ i ] .getSourceRef () != null ) {
123
- searchHitBuilder .setSource (ByteString .copyFrom (hits [ i ] .getSourceRef ().toBytesRef ().bytes ));
125
+ searchHitBuilder .setId (hit .getId ());
126
+ searchHitBuilder .setScore (hit .getScore ());
127
+ searchHitBuilder .setSeqNo (hit .getSeqNo ());
128
+ searchHitBuilder .setPrimaryTerm (hit .getPrimaryTerm ());
129
+ searchHitBuilder .setVersion (hit .getVersion ());
130
+ if (hit .getSourceRef () != null ) {
131
+ searchHitBuilder .setSource (ByteString .copyFrom (hit .getSourceRef ().toBytesRef ().bytes ));
124
132
}
125
133
searchHitList .add (searchHitBuilder .build ());
126
134
}
@@ -135,8 +143,25 @@ public SearchHits(
135
143
searchHitsBuilder .setMaxScore (maxScore );
136
144
searchHitsBuilder .addAllHits (searchHitList );
137
145
searchHitsBuilder .setTotalHits (totalHitsBuilder .build ());
146
+ if (sortFields != null && sortFields .length > 0 ) {
147
+ for (SortField sortField : sortFields ) {
148
+ FetchSearchResultProto .SortField .Builder sortFieldBuilder = FetchSearchResultProto .SortField .newBuilder ();
149
+ sortFieldBuilder .setField (sortField .getField ());
150
+ sortFieldBuilder .setType (FetchSearchResultProto .SortField .Type .valueOf (sortField .getType ().name ()));
151
+ searchHitsBuilder .addSortFields (sortFieldBuilder .build ());
152
+ }
153
+ }
138
154
if (collapseField != null ) {
139
155
searchHitsBuilder .setCollapseField (collapseField );
156
+ for (Object value : collapseValues ) {
157
+ ByteArrayOutputStream bos = new ByteArrayOutputStream ();
158
+ try (ObjectOutputStream stream = new ObjectOutputStream (bos )) {
159
+ stream .writeObject (value );
160
+ searchHitsBuilder .addCollapseValues (ByteString .copyFrom (bos .toByteArray ()));
161
+ } catch (IOException e ) {
162
+ throw new OpenSearchException (e );
163
+ }
164
+ }
140
165
}
141
166
this .searchHitsProto = searchHitsBuilder .build ();
142
167
}
@@ -164,6 +189,7 @@ public SearchHits(StreamInput in) throws IOException {
164
189
collapseValues = in .readOptionalArray (Lucene ::readSortValue , Object []::new );
165
190
}
166
191
192
+ @ SuppressForbidden (reason = "serialization of object to protobuf" )
167
193
public SearchHits (byte [] in ) throws IOException {
168
194
this .searchHitsProto = org .opensearch .server .proto .FetchSearchResultProto .SearchHits .parseFrom (in );
169
195
this .hits = new SearchHit [this .searchHitsProto .getHitsCount ()];
@@ -175,10 +201,21 @@ public SearchHits(byte[] in) throws IOException {
175
201
Relation .valueOf (this .searchHitsProto .getTotalHits ().getRelation ().toString ())
176
202
);
177
203
this .maxScore = this .searchHitsProto .getMaxScore ();
204
+ this .sortFields = this .searchHitsProto .getSortFieldsList ()
205
+ .stream ()
206
+ .map (sortField -> new SortField (sortField .getField (), SortField .Type .valueOf (sortField .getType ().toString ())))
207
+ .toArray (SortField []::new );
178
208
this .collapseField = this .searchHitsProto .getCollapseField ();
179
- // Below fields are set to null currently, support to be added in the future
180
- this .collapseValues = null ;
181
- this .sortFields = null ;
209
+ this .collapseValues = new Object [this .searchHitsProto .getCollapseValuesCount ()];
210
+ for (int i = 0 ; i < this .searchHitsProto .getCollapseValuesCount (); i ++) {
211
+ ByteString collapseValue = this .searchHitsProto .getCollapseValues (i );
212
+ InputStream is = new ByteArrayInputStream (collapseValue .toByteArray ());
213
+ try (ObjectInputStream ois = new ObjectInputStream (is )) {
214
+ this .collapseValues [i ] = ois .readObject ();
215
+ } catch (ClassNotFoundException e ) {
216
+ throw new OpenSearchException (e );
217
+ }
218
+ }
182
219
}
183
220
184
221
@ Override
0 commit comments