Skip to content

Commit 5663b4a

Browse files
authored
Change abstraction point for transport protocol (opensearch-project#15432)
* Revert "Replacing InboundMessage with NativeInboundMessage for deprecation (opensearch-project#13126)" This reverts commit f5c3ef9. Signed-off-by: Andrew Ross <andrross@amazon.com> * Change abstraction point for transport protocol The previous implementation had a transport switch point in InboundPipeline when the bytes were initially pulled off the wire. There was no implementation for any other protocol as the `canHandleBytes` method was hardcoded to return true. I believe this is the wrong point to switch on the protocol. This change makes NativeInboundBytesHandler protocol agnostic beyond the header. With this change, a complete message is parsed from the stream of bytes, with the header schema being unchanged from what exists today. The protocol switch point will now be at `InboundHandler::inboundMessage`. The header will indicate what protocol was used to serialize the the non-header bytes of the message and then invoke the appropriate handler based on that field. Signed-off-by: Andrew Ross <andrross@amazon.com> --------- Signed-off-by: Andrew Ross <andrross@amazon.com>
1 parent acee2ae commit 5663b4a

18 files changed

+467
-402
lines changed

server/src/main/java/org/opensearch/transport/Header.java

+9-1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ public class Header {
5555

5656
private static final String RESPONSE_NAME = "NO_ACTION_NAME_FOR_RESPONSES";
5757

58+
private final TransportProtocol protocol;
5859
private final int networkMessageSize;
5960
private final Version version;
6061
private final long requestId;
@@ -64,13 +65,18 @@ public class Header {
6465
Tuple<Map<String, String>, Map<String, Set<String>>> headers;
6566
Set<String> features;
6667

67-
Header(int networkMessageSize, long requestId, byte status, Version version) {
68+
Header(TransportProtocol protocol, int networkMessageSize, long requestId, byte status, Version version) {
69+
this.protocol = protocol;
6870
this.networkMessageSize = networkMessageSize;
6971
this.version = version;
7072
this.requestId = requestId;
7173
this.status = status;
7274
}
7375

76+
TransportProtocol getTransportProtocol() {
77+
return protocol;
78+
}
79+
7480
public int getNetworkMessageSize() {
7581
return networkMessageSize;
7682
}
@@ -142,6 +148,8 @@ void finishParsingHeader(StreamInput input) throws IOException {
142148
@Override
143149
public String toString() {
144150
return "Header{"
151+
+ protocol
152+
+ "}{"
145153
+ networkMessageSize
146154
+ "}{"
147155
+ version

server/src/main/java/org/opensearch/transport/InboundAggregator.java

+3-4
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
import org.opensearch.core.common.bytes.BytesArray;
4141
import org.opensearch.core.common.bytes.BytesReference;
4242
import org.opensearch.core.common.bytes.CompositeBytesReference;
43-
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;
4443

4544
import java.io.IOException;
4645
import java.util.ArrayList;
@@ -114,7 +113,7 @@ public void aggregate(ReleasableBytesReference content) {
114113
}
115114
}
116115

117-
public NativeInboundMessage finishAggregation() throws IOException {
116+
public InboundMessage finishAggregation() throws IOException {
118117
ensureOpen();
119118
final ReleasableBytesReference releasableContent;
120119
if (isFirstContent()) {
@@ -128,7 +127,7 @@ public NativeInboundMessage finishAggregation() throws IOException {
128127
}
129128

130129
final BreakerControl breakerControl = new BreakerControl(circuitBreaker);
131-
final NativeInboundMessage aggregated = new NativeInboundMessage(currentHeader, releasableContent, breakerControl);
130+
final InboundMessage aggregated = new InboundMessage(currentHeader, releasableContent, breakerControl);
132131
boolean success = false;
133132
try {
134133
if (aggregated.getHeader().needsToReadVariableHeader()) {
@@ -143,7 +142,7 @@ public NativeInboundMessage finishAggregation() throws IOException {
143142
if (isShortCircuited()) {
144143
aggregated.close();
145144
success = true;
146-
return new NativeInboundMessage(aggregated.getHeader(), aggregationException);
145+
return new InboundMessage(aggregated.getHeader(), aggregationException);
147146
} else {
148147
success = true;
149148
return aggregated;

server/src/main/java/org/opensearch/transport/InboundBytesHandler.java

+126-11
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,139 @@
99
package org.opensearch.transport;
1010

1111
import org.opensearch.common.bytes.ReleasableBytesReference;
12+
import org.opensearch.common.lease.Releasable;
13+
import org.opensearch.common.lease.Releasables;
14+
import org.opensearch.core.common.bytes.CompositeBytesReference;
1215

13-
import java.io.Closeable;
1416
import java.io.IOException;
17+
import java.util.ArrayDeque;
18+
import java.util.ArrayList;
1519
import java.util.function.BiConsumer;
1620

1721
/**
18-
* Interface for handling inbound bytes. Can be implemented by different transport protocols.
22+
* Handler for inbound bytes, using {@link InboundDecoder} to decode headers
23+
* and {@link InboundAggregator} to assemble complete messages to forward to
24+
* the given message handler to parse the message payload.
1925
*/
20-
public interface InboundBytesHandler extends Closeable {
26+
class InboundBytesHandler {
2127

22-
public void doHandleBytes(
23-
TcpChannel channel,
24-
ReleasableBytesReference reference,
25-
BiConsumer<TcpChannel, ProtocolInboundMessage> messageHandler
26-
) throws IOException;
28+
private static final ThreadLocal<ArrayList<Object>> fragmentList = ThreadLocal.withInitial(ArrayList::new);
2729

28-
public boolean canHandleBytes(ReleasableBytesReference reference);
30+
private final ArrayDeque<ReleasableBytesReference> pending;
31+
private final InboundDecoder decoder;
32+
private final InboundAggregator aggregator;
33+
private final StatsTracker statsTracker;
34+
private boolean isClosed = false;
35+
36+
InboundBytesHandler(
37+
ArrayDeque<ReleasableBytesReference> pending,
38+
InboundDecoder decoder,
39+
InboundAggregator aggregator,
40+
StatsTracker statsTracker
41+
) {
42+
this.pending = pending;
43+
this.decoder = decoder;
44+
this.aggregator = aggregator;
45+
this.statsTracker = statsTracker;
46+
}
47+
48+
public void close() {
49+
isClosed = true;
50+
}
51+
52+
public void doHandleBytes(TcpChannel channel, ReleasableBytesReference reference, BiConsumer<TcpChannel, InboundMessage> messageHandler)
53+
throws IOException {
54+
final ArrayList<Object> fragments = fragmentList.get();
55+
boolean continueHandling = true;
56+
57+
while (continueHandling && isClosed == false) {
58+
boolean continueDecoding = true;
59+
while (continueDecoding && pending.isEmpty() == false) {
60+
try (ReleasableBytesReference toDecode = getPendingBytes()) {
61+
final int bytesDecoded = decoder.decode(toDecode, fragments::add);
62+
if (bytesDecoded != 0) {
63+
releasePendingBytes(bytesDecoded);
64+
if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) {
65+
continueDecoding = false;
66+
}
67+
} else {
68+
continueDecoding = false;
69+
}
70+
}
71+
}
72+
73+
if (fragments.isEmpty()) {
74+
continueHandling = false;
75+
} else {
76+
try {
77+
forwardFragments(channel, fragments, messageHandler);
78+
} finally {
79+
for (Object fragment : fragments) {
80+
if (fragment instanceof ReleasableBytesReference) {
81+
((ReleasableBytesReference) fragment).close();
82+
}
83+
}
84+
fragments.clear();
85+
}
86+
}
87+
}
88+
}
89+
90+
private ReleasableBytesReference getPendingBytes() {
91+
if (pending.size() == 1) {
92+
return pending.peekFirst().retain();
93+
} else {
94+
final ReleasableBytesReference[] bytesReferences = new ReleasableBytesReference[pending.size()];
95+
int index = 0;
96+
for (ReleasableBytesReference pendingReference : pending) {
97+
bytesReferences[index] = pendingReference.retain();
98+
++index;
99+
}
100+
final Releasable releasable = () -> Releasables.closeWhileHandlingException(bytesReferences);
101+
return new ReleasableBytesReference(CompositeBytesReference.of(bytesReferences), releasable);
102+
}
103+
}
104+
105+
private void releasePendingBytes(int bytesConsumed) {
106+
int bytesToRelease = bytesConsumed;
107+
while (bytesToRelease != 0) {
108+
try (ReleasableBytesReference reference = pending.pollFirst()) {
109+
assert reference != null;
110+
if (bytesToRelease < reference.length()) {
111+
pending.addFirst(reference.retainedSlice(bytesToRelease, reference.length() - bytesToRelease));
112+
bytesToRelease -= bytesToRelease;
113+
} else {
114+
bytesToRelease -= reference.length();
115+
}
116+
}
117+
}
118+
}
119+
120+
private boolean endOfMessage(Object fragment) {
121+
return fragment == InboundDecoder.PING || fragment == InboundDecoder.END_CONTENT || fragment instanceof Exception;
122+
}
123+
124+
private void forwardFragments(TcpChannel channel, ArrayList<Object> fragments, BiConsumer<TcpChannel, InboundMessage> messageHandler)
125+
throws IOException {
126+
for (Object fragment : fragments) {
127+
if (fragment instanceof Header) {
128+
assert aggregator.isAggregating() == false;
129+
aggregator.headerReceived((Header) fragment);
130+
} else if (fragment == InboundDecoder.PING) {
131+
assert aggregator.isAggregating() == false;
132+
messageHandler.accept(channel, InboundMessage.PING);
133+
} else if (fragment == InboundDecoder.END_CONTENT) {
134+
assert aggregator.isAggregating();
135+
try (InboundMessage aggregated = aggregator.finishAggregation()) {
136+
statsTracker.markMessageReceived();
137+
messageHandler.accept(channel, aggregated);
138+
}
139+
} else {
140+
assert aggregator.isAggregating();
141+
assert fragment instanceof ReleasableBytesReference;
142+
aggregator.aggregate((ReleasableBytesReference) fragment);
143+
}
144+
}
145+
}
29146

30-
@Override
31-
void close();
32147
}

server/src/main/java/org/opensearch/transport/InboundDecoder.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,12 @@ private int headerBytesToRead(BytesReference reference) {
187187
// exposed for use in tests
188188
static Header readHeader(Version version, int networkMessageSize, BytesReference bytesReference) throws IOException {
189189
try (StreamInput streamInput = bytesReference.streamInput()) {
190-
streamInput.skip(TcpHeader.BYTES_REQUIRED_FOR_MESSAGE_SIZE);
190+
TransportProtocol protocol = TransportProtocol.fromBytes(streamInput.readByte(), streamInput.readByte());
191+
streamInput.skip(TcpHeader.MESSAGE_LENGTH_SIZE);
191192
long requestId = streamInput.readLong();
192193
byte status = streamInput.readByte();
193194
Version remoteVersion = Version.fromId(streamInput.readInt());
194-
Header header = new Header(networkMessageSize, requestId, status, remoteVersion);
195+
Header header = new Header(protocol, networkMessageSize, requestId, status, remoteVersion);
195196
final IllegalStateException invalidVersion = ensureVersionCompatibility(remoteVersion, version, header.isHandshake());
196197
if (invalidVersion != null) {
197198
throw invalidVersion;

server/src/main/java/org/opensearch/transport/InboundHandler.java

+6-7
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
3939
import org.opensearch.telemetry.tracing.Tracer;
4040
import org.opensearch.threadpool.ThreadPool;
41-
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;
4241

4342
import java.io.IOException;
4443
import java.util.Map;
@@ -56,7 +55,7 @@ public class InboundHandler {
5655

5756
private volatile long slowLogThresholdMs = Long.MAX_VALUE;
5857

59-
private final Map<String, ProtocolMessageHandler> protocolMessageHandlers;
58+
private final Map<TransportProtocol, ProtocolMessageHandler> protocolMessageHandlers;
6059

6160
InboundHandler(
6261
String nodeName,
@@ -75,7 +74,7 @@ public class InboundHandler {
7574
) {
7675
this.threadPool = threadPool;
7776
this.protocolMessageHandlers = Map.of(
78-
NativeInboundMessage.NATIVE_PROTOCOL,
77+
TransportProtocol.NATIVE,
7978
new NativeMessageHandler(
8079
nodeName,
8180
version,
@@ -107,16 +106,16 @@ void setSlowLogThreshold(TimeValue slowLogThreshold) {
107106
this.slowLogThresholdMs = slowLogThreshold.getMillis();
108107
}
109108

110-
void inboundMessage(TcpChannel channel, ProtocolInboundMessage message) throws Exception {
109+
void inboundMessage(TcpChannel channel, InboundMessage message) throws Exception {
111110
final long startTime = threadPool.relativeTimeInMillis();
112111
channel.getChannelStats().markAccessed(startTime);
113112
messageReceivedFromPipeline(channel, message, startTime);
114113
}
115114

116-
private void messageReceivedFromPipeline(TcpChannel channel, ProtocolInboundMessage message, long startTime) throws IOException {
117-
ProtocolMessageHandler protocolMessageHandler = protocolMessageHandlers.get(message.getProtocol());
115+
private void messageReceivedFromPipeline(TcpChannel channel, InboundMessage message, long startTime) throws IOException {
116+
ProtocolMessageHandler protocolMessageHandler = protocolMessageHandlers.get(message.getTransportProtocol());
118117
if (protocolMessageHandler == null) {
119-
throw new IllegalStateException("No protocol message handler found for protocol: " + message.getProtocol());
118+
throw new IllegalStateException("No protocol message handler found for protocol: " + message.getTransportProtocol());
120119
}
121120
protocolMessageHandler.messageReceived(channel, message, startTime, slowLogThresholdMs, messageListener);
122121
}

0 commit comments

Comments
 (0)