Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit abc32a8

Browse files
committedFeb 2, 2024
Added tests
Signed-off-by: Vacha Shah <vachshah@amazon.com>
1 parent 288a943 commit abc32a8

File tree

6 files changed

+230
-6
lines changed

6 files changed

+230
-6
lines changed
 

‎server/src/main/java/org/opensearch/search/fetch/QueryFetchSearchResult.java

+1
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,5 @@ public QueryFetchSearchResult(QueryFetchSearchResultProto.QueryFetchSearchResult
136136
this.queryResult = new QuerySearchResult(queryFetchSearchResult.getQueryResult());
137137
this.fetchResult = new FetchSearchResult(queryFetchSearchResult.getFetchResult());
138138
}
139+
139140
}

‎server/src/main/java/org/opensearch/search/query/QuerySearchResult.java

+6
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
import org.opensearch.server.proto.QuerySearchResultProto;
5858

5959
import java.io.IOException;
60+
import java.io.OutputStream;
6061
import java.util.ArrayList;
6162
import java.util.List;
6263

@@ -487,6 +488,11 @@ public void writeTo(StreamOutput out) throws IOException {
487488
}
488489
}
489490

491+
@Override
492+
public void writeTo(OutputStream out) throws IOException {
493+
out.write(this.querySearchResultProto.toByteArray());
494+
}
495+
490496
public void writeToNoId(StreamOutput out) throws IOException {
491497
out.writeVInt(from);
492498
out.writeVInt(size);

‎server/src/main/java/org/opensearch/transport/InboundHandler.java

-2
Original file line numberDiff line numberDiff line change
@@ -234,9 +234,7 @@ private void messageReceivedProtobuf(TcpChannel channel, NodeToNodeMessage messa
234234
long requestId = header.getRequestId();
235235
TransportResponseHandler<? extends TransportResponse> handler = responseHandlers.onResponseReceived(requestId, messageListener);
236236
if (handler != null) {
237-
// if (handler.toString().contains("Protobuf")) {
238237
handleProtobufResponse(requestId, remoteAddress, message, handler);
239-
// }
240238
}
241239
} finally {
242240
final long took = threadPool.relativeTimeInMillis() - startTime;

‎server/src/test/java/org/opensearch/search/query/QuerySearchResultTests.java

+24
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,11 @@
5454
import org.opensearch.search.internal.ShardSearchContextId;
5555
import org.opensearch.search.internal.ShardSearchRequest;
5656
import org.opensearch.search.suggest.SuggestTests;
57+
import org.opensearch.server.proto.QuerySearchResultProto;
5758
import org.opensearch.test.OpenSearchTestCase;
5859

60+
import java.io.ByteArrayOutputStream;
61+
5962
import static java.util.Collections.emptyList;
6063

6164
public class QuerySearchResultTests extends OpenSearchTestCase {
@@ -125,4 +128,25 @@ public void testNullResponse() throws Exception {
125128
QuerySearchResult deserialized = copyWriteable(querySearchResult, namedWriteableRegistry, QuerySearchResult::new, Version.CURRENT);
126129
assertEquals(querySearchResult.isNull(), deserialized.isNull());
127130
}
131+
132+
public void testProtobufSerialization() throws Exception {
133+
QuerySearchResult querySearchResult = createTestInstance();
134+
ByteArrayOutputStream stream = new ByteArrayOutputStream();
135+
querySearchResult.writeTo(stream);
136+
QuerySearchResult deserialized = new QuerySearchResult(stream.toByteArray());
137+
QuerySearchResultProto.QuerySearchResult querySearchResultProto = deserialized.response();
138+
assertNotNull(querySearchResultProto);
139+
assertEquals(querySearchResult.getContextId().getId(), querySearchResultProto.getContextId().getId());
140+
assertEquals(
141+
querySearchResult.getSearchShardTarget().getShardId().getIndex().getUUID(),
142+
querySearchResultProto.getSearchShardTarget().getShardId().getIndexUUID()
143+
);
144+
assertEquals(querySearchResult.topDocs().maxScore, querySearchResultProto.getTopDocsAndMaxScore().getMaxScore(), 0f);
145+
assertEquals(
146+
querySearchResult.topDocs().topDocs.totalHits.value,
147+
querySearchResultProto.getTopDocsAndMaxScore().getTopDocs().getTotalHits().getValue()
148+
);
149+
assertEquals(querySearchResult.from(), querySearchResultProto.getFrom());
150+
assertEquals(querySearchResult.size(), querySearchResultProto.getSize());
151+
}
128152
}

‎server/src/test/java/org/opensearch/transport/InboundHandlerTests.java

+91
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,23 @@
3737
import org.apache.lucene.util.BytesRef;
3838
import org.opensearch.OpenSearchException;
3939
import org.opensearch.Version;
40+
import org.opensearch.common.SuppressForbidden;
4041
import org.opensearch.common.bytes.ReleasableBytesReference;
4142
import org.opensearch.common.collect.Tuple;
4243
import org.opensearch.common.io.stream.BytesStreamOutput;
4344
import org.opensearch.common.settings.Settings;
4445
import org.opensearch.common.unit.TimeValue;
4546
import org.opensearch.common.util.BigArrays;
47+
import org.opensearch.common.util.FeatureFlags;
4648
import org.opensearch.core.action.ActionListener;
4749
import org.opensearch.core.common.bytes.BytesArray;
4850
import org.opensearch.core.common.bytes.BytesReference;
4951
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
5052
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
5153
import org.opensearch.core.common.io.stream.StreamInput;
54+
import org.opensearch.search.fetch.FetchSearchResult;
55+
import org.opensearch.search.fetch.QueryFetchSearchResult;
56+
import org.opensearch.search.query.QuerySearchResult;
5257
import org.opensearch.tasks.TaskManager;
5358
import org.opensearch.telemetry.tracing.noop.NoopTracer;
5459
import org.opensearch.test.MockLogAppender;
@@ -248,6 +253,92 @@ public TestResponse read(byte[] in) throws IOException {
248253
}
249254
}
250255

256+
@SuppressForbidden(reason = "manipulates system properties for testing")
257+
public void testProtobufResponse() throws Exception {
258+
String action = "test-request";
259+
int headerSize = TcpHeader.headerSize(version);
260+
AtomicReference<TestRequest> requestCaptor = new AtomicReference<>();
261+
AtomicReference<Exception> exceptionCaptor = new AtomicReference<>();
262+
AtomicReference<QueryFetchSearchResult> responseCaptor = new AtomicReference<>();
263+
AtomicReference<TransportChannel> channelCaptor = new AtomicReference<>();
264+
265+
long requestId = responseHandlers.add(new Transport.ResponseContext<>(new TransportResponseHandler<QueryFetchSearchResult>() {
266+
@Override
267+
public void handleResponse(QueryFetchSearchResult response) {
268+
responseCaptor.set(response);
269+
}
270+
271+
@Override
272+
public void handleException(TransportException exp) {
273+
exceptionCaptor.set(exp);
274+
}
275+
276+
@Override
277+
public String executor() {
278+
return ThreadPool.Names.SAME;
279+
}
280+
281+
@Override
282+
public QueryFetchSearchResult read(StreamInput in) throws IOException {
283+
throw new UnsupportedOperationException("Unimplemented method 'read'");
284+
}
285+
286+
@Override
287+
public QueryFetchSearchResult read(byte[] in) throws IOException {
288+
return new QueryFetchSearchResult(in);
289+
}
290+
}, null, action));
291+
RequestHandlerRegistry<TestRequest> registry = new RequestHandlerRegistry<>(
292+
action,
293+
TestRequest::new,
294+
taskManager,
295+
(request, channel, task) -> {
296+
channelCaptor.set(channel);
297+
requestCaptor.set(request);
298+
},
299+
ThreadPool.Names.SAME,
300+
false,
301+
true
302+
);
303+
requestHandlers.registerHandler(registry);
304+
String requestValue = randomAlphaOfLength(10);
305+
OutboundMessage.Request request = new OutboundMessage.Request(
306+
threadPool.getThreadContext(),
307+
new String[0],
308+
new TestRequest(requestValue),
309+
version,
310+
action,
311+
requestId,
312+
false,
313+
false
314+
);
315+
316+
BytesReference fullRequestBytes = request.serialize(new BytesStreamOutput());
317+
BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize);
318+
Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version);
319+
InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {});
320+
requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput());
321+
handler.inboundMessage(channel, requestMessage);
322+
323+
TransportChannel transportChannel = channelCaptor.get();
324+
assertEquals(Version.CURRENT, transportChannel.getVersion());
325+
assertEquals("transport", transportChannel.getChannelType());
326+
assertEquals(requestValue, requestCaptor.get().value);
327+
328+
QuerySearchResult queryResult = OutboundHandlerTests.createQuerySearchResult();
329+
FetchSearchResult fetchResult = OutboundHandlerTests.createFetchSearchResult();
330+
QueryFetchSearchResult response = new QueryFetchSearchResult(queryResult, fetchResult);
331+
System.setProperty(FeatureFlags.PROTOBUF, "true");
332+
transportChannel.sendResponse(response);
333+
334+
BytesReference fullResponseBytes = channel.getMessageCaptor().get();
335+
NodeToNodeMessage nodeToNodeMessage = new NodeToNodeMessage(fullResponseBytes.toBytesRef().bytes);
336+
handler.inboundMessage(channel, nodeToNodeMessage);
337+
QueryFetchSearchResult result = responseCaptor.get();
338+
assertNotNull(result);
339+
assertEquals(queryResult.getMaxScore(), result.queryResult().getMaxScore(), 0.0);
340+
}
341+
251342
public void testSendsErrorResponseToHandshakeFromCompatibleVersion() throws Exception {
252343
// Nodes use their minimum compatibility version for the TCP handshake, so a node from v(major-1).x will report its version as
253344
// v(major-2).last in the TCP handshake, with which we are not really compatible. We put extra effort into making sure that if

‎server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java

+108-4
Original file line numberDiff line numberDiff line change
@@ -34,29 +34,45 @@
3434

3535
import org.opensearch.OpenSearchException;
3636
import org.opensearch.Version;
37+
import org.opensearch.action.OriginalIndices;
38+
import org.opensearch.action.OriginalIndicesTests;
39+
import org.opensearch.action.search.SearchRequest;
3740
import org.opensearch.cluster.node.DiscoveryNode;
41+
import org.opensearch.common.SuppressForbidden;
42+
import org.opensearch.common.UUIDs;
3843
import org.opensearch.common.bytes.ReleasableBytesReference;
3944
import org.opensearch.common.collect.Tuple;
4045
import org.opensearch.common.io.stream.BytesStreamOutput;
4146
import org.opensearch.common.unit.TimeValue;
4247
import org.opensearch.common.util.BigArrays;
48+
import org.opensearch.common.util.FeatureFlags;
4349
import org.opensearch.common.util.PageCacheRecycler;
4450
import org.opensearch.common.util.concurrent.ThreadContext;
4551
import org.opensearch.common.util.io.Streams;
4652
import org.opensearch.core.action.ActionListener;
53+
import org.opensearch.core.common.Strings;
4754
import org.opensearch.core.common.breaker.CircuitBreaker;
4855
import org.opensearch.core.common.breaker.NoopCircuitBreaker;
4956
import org.opensearch.core.common.bytes.BytesArray;
5057
import org.opensearch.core.common.bytes.BytesReference;
5158
import org.opensearch.core.common.transport.TransportAddress;
59+
import org.opensearch.core.index.shard.ShardId;
5260
import org.opensearch.core.transport.TransportResponse;
61+
import org.opensearch.search.SearchShardTarget;
62+
import org.opensearch.search.fetch.FetchSearchResult;
63+
import org.opensearch.search.fetch.QueryFetchSearchResult;
64+
import org.opensearch.search.internal.AliasFilter;
65+
import org.opensearch.search.internal.ShardSearchContextId;
66+
import org.opensearch.search.internal.ShardSearchRequest;
67+
import org.opensearch.search.query.QuerySearchResult;
5368
import org.opensearch.test.OpenSearchTestCase;
5469
import org.opensearch.threadpool.TestThreadPool;
5570
import org.opensearch.threadpool.ThreadPool;
5671
import org.junit.After;
5772
import org.junit.Before;
5873

5974
import java.io.IOException;
75+
import java.nio.ByteBuffer;
6076
import java.nio.charset.StandardCharsets;
6177
import java.util.Collections;
6278
import java.util.concurrent.TimeUnit;
@@ -76,7 +92,7 @@ public class OutboundHandlerTests extends OpenSearchTestCase {
7692
private final TestThreadPool threadPool = new TestThreadPool(getClass().getName());
7793
private final TransportRequestOptions options = TransportRequestOptions.EMPTY;
7894
private final AtomicReference<Tuple<Header, BytesReference>> message = new AtomicReference<>();
79-
// private final AtomicReference<BytesReference> protobufMessage = new AtomicReference<>();
95+
private final AtomicReference<BytesReference> protobufMessage = new AtomicReference<>();
8096
private InboundPipeline pipeline;
8197
private OutboundHandler handler;
8298
private FakeTcpChannel channel;
@@ -104,9 +120,7 @@ public void setUp() throws Exception {
104120
} catch (IOException e) {
105121
throw new AssertionError(e);
106122
}
107-
}
108-
// , (c, m) -> { protobufMessage.set(m); }
109-
);
123+
});
110124
}
111125

112126
@After
@@ -266,6 +280,65 @@ public void onResponseSent(long requestId, String action, TransportResponse resp
266280
assertEquals("header_value", header.getHeaders().v1().get("header"));
267281
}
268282

283+
@SuppressForbidden(reason = "manipulates system properties for testing")
284+
public void testSendProtobufResponse() throws IOException {
285+
ThreadContext threadContext = threadPool.getThreadContext();
286+
Version version = Version.CURRENT;
287+
String action = "handshake";
288+
long requestId = randomLongBetween(0, 300);
289+
boolean isHandshake = randomBoolean();
290+
boolean compress = randomBoolean();
291+
threadContext.putHeader("header", "header_value");
292+
QuerySearchResult queryResult = createQuerySearchResult();
293+
FetchSearchResult fetchResult = createFetchSearchResult();
294+
QueryFetchSearchResult response = new QueryFetchSearchResult(queryResult, fetchResult);
295+
System.setProperty(FeatureFlags.PROTOBUF, "true");
296+
assertTrue(response.isMessageProtobuf());
297+
298+
AtomicLong requestIdRef = new AtomicLong();
299+
AtomicReference<String> actionRef = new AtomicReference<>();
300+
AtomicReference<TransportResponse> responseRef = new AtomicReference<>();
301+
handler.setMessageListener(new TransportMessageListener() {
302+
@Override
303+
public void onResponseSent(long requestId, String action, TransportResponse response) {
304+
requestIdRef.set(requestId);
305+
actionRef.set(action);
306+
responseRef.set(response);
307+
}
308+
});
309+
handler.sendResponse(version, Collections.emptySet(), channel, requestId, action, response, compress, isHandshake);
310+
311+
StatsTracker statsTracker = new StatsTracker();
312+
final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime());
313+
final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE);
314+
final Supplier<CircuitBreaker> breaker = () -> new NoopCircuitBreaker("test");
315+
final InboundAggregator aggregator = new InboundAggregator(breaker, (Predicate<String>) requestCanTripBreaker -> true);
316+
InboundPipeline inboundPipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, (c, m) -> {
317+
NodeToNodeMessage m1 = (NodeToNodeMessage) m;
318+
protobufMessage.set(BytesReference.fromByteBuffer(ByteBuffer.wrap(m1.getMessage().toByteArray())));
319+
});
320+
BytesReference reference = channel.getMessageCaptor().get();
321+
ActionListener<Void> sendListener = channel.getListenerCaptor().get();
322+
if (randomBoolean()) {
323+
sendListener.onResponse(null);
324+
} else {
325+
sendListener.onFailure(new IOException("failed"));
326+
}
327+
assertEquals(requestId, requestIdRef.get());
328+
assertEquals(action, actionRef.get());
329+
assertEquals(response, responseRef.get());
330+
331+
inboundPipeline.handleBytes(channel, new ReleasableBytesReference(reference, () -> {}));
332+
final BytesReference responseBytes = protobufMessage.get();
333+
final NodeToNodeMessage message = new NodeToNodeMessage(responseBytes.toBytesRef().bytes);
334+
assertEquals(version.toString(), message.getMessage().getVersion());
335+
assertEquals(requestId, message.getHeader().getRequestId());
336+
assertNotNull(message.getRequestHeaders());
337+
assertNotNull(message.getResponseHandlers());
338+
assertNotNull(message.getMessage());
339+
assertTrue(message.getMessage().hasQueryFetchSearchResult());
340+
}
341+
269342
public void testErrorResponse() throws IOException {
270343
ThreadContext threadContext = threadPool.getThreadContext();
271344
Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion());
@@ -317,4 +390,35 @@ public void onResponseSent(long requestId, String action, Exception error) {
317390

318391
assertEquals("header_value", header.getHeaders().v1().get("header"));
319392
}
393+
394+
public static QuerySearchResult createQuerySearchResult() {
395+
ShardId shardId = new ShardId("index", "uuid", randomInt());
396+
SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(randomBoolean());
397+
ShardSearchRequest shardSearchRequest = new ShardSearchRequest(
398+
OriginalIndicesTests.randomOriginalIndices(),
399+
searchRequest,
400+
shardId,
401+
1,
402+
new AliasFilter(null, Strings.EMPTY_ARRAY),
403+
1.0f,
404+
randomNonNegativeLong(),
405+
null,
406+
new String[0]
407+
);
408+
QuerySearchResult result = new QuerySearchResult(
409+
new ShardSearchContextId(UUIDs.base64UUID(), randomLong()),
410+
new SearchShardTarget("node", shardId, null, OriginalIndices.NONE),
411+
shardSearchRequest
412+
);
413+
return result;
414+
}
415+
416+
public static FetchSearchResult createFetchSearchResult() {
417+
ShardId shardId = new ShardId("index", "uuid", randomInt());
418+
FetchSearchResult result = new FetchSearchResult(
419+
new ShardSearchContextId(UUIDs.base64UUID(), randomLong()),
420+
new SearchShardTarget("node", shardId, null, OriginalIndices.NONE)
421+
);
422+
return result;
423+
}
320424
}

0 commit comments

Comments
 (0)
Please sign in to comment.