34
34
35
35
import org .opensearch .OpenSearchException ;
36
36
import org .opensearch .Version ;
37
+ import org .opensearch .action .OriginalIndices ;
38
+ import org .opensearch .action .OriginalIndicesTests ;
39
+ import org .opensearch .action .search .SearchRequest ;
37
40
import org .opensearch .cluster .node .DiscoveryNode ;
41
+ import org .opensearch .common .SuppressForbidden ;
42
+ import org .opensearch .common .UUIDs ;
38
43
import org .opensearch .common .bytes .ReleasableBytesReference ;
39
44
import org .opensearch .common .collect .Tuple ;
40
45
import org .opensearch .common .io .stream .BytesStreamOutput ;
41
46
import org .opensearch .common .unit .TimeValue ;
42
47
import org .opensearch .common .util .BigArrays ;
48
+ import org .opensearch .common .util .FeatureFlags ;
43
49
import org .opensearch .common .util .PageCacheRecycler ;
44
50
import org .opensearch .common .util .concurrent .ThreadContext ;
45
51
import org .opensearch .common .util .io .Streams ;
46
52
import org .opensearch .core .action .ActionListener ;
53
+ import org .opensearch .core .common .Strings ;
47
54
import org .opensearch .core .common .breaker .CircuitBreaker ;
48
55
import org .opensearch .core .common .breaker .NoopCircuitBreaker ;
49
56
import org .opensearch .core .common .bytes .BytesArray ;
50
57
import org .opensearch .core .common .bytes .BytesReference ;
51
58
import org .opensearch .core .common .transport .TransportAddress ;
59
+ import org .opensearch .core .index .shard .ShardId ;
52
60
import org .opensearch .core .transport .TransportResponse ;
61
+ import org .opensearch .search .SearchShardTarget ;
62
+ import org .opensearch .search .fetch .FetchSearchResult ;
63
+ import org .opensearch .search .fetch .QueryFetchSearchResult ;
64
+ import org .opensearch .search .internal .AliasFilter ;
65
+ import org .opensearch .search .internal .ShardSearchContextId ;
66
+ import org .opensearch .search .internal .ShardSearchRequest ;
67
+ import org .opensearch .search .query .QuerySearchResult ;
53
68
import org .opensearch .test .OpenSearchTestCase ;
54
69
import org .opensearch .threadpool .TestThreadPool ;
55
70
import org .opensearch .threadpool .ThreadPool ;
56
71
import org .junit .After ;
57
72
import org .junit .Before ;
58
73
59
74
import java .io .IOException ;
75
+ import java .nio .ByteBuffer ;
60
76
import java .nio .charset .StandardCharsets ;
61
77
import java .util .Collections ;
62
78
import java .util .concurrent .TimeUnit ;
@@ -76,7 +92,7 @@ public class OutboundHandlerTests extends OpenSearchTestCase {
76
92
private final TestThreadPool threadPool = new TestThreadPool (getClass ().getName ());
77
93
private final TransportRequestOptions options = TransportRequestOptions .EMPTY ;
78
94
private final AtomicReference <Tuple <Header , BytesReference >> message = new AtomicReference <>();
79
- // private final AtomicReference<BytesReference> protobufMessage = new AtomicReference<>();
95
+ private final AtomicReference <BytesReference > protobufMessage = new AtomicReference <>();
80
96
private InboundPipeline pipeline ;
81
97
private OutboundHandler handler ;
82
98
private FakeTcpChannel channel ;
@@ -104,9 +120,7 @@ public void setUp() throws Exception {
104
120
} catch (IOException e ) {
105
121
throw new AssertionError (e );
106
122
}
107
- }
108
- // , (c, m) -> { protobufMessage.set(m); }
109
- );
123
+ });
110
124
}
111
125
112
126
@ After
@@ -266,6 +280,65 @@ public void onResponseSent(long requestId, String action, TransportResponse resp
266
280
assertEquals ("header_value" , header .getHeaders ().v1 ().get ("header" ));
267
281
}
268
282
283
+ @ SuppressForbidden (reason = "manipulates system properties for testing" )
284
+ public void testSendProtobufResponse () throws IOException {
285
+ ThreadContext threadContext = threadPool .getThreadContext ();
286
+ Version version = Version .CURRENT ;
287
+ String action = "handshake" ;
288
+ long requestId = randomLongBetween (0 , 300 );
289
+ boolean isHandshake = randomBoolean ();
290
+ boolean compress = randomBoolean ();
291
+ threadContext .putHeader ("header" , "header_value" );
292
+ QuerySearchResult queryResult = createQuerySearchResult ();
293
+ FetchSearchResult fetchResult = createFetchSearchResult ();
294
+ QueryFetchSearchResult response = new QueryFetchSearchResult (queryResult , fetchResult );
295
+ System .setProperty (FeatureFlags .PROTOBUF , "true" );
296
+ assertTrue (response .isMessageProtobuf ());
297
+
298
+ AtomicLong requestIdRef = new AtomicLong ();
299
+ AtomicReference <String > actionRef = new AtomicReference <>();
300
+ AtomicReference <TransportResponse > responseRef = new AtomicReference <>();
301
+ handler .setMessageListener (new TransportMessageListener () {
302
+ @ Override
303
+ public void onResponseSent (long requestId , String action , TransportResponse response ) {
304
+ requestIdRef .set (requestId );
305
+ actionRef .set (action );
306
+ responseRef .set (response );
307
+ }
308
+ });
309
+ handler .sendResponse (version , Collections .emptySet (), channel , requestId , action , response , compress , isHandshake );
310
+
311
+ StatsTracker statsTracker = new StatsTracker ();
312
+ final LongSupplier millisSupplier = () -> TimeValue .nsecToMSec (System .nanoTime ());
313
+ final InboundDecoder decoder = new InboundDecoder (Version .CURRENT , PageCacheRecycler .NON_RECYCLING_INSTANCE );
314
+ final Supplier <CircuitBreaker > breaker = () -> new NoopCircuitBreaker ("test" );
315
+ final InboundAggregator aggregator = new InboundAggregator (breaker , (Predicate <String >) requestCanTripBreaker -> true );
316
+ InboundPipeline inboundPipeline = new InboundPipeline (statsTracker , millisSupplier , decoder , aggregator , (c , m ) -> {
317
+ NodeToNodeMessage m1 = (NodeToNodeMessage ) m ;
318
+ protobufMessage .set (BytesReference .fromByteBuffer (ByteBuffer .wrap (m1 .getMessage ().toByteArray ())));
319
+ });
320
+ BytesReference reference = channel .getMessageCaptor ().get ();
321
+ ActionListener <Void > sendListener = channel .getListenerCaptor ().get ();
322
+ if (randomBoolean ()) {
323
+ sendListener .onResponse (null );
324
+ } else {
325
+ sendListener .onFailure (new IOException ("failed" ));
326
+ }
327
+ assertEquals (requestId , requestIdRef .get ());
328
+ assertEquals (action , actionRef .get ());
329
+ assertEquals (response , responseRef .get ());
330
+
331
+ inboundPipeline .handleBytes (channel , new ReleasableBytesReference (reference , () -> {}));
332
+ final BytesReference responseBytes = protobufMessage .get ();
333
+ final NodeToNodeMessage message = new NodeToNodeMessage (responseBytes .toBytesRef ().bytes );
334
+ assertEquals (version .toString (), message .getMessage ().getVersion ());
335
+ assertEquals (requestId , message .getHeader ().getRequestId ());
336
+ assertNotNull (message .getRequestHeaders ());
337
+ assertNotNull (message .getResponseHandlers ());
338
+ assertNotNull (message .getMessage ());
339
+ assertTrue (message .getMessage ().hasQueryFetchSearchResult ());
340
+ }
341
+
269
342
public void testErrorResponse () throws IOException {
270
343
ThreadContext threadContext = threadPool .getThreadContext ();
271
344
Version version = randomFrom (Version .CURRENT , Version .CURRENT .minimumCompatibilityVersion ());
@@ -317,4 +390,35 @@ public void onResponseSent(long requestId, String action, Exception error) {
317
390
318
391
assertEquals ("header_value" , header .getHeaders ().v1 ().get ("header" ));
319
392
}
393
+
394
+ public static QuerySearchResult createQuerySearchResult () {
395
+ ShardId shardId = new ShardId ("index" , "uuid" , randomInt ());
396
+ SearchRequest searchRequest = new SearchRequest ().allowPartialSearchResults (randomBoolean ());
397
+ ShardSearchRequest shardSearchRequest = new ShardSearchRequest (
398
+ OriginalIndicesTests .randomOriginalIndices (),
399
+ searchRequest ,
400
+ shardId ,
401
+ 1 ,
402
+ new AliasFilter (null , Strings .EMPTY_ARRAY ),
403
+ 1.0f ,
404
+ randomNonNegativeLong (),
405
+ null ,
406
+ new String [0 ]
407
+ );
408
+ QuerySearchResult result = new QuerySearchResult (
409
+ new ShardSearchContextId (UUIDs .base64UUID (), randomLong ()),
410
+ new SearchShardTarget ("node" , shardId , null , OriginalIndices .NONE ),
411
+ shardSearchRequest
412
+ );
413
+ return result ;
414
+ }
415
+
416
+ public static FetchSearchResult createFetchSearchResult () {
417
+ ShardId shardId = new ShardId ("index" , "uuid" , randomInt ());
418
+ FetchSearchResult result = new FetchSearchResult (
419
+ new ShardSearchContextId (UUIDs .base64UUID (), randomLong ()),
420
+ new SearchShardTarget ("node" , shardId , null , OriginalIndices .NONE )
421
+ );
422
+ return result ;
423
+ }
320
424
}
0 commit comments