Skip to content

Commit 2ddf5ae

Browse files
committed
Create a custom thrift protocol to ensure client/server compatibility
1 parent f99847c commit 2ddf5ae

File tree

4 files changed

+169
-72
lines changed

4 files changed

+169
-72
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* https://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.apache.accumulo.core.rpc;
20+
21+
import org.apache.accumulo.core.trace.TraceUtil;
22+
import org.apache.thrift.TException;
23+
import org.apache.thrift.protocol.TCompactProtocol;
24+
import org.apache.thrift.protocol.TMessage;
25+
import org.apache.thrift.protocol.TProtocol;
26+
import org.apache.thrift.transport.TTransport;
27+
28+
import io.opentelemetry.api.trace.Span;
29+
import io.opentelemetry.context.Scope;
30+
31+
public class AccumuloProtocolFactory extends TCompactProtocol.Factory {
32+
33+
private static final long serialVersionUID = 1L;
34+
35+
private final boolean isClient;
36+
37+
public static class AccumuloProtocol extends TCompactProtocol {
38+
39+
private static final int MAGIC_NUMBER = 0x41434355; // "ACCU" in ASCII
40+
private static final byte PROTOCOL_VERSION = 1;
41+
42+
private final boolean isClient;
43+
44+
private Span span = null;
45+
private Scope scope = null;
46+
47+
public AccumuloProtocol(TTransport transport, boolean isClient) {
48+
super(transport);
49+
this.isClient = isClient;
50+
}
51+
52+
@Override
53+
public void writeMessageBegin(TMessage message) throws TException {
54+
if (isClient) {
55+
span = TraceUtil.startClientRpcSpan(this.getClass(), message.name);
56+
scope = span.makeCurrent();
57+
58+
try {
59+
this.writeHeader();
60+
} catch (TException e) {
61+
if (scope != null) {
62+
scope.close();
63+
}
64+
if (span != null) {
65+
span.end();
66+
}
67+
throw e;
68+
}
69+
}
70+
71+
super.writeMessageBegin(message);
72+
}
73+
74+
/**
75+
* Writes the Accumulo protocol header containing version and identification info
76+
*/
77+
private void writeHeader() throws TException {
78+
super.writeI32(MAGIC_NUMBER);
79+
super.writeByte(PROTOCOL_VERSION);
80+
}
81+
82+
@Override
83+
public TMessage readMessageBegin() throws TException {
84+
if (!isClient) {
85+
this.validateHeader();
86+
}
87+
88+
return super.readMessageBegin();
89+
}
90+
91+
/**
92+
* Checks if the given version is compatible with the current protocol version
93+
*/
94+
private boolean isCompatibleVersion(byte version) {
95+
return version == PROTOCOL_VERSION;
96+
}
97+
98+
/**
99+
* Reads and validates the Accumulo protocol header
100+
*
101+
* @throws TException if the header is invalid or incompatible
102+
*/
103+
private void validateHeader() throws TException {
104+
final int magic = super.readI32();
105+
if (magic != MAGIC_NUMBER) {
106+
throw new TException("Invalid Accumulo protocol: magic number mismatch. " + "Expected: 0x"
107+
+ Integer.toHexString(MAGIC_NUMBER) + ", got: 0x" + Integer.toHexString(magic));
108+
}
109+
110+
final byte version = super.readByte();
111+
if (!isCompatibleVersion(version)) {
112+
throw new TException("Incompatible protocol version. Client version: " + version
113+
+ ", Server version: " + PROTOCOL_VERSION);
114+
}
115+
}
116+
117+
@Override
118+
public void writeMessageEnd() throws TException {
119+
super.writeMessageEnd();
120+
121+
if (this.isClient && scope != null) {
122+
scope.close();
123+
span.end();
124+
}
125+
}
126+
}
127+
128+
@Override
129+
public TProtocol getProtocol(TTransport trans) {
130+
return new AccumuloProtocol(trans, isClient);
131+
}
132+
133+
/**
134+
* Creates a factory for producing AccumuloProtocol instances
135+
*
136+
* @param isClient true if this factory produces protocols for the client side, false for the
137+
* server side
138+
*/
139+
private AccumuloProtocolFactory(boolean isClient) {
140+
this.isClient = isClient;
141+
}
142+
143+
/**
144+
* Creates a client-side factory for use in clients making RPC calls
145+
*/
146+
public static AccumuloProtocolFactory clientFactory() {
147+
return new AccumuloProtocolFactory(true);
148+
}
149+
150+
/**
151+
* Creates a server-side factory for use in servers receiving RPC calls
152+
*/
153+
public static AccumuloProtocolFactory serverFactory() {
154+
return new AccumuloProtocolFactory(false);
155+
}
156+
}

core/src/main/java/org/apache/accumulo/core/rpc/ThriftUtil.java

+12-5
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@ public class ThriftUtil {
5353

5454
private static final Logger log = LoggerFactory.getLogger(ThriftUtil.class);
5555

56-
private static final TraceProtocolFactory protocolFactory = new TraceProtocolFactory();
56+
private static final AccumuloProtocolFactory clientAccumuloProtocolFactory =
57+
AccumuloProtocolFactory.clientFactory();
58+
private static final AccumuloProtocolFactory serverAccumuloProtocolFactory =
59+
AccumuloProtocolFactory.serverFactory();
5760
private static final AccumuloTFramedTransportFactory transportFactory =
5861
new AccumuloTFramedTransportFactory(Integer.MAX_VALUE);
5962
private static final Map<Integer,TTransportFactory> factoryCache = new HashMap<>();
@@ -63,12 +66,16 @@ public class ThriftUtil {
6366
private static final int RELOGIN_MAX_BACKOFF = 5000;
6467

6568
/**
66-
* An instance of {@link TraceProtocolFactory}
69+
* An instance of {@link AccumuloProtocolFactory}
6770
*
6871
* @return The default Thrift TProtocolFactory for RPC
6972
*/
70-
public static TProtocolFactory protocolFactory() {
71-
return protocolFactory;
73+
public static TProtocolFactory clientProtocolFactory() {
74+
return clientAccumuloProtocolFactory;
75+
}
76+
77+
public static TProtocolFactory serverProtocolFactory() {
78+
return serverAccumuloProtocolFactory;
7279
}
7380

7481
/**
@@ -85,7 +92,7 @@ public static TTransportFactory transportFactory() {
8592
*/
8693
public static <T extends TServiceClient> T createClient(ThriftClientTypes<T> type,
8794
TTransport transport) {
88-
return type.getClient(protocolFactory.getProtocol(transport));
95+
return type.getClient(clientProtocolFactory().getProtocol(transport));
8996
}
9097

9198
/**

core/src/main/java/org/apache/accumulo/core/rpc/TraceProtocolFactory.java

-66
This file was deleted.

server/base/src/main/java/org/apache/accumulo/server/rpc/TServerUtils.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,7 @@ private static ServerAddress startTServer(ThriftServerType serverType, TimedProc
595595
SslConnectionParams sslParams, SaslServerConnectionParams saslParams,
596596
long serverSocketTimeout, int backlog, boolean portSearch, HostAndPort... addresses)
597597
throws TTransportException {
598-
TProtocolFactory protocolFactory = ThriftUtil.protocolFactory();
598+
TProtocolFactory protocolFactory = ThriftUtil.serverProtocolFactory();
599599
// This is presently not supported. It's hypothetically possible, I believe, to work, but it
600600
// would require changes in how the transports
601601
// work at the Thrift layer to ensure that both the SSL and SASL handshakes function. SASL's

0 commit comments

Comments
 (0)