Skip to content

Commit

Permalink
[controller] Enforce ACL checks on gRPC createStore API (#1443)
Browse files Browse the repository at this point in the history
- Use access controller manager to enforce ACL checks in gRPC createStore API.
- Add support for running E2E tests with a custom access controller.
- Introduce `GrpcControllerClientDetails` to encapsulate gRPC session details for 
  authorization checks.
  • Loading branch information
sushantmane authored Jan 17, 2025
1 parent d090edd commit fbbd13b
Show file tree
Hide file tree
Showing 14 changed files with 243 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public class NoOpDynamicAccessController implements DynamicAccessController {

public static final NoOpDynamicAccessController INSTANCE = new NoOpDynamicAccessController();

private NoOpDynamicAccessController() {
protected NoOpDynamicAccessController() {
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public static ClusterStoreGrpcInfo getClusterStoreGrpcInfo(ControllerResponse re
*
* @param code The gRPC status code representing the error (e.g., {@link io.grpc.Status.Code}).
* @param errorType The specific controller error type represented by {@link ControllerGrpcErrorType}.
* @param e The exception containing the error message.
* @param errorMessage The error message to be included in the response.
* @param clusterName The name of the cluster associated with the error (can be null).
* @param storeName The name of the store associated with the error (can be null).
* @param responseObserver The {@link StreamObserver} to send the error response back to the client.
Expand Down Expand Up @@ -62,14 +62,14 @@ public static ClusterStoreGrpcInfo getClusterStoreGrpcInfo(ControllerResponse re
public static void sendErrorResponse(
Code code,
ControllerGrpcErrorType errorType,
Exception e,
String errorMessage,
String clusterName,
String storeName,
StreamObserver<?> responseObserver) {
VeniceControllerGrpcErrorInfo.Builder errorInfoBuilder =
VeniceControllerGrpcErrorInfo.newBuilder().setStatusCode(code.value()).setErrorType(errorType);
if (e.getMessage() != null) {
errorInfoBuilder.setErrorMessage(e.getMessage());
if (errorMessage != null) {
errorInfoBuilder.setErrorMessage(errorMessage);
}
if (clusterName != null) {
errorInfoBuilder.setClusterName(clusterName);
Expand All @@ -85,6 +85,22 @@ public static void sendErrorResponse(
responseObserver.onError(StatusProto.toStatusRuntimeException(status));
}

public static void sendErrorResponse(
Code code,
ControllerGrpcErrorType errorType,
Exception exception,
String clusterName,
String storeName,
StreamObserver<?> responseObserver) {
sendErrorResponse(
code,
errorType,
exception != null ? exception.getMessage() : "",
clusterName,
storeName,
responseObserver);
}

/**
* Parses a {@link StatusRuntimeException} to extract a {@link VeniceControllerGrpcErrorInfo} object.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ enum ControllerGrpcErrorType {
BAD_REQUEST = 8;
CONCURRENT_BATCH_PUSH = 9;
RESOURCE_STILL_EXISTS = 10;
UNAUTHORIZED = 11;
}

message VeniceControllerGrpcErrorInfo {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package com.linkedin.venice.controller;

import static com.linkedin.venice.controller.server.grpc.ControllerGrpcSslSessionInterceptor.CLIENT_CERTIFICATE_CONTEXT_KEY;
import static com.linkedin.venice.controller.server.grpc.ControllerGrpcSslSessionInterceptor.GRPC_CONTROLLER_CLIENT_DETAILS;
import static org.testng.Assert.assertEquals;

import com.linkedin.venice.controller.server.grpc.ControllerGrpcSslSessionInterceptor;
import com.linkedin.venice.controller.server.grpc.GrpcControllerClientDetails;
import com.linkedin.venice.grpc.GrpcUtils;
import com.linkedin.venice.grpc.VeniceGrpcServer;
import com.linkedin.venice.grpc.VeniceGrpcServerConfig;
Expand All @@ -18,7 +19,6 @@
import io.grpc.Context;
import io.grpc.Grpc;
import io.grpc.ManagedChannel;
import java.security.cert.X509Certificate;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
Expand Down Expand Up @@ -57,10 +57,13 @@ public static class VeniceControllerGrpcSecureServiceTestImpl
public void discoverClusterForStore(
DiscoverClusterGrpcRequest request,
io.grpc.stub.StreamObserver<DiscoverClusterGrpcResponse> responseObserver) {
X509Certificate clientCert = CLIENT_CERTIFICATE_CONTEXT_KEY.get(Context.current());
if (clientCert == null) {
GrpcControllerClientDetails clientDetails = GRPC_CONTROLLER_CLIENT_DETAILS.get(Context.current());
if (clientDetails.getClientCertificate() == null) {
throw new RuntimeException("Client cert is null");
}
if (clientDetails.getClientAddress() == null) {
throw new RuntimeException("Client address is null");
}
DiscoverClusterGrpcResponse discoverClusterGrpcResponse =
DiscoverClusterGrpcResponse.newBuilder().setClusterName("test-cluster").build();
responseObserver.onNext(discoverClusterGrpcResponse);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
import static com.linkedin.venice.ConfigKeys.CONTROLLER_GRPC_SERVER_ENABLED;
import static com.linkedin.venice.integration.utils.VeniceClusterWrapper.DEFAULT_KEY_SCHEMA;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertNotNull;

import com.linkedin.venice.acl.NoOpDynamicAccessController;
import com.linkedin.venice.authorization.Method;
import com.linkedin.venice.controllerapi.StoreResponse;
import com.linkedin.venice.grpc.GrpcUtils;
import com.linkedin.venice.integration.utils.ServiceFactory;
import com.linkedin.venice.integration.utils.VeniceClusterCreateOptions;
import com.linkedin.venice.integration.utils.VeniceClusterWrapper;
Expand All @@ -18,27 +22,40 @@
import com.linkedin.venice.protocols.controller.LeaderControllerGrpcResponse;
import com.linkedin.venice.protocols.controller.VeniceControllerGrpcServiceGrpc;
import com.linkedin.venice.protocols.controller.VeniceControllerGrpcServiceGrpc.VeniceControllerGrpcServiceBlockingStub;
import com.linkedin.venice.security.SSLFactory;
import com.linkedin.venice.utils.SslUtils;
import com.linkedin.venice.utils.TestUtils;
import com.linkedin.venice.utils.Utils;
import io.grpc.ChannelCredentials;
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.ManagedChannel;
import io.grpc.StatusRuntimeException;
import java.security.cert.X509Certificate;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;


public class TestControllerGrpcEndpoints {
private VeniceClusterWrapper veniceCluster;
private SSLFactory sslFactory;
private MockDynamicAccessController mockDynamicAccessController;

@BeforeClass(alwaysRun = true)
public void setUp() {
mockDynamicAccessController = new MockDynamicAccessController();
sslFactory = SslUtils.getVeniceLocalSslFactory();
Properties properties = new Properties();
properties.put(CONTROLLER_GRPC_SERVER_ENABLED, true);
VeniceClusterCreateOptions options = new VeniceClusterCreateOptions.Builder().numberOfControllers(1)
.numberOfRouters(1)
.numberOfServers(1)
.accessController(mockDynamicAccessController)
.extraProperties(properties)
.build();
veniceCluster = ServiceFactory.getVeniceCluster(options);
Expand Down Expand Up @@ -97,4 +114,63 @@ public void testGrpcEndpointsWithGrpcClient() {
assertEquals(discoverClusterGrpcResponse.getStoreName(), storeName);
assertEquals(discoverClusterGrpcResponse.getClusterName(), veniceCluster.getClusterName());
}

@Test
public void testCreateStoreOverSecureGrpcChannel() {
String storeName = Utils.getUniqueString("test_grpc_store");
String controllerSecureGrpcUrl = veniceCluster.getLeaderVeniceController().getControllerSecureGrpcUrl();
ChannelCredentials credentials = GrpcUtils.buildChannelCredentials(sslFactory);
ManagedChannel channel = Grpc.newChannelBuilder(controllerSecureGrpcUrl, credentials).build();
VeniceControllerGrpcServiceBlockingStub blockingStub = VeniceControllerGrpcServiceGrpc.newBlockingStub(channel);

CreateStoreGrpcRequest createStoreGrpcRequest = CreateStoreGrpcRequest.newBuilder()
.setClusterStoreInfo(
ClusterStoreGrpcInfo.newBuilder()
.setClusterName(veniceCluster.getClusterName())
.setStoreName(storeName)
.build())
.setOwner("owner")
.setKeySchema(DEFAULT_KEY_SCHEMA)
.setValueSchema("\"string\"")
.build();

// Case 1: User not in allowlist for the resource
mockDynamicAccessController.removeResourceFromAllowList(storeName);
assertFalse(
mockDynamicAccessController.isAllowlistUsers(null, storeName, Method.GET.name()),
"User should not be in allowlist");
StatusRuntimeException exception =
Assert.expectThrows(StatusRuntimeException.class, () -> blockingStub.createStore(createStoreGrpcRequest));
assertEquals(exception.getStatus().getCode(), io.grpc.Status.Code.PERMISSION_DENIED);

// Case 2: Allowlist user
mockDynamicAccessController.addResourceToAllowList(storeName);
CreateStoreGrpcResponse okResponse = blockingStub.createStore(createStoreGrpcRequest);
assertNotNull(okResponse, "Response should not be null");
assertNotNull(okResponse.getClusterStoreInfo(), "ClusterStoreInfo should not be null");
assertEquals(okResponse.getClusterStoreInfo().getClusterName(), veniceCluster.getClusterName());
assertEquals(okResponse.getClusterStoreInfo().getStoreName(), storeName);

veniceCluster.useControllerClient(controllerClient -> {
StoreResponse storeResponse = TestUtils.assertCommand(controllerClient.getStore(storeName));
assertNotNull(storeResponse.getStore(), "Store should not be null");
});
}

private static class MockDynamicAccessController extends NoOpDynamicAccessController {
private final Set<String> resourcesInAllowList = ConcurrentHashMap.newKeySet();

@Override
public boolean isAllowlistUsers(X509Certificate clientCert, String resource, String method) {
return resourcesInAllowList.contains(resource);
}

public void addResourceToAllowList(String resource) {
resourcesInAllowList.add(resource);
}

public void removeResourceFromAllowList(String resource) {
resourcesInAllowList.remove(resource);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import static com.linkedin.venice.integration.utils.VeniceClusterWrapperConstants.DEFAULT_SSL_TO_STORAGE_NODES;
import static com.linkedin.venice.integration.utils.VeniceClusterWrapperConstants.STANDALONE_REGION_NAME;

import com.linkedin.venice.acl.DynamicAccessController;
import com.linkedin.venice.utils.Utils;
import java.util.Collections;
import java.util.Map;
Expand Down Expand Up @@ -44,6 +45,7 @@ public class VeniceClusterCreateOptions {
private final ZkServerWrapper zkServerWrapper;
private final String veniceZkBasePath;
private final PubSubBrokerWrapper pubSubBrokerWrapper;
private final DynamicAccessController accessController;

private VeniceClusterCreateOptions(Builder builder) {
this.clusterName = builder.clusterName;
Expand Down Expand Up @@ -71,6 +73,7 @@ private VeniceClusterCreateOptions(Builder builder) {
this.zkServerWrapper = builder.zkServerWrapper;
this.veniceZkBasePath = builder.veniceZkBasePath;
this.pubSubBrokerWrapper = builder.pubSubBrokerWrapper;
this.accessController = builder.accessController;
}

public String getClusterName() {
Expand Down Expand Up @@ -173,6 +176,10 @@ public PubSubBrokerWrapper getKafkaBrokerWrapper() {
return pubSubBrokerWrapper;
}

public DynamicAccessController getAccessController() {
return accessController;
}

@Override
public String toString() {
return new StringBuilder().append("VeniceClusterCreateOptions - ")
Expand Down Expand Up @@ -279,6 +286,7 @@ public static class Builder {
private ZkServerWrapper zkServerWrapper;
private String veniceZkBasePath = "/";
private PubSubBrokerWrapper pubSubBrokerWrapper;
private DynamicAccessController accessController;

public Builder clusterName(String clusterName) {
this.clusterName = clusterName;
Expand Down Expand Up @@ -409,6 +417,11 @@ public Builder kafkaBrokerWrapper(PubSubBrokerWrapper pubSubBrokerWrapper) {
return this;
}

public Builder accessController(DynamicAccessController accessController) {
this.accessController = accessController;
return this;
}

private void verifyAndAddDefaults() {
if (clusterName == null) {
clusterName = Utils.getUniqueString("venice-cluster");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ static ServiceProvider<VeniceClusterWrapper> generateService(VeniceClusterCreate
.d2Enabled(true)
.regionName(options.getRegionName())
.extraProperties(options.getExtraProperties())
.dynamicAccessController(options.getAccessController())
.build());
LOGGER.info(
"[{}][{}] Created child controller on port {}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import static com.linkedin.venice.integration.utils.VeniceClusterWrapperConstants.DEFAULT_PARTITION_SIZE_BYTES;
import static com.linkedin.venice.integration.utils.VeniceClusterWrapperConstants.DEFAULT_REPLICATION_FACTOR;

import com.linkedin.venice.acl.DynamicAccessController;
import com.linkedin.venice.authorization.AuthorizerService;
import java.util.Arrays;
import java.util.Map;
Expand Down Expand Up @@ -38,6 +39,7 @@ public class VeniceControllerCreateOptions {
private final Properties extraProperties;
private final AuthorizerService authorizerService;
private final String regionName;
private final DynamicAccessController dynamicAccessController;

private VeniceControllerCreateOptions(Builder builder) {
multiRegion = builder.multiRegion;
Expand All @@ -59,6 +61,7 @@ private VeniceControllerCreateOptions(Builder builder) {
authorizerService = builder.authorizerService;
isParent = builder.childControllers != null && builder.childControllers.length != 0;
regionName = builder.regionName;
dynamicAccessController = builder.dynamicAccessController;
}

@Override
Expand Down Expand Up @@ -201,6 +204,10 @@ public AuthorizerService getAuthorizerService() {
return authorizerService;
}

public DynamicAccessController getDynamicAccessController() {
return dynamicAccessController;
}

public String getRegionName() {
return regionName;
}
Expand All @@ -224,6 +231,7 @@ public static class Builder {
private Properties extraProperties = new Properties();
private AuthorizerService authorizerService;
private String regionName;
private DynamicAccessController dynamicAccessController;

public Builder(String[] clusterNames, ZkServerWrapper zkServer, PubSubBrokerWrapper kafkaBroker) {
this.clusterNames = Objects.requireNonNull(clusterNames, "clusterNames cannot be null when creating controller");
Expand Down Expand Up @@ -315,6 +323,11 @@ public Builder regionName(String regionName) {
return this;
}

public Builder dynamicAccessController(DynamicAccessController dynamicAccessController) {
this.dynamicAccessController = dynamicAccessController;
return this;
}

private void verifyAndAddParentControllerSpecificDefaults() {
extraProperties.setProperty(LOCAL_REGION_NAME, DEFAULT_PARENT_DATA_CENTER_REGION_NAME);
if (!extraProperties.containsKey(CONTROLLER_AUTO_MATERIALIZE_META_SYSTEM_STORE)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ static StatefulServiceProvider<VeniceControllerWrapper> generateService(VeniceCo
.setD2Client(d2Client)
.setRouterClientConfig(consumerClientConfig.orElse(null))
.setExternalSupersetSchemaGenerator(supersetSchemaGenerator.orElse(null))
.setAccessController(options.getDynamicAccessController())
.build();
VeniceController veniceController = new VeniceController(ctx);
return new VeniceControllerWrapper(
Expand Down
Loading

0 comments on commit fbbd13b

Please sign in to comment.