diff --git a/hugegraph-pd/hg-pd-core/src/main/java/org/apache/hugegraph/pd/raft/KVOperation.java b/hugegraph-pd/hg-pd-core/src/main/java/org/apache/hugegraph/pd/raft/KVOperation.java index b27252fa10..6afc6d6e94 100644 --- a/hugegraph-pd/hg-pd-core/src/main/java/org/apache/hugegraph/pd/raft/KVOperation.java +++ b/hugegraph-pd/hg-pd-core/src/main/java/org/apache/hugegraph/pd/raft/KVOperation.java @@ -29,6 +29,8 @@ import lombok.Data; +import org.apache.hugegraph.pd.raft.serializer.HugegraphHessianSerializerFactory; + @Data public class KVOperation { @@ -84,6 +86,7 @@ public static KVOperation fromByteArray(byte[] value) throws IOException { try (ByteArrayInputStream bis = new ByteArrayInputStream(value, 1, value.length - 1)) { Hessian2Input input = new Hessian2Input(bis); + input.setSerializerFactory(HugegraphHessianSerializerFactory.getInstance()); KVOperation op = new KVOperation(); op.op = value[0]; op.key = input.readBytes(); diff --git a/hugegraph-pd/hg-pd-core/src/main/java/org/apache/hugegraph/pd/raft/RaftEngine.java b/hugegraph-pd/hg-pd-core/src/main/java/org/apache/hugegraph/pd/raft/RaftEngine.java index 67734d1456..3b5fd9f575 100644 --- a/hugegraph-pd/hg-pd-core/src/main/java/org/apache/hugegraph/pd/raft/RaftEngine.java +++ b/hugegraph-pd/hg-pd-core/src/main/java/org/apache/hugegraph/pd/raft/RaftEngine.java @@ -19,6 +19,7 @@ import java.io.File; import java.util.ArrayList; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Objects; @@ -26,6 +27,11 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import com.alipay.remoting.ExtendedNettyChannelHandler; +import com.alipay.remoting.config.BoltServerOption; +import com.alipay.sofa.jraft.rpc.impl.BoltRpcServer; +import io.netty.channel.ChannelHandler; import org.apache.hugegraph.pd.common.PDException; import org.apache.hugegraph.pd.config.PDConfig; @@ -50,8 +56,8 @@ import com.alipay.sofa.jraft.util.Endpoint; import com.alipay.sofa.jraft.util.ThreadId; import com.alipay.sofa.jraft.util.internal.ThrowUtil; - import lombok.extern.slf4j.Slf4j; +import org.apache.hugegraph.pd.raft.auth.IpAuthHandler; @Slf4j public class RaftEngine { @@ -117,7 +123,7 @@ public boolean init(PDConfig.Raft config) { final PeerId serverId = JRaftUtils.getPeerId(config.getAddress()); - rpcServer = createRaftRpcServer(config.getAddress()); + rpcServer = createRaftRpcServer(config.getAddress(), initConf.getPeers()); // construct raft group and start raft this.raftGroupService = new RaftGroupService(groupId, serverId, nodeOptions, rpcServer, true); @@ -130,14 +136,35 @@ public boolean init(PDConfig.Raft config) { /** * Create a Raft RPC Server for communication between PDs */ - private RpcServer createRaftRpcServer(String raftAddr) { + private RpcServer createRaftRpcServer(String raftAddr, List peers) { Endpoint endpoint = JRaftUtils.getEndPoint(raftAddr); RpcServer rpcServer = RaftRpcServerFactory.createRaftRpcServer(endpoint); + configureRaftServerIpWhitelist(peers, rpcServer); RaftRpcProcessor.registerProcessor(rpcServer, this); rpcServer.init(null); return rpcServer; } + private static void configureRaftServerIpWhitelist(List peers, RpcServer rpcServer) { + if(rpcServer instanceof BoltRpcServer){ + ((BoltRpcServer) rpcServer).getServer().option(BoltServerOption.EXTENDED_NETTY_CHANNEL_HANDLER, + new ExtendedNettyChannelHandler() { + @Override + public List frontChannelHandlers() { + return Collections.singletonList( + IpAuthHandler.getInstance( + peers.stream() + .map(PeerId::getIp) + .collect(Collectors.toSet()))); + } + @Override + public List backChannelHandlers() { + return Collections.emptyList(); + } + }); + } + } + public void shutDown() { if (this.raftGroupService != null) { this.raftGroupService.shutdown(); diff --git a/hugegraph-pd/hg-pd-core/src/main/java/org/apache/hugegraph/pd/raft/auth/IpAuthHandler.java b/hugegraph-pd/hg-pd-core/src/main/java/org/apache/hugegraph/pd/raft/auth/IpAuthHandler.java new file mode 100644 index 0000000000..b9b5d839ac --- /dev/null +++ b/hugegraph-pd/hg-pd-core/src/main/java/org/apache/hugegraph/pd/raft/auth/IpAuthHandler.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hugegraph.pd.raft.auth; + +import java.net.InetSocketAddress; +import java.util.Collections; +import java.util.Set; + +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +@ChannelHandler.Sharable +public class IpAuthHandler extends ChannelDuplexHandler { + + private final Set allowedIps; + private static volatile IpAuthHandler instance; + + private IpAuthHandler(Set allowedIps) { + this.allowedIps = Collections.unmodifiableSet(allowedIps); + } + + public static IpAuthHandler getInstance(Set allowedIps) { + if (instance == null) { + synchronized (IpAuthHandler.class) { + if (instance == null) { + instance = new IpAuthHandler(allowedIps); + } + } + } + return instance; + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + String clientIp = getClientIp(ctx); + if (!isIpAllowed(clientIp)) { + log.warn("Blocked connection from {}", clientIp); + ctx.close(); + return; + } + super.channelActive(ctx); + } + + private static String getClientIp(ChannelHandlerContext ctx) { + InetSocketAddress remoteAddress = (InetSocketAddress) ctx.channel().remoteAddress(); + return remoteAddress.getAddress().getHostAddress(); + } + + private boolean isIpAllowed(String ip) { + return allowedIps.isEmpty() || allowedIps.contains(ip); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + String clientIp=getClientIp(ctx); + log.warn("clien : {} connection exception : {}",clientIp,cause); + if (ctx.channel().isActive()) { + ctx.close().addListener(future -> { + if (!future.isSuccess()) { + log.warn("clien : {} connection closed failed : {}",clientIp,future.cause().getMessage()); + } + }); + } + } + +} diff --git a/hugegraph-pd/hg-pd-core/src/main/java/org/apache/hugegraph/pd/raft/serializer/HugegraphHessianSerializerFactory.java b/hugegraph-pd/hg-pd-core/src/main/java/org/apache/hugegraph/pd/raft/serializer/HugegraphHessianSerializerFactory.java new file mode 100644 index 0000000000..275159a50e --- /dev/null +++ b/hugegraph-pd/hg-pd-core/src/main/java/org/apache/hugegraph/pd/raft/serializer/HugegraphHessianSerializerFactory.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.hugegraph.pd.raft.serializer; + +import com.caucho.hessian.io.Deserializer; +import com.caucho.hessian.io.HessianProtocolException; +import com.caucho.hessian.io.Serializer; +import com.caucho.hessian.io.SerializerFactory; + + +import lombok.extern.slf4j.Slf4j; + +import java.text.SimpleDateFormat; +import java.time.format.DateTimeFormatter; + +import java.util.ArrayList; +import java.util.Calendar; +import java.util.Date; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; +import java.util.TreeSet; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ConcurrentSkipListMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +@Slf4j +public class HugegraphHessianSerializerFactory extends SerializerFactory { + + private static final HugegraphHessianSerializerFactory INSTANCE = new HugegraphHessianSerializerFactory(); + + private HugegraphHessianSerializerFactory() { + super(); + initWhitelist(); + } + + public static HugegraphHessianSerializerFactory getInstance() { + return INSTANCE; + } + + private final Set whitelist = new HashSet<>(); + + private void initWhitelist() { + allowBasicType(); + allowCollections(); + allowConcurrent(); + allowTime(); + allowBusinessClasses(); + } + + private void allowBasicType() { + addToWhitelist( + boolean.class, byte.class, char.class, double.class, + float.class, int.class, long.class, short.class, + Boolean.class, Byte.class, Character.class, Double.class, + Float.class, Integer.class, Long.class, Short.class, + String.class, Class.class, Number.class + ); + } + + private void allowCollections() { + addToWhitelist( + List.class, ArrayList.class, LinkedList.class, + Set.class, HashSet.class, LinkedHashSet.class, TreeSet.class, + Map.class, HashMap.class, LinkedHashMap.class, TreeMap.class + ); + } + + private void allowConcurrent() { + addToWhitelist( + AtomicBoolean.class, AtomicInteger.class, AtomicLong.class, AtomicReference.class, + ConcurrentMap.class, ConcurrentHashMap.class, ConcurrentSkipListMap.class, CopyOnWriteArrayList.class + ); + } + + private void allowTime() { + addToWhitelist( + Date.class, Calendar.class, TimeUnit.class, + SimpleDateFormat.class, DateTimeFormatter.class + ); + tryAddClass("java.time.LocalDate"); + tryAddClass("java.time.LocalDateTime"); + tryAddClass("java.time.Instant"); + } + + private void allowBusinessClasses() { + addToWhitelist( + org.apache.hugegraph.pd.raft.KVOperation.class, + byte[].class + ); + } + + private void addToWhitelist(Class... classes) { + for (Class clazz : classes) { + whitelist.add(clazz.getName()); + } + } + + private void tryAddClass(String className) { + try { + Class.forName(className); + whitelist.add(className); + } catch (ClassNotFoundException e) { + log.warn("Failed to load class {}", className); + } + } + + @Override + public Serializer getSerializer(Class cl) throws HessianProtocolException { + checkWhitelist(cl); + return super.getSerializer(cl); + } + + @Override + public Deserializer getDeserializer(Class cl) throws HessianProtocolException { + checkWhitelist(cl); + return super.getDeserializer(cl); + } + + private void checkWhitelist(Class cl) { + String className = cl.getName(); + if (!whitelist.contains(className)) { + log.warn("Security alert: Blocked unauthorized class [{}] at {}", + className, new Date()); + throw new SecurityException("hessian serialize unauthorized class: " + className); + } + } +}