From 62b53b067fe34ec38c1f051604fbf37fe4ccc510 Mon Sep 17 00:00:00 2001 From: James Duong Date: Wed, 24 Jan 2024 14:53:40 -0800 Subject: [PATCH] Bring connection options to parity with Go - Propagate bearer tokens to spawned connections. - Allow for cookies and propagate to spawned connections. - Implement support for user-specified headers. - Implement TLS options. --- .../driver/flightsql/FlightInfoReader.java | 10 +- .../FlightSqlClientWithCallOptions.java | 289 ++++++++++++++++++ .../driver/flightsql/FlightSqlConnection.java | 262 ++++++++++++---- .../FlightSqlConnectionProperties.java | 39 +++ .../driver/flightsql/FlightSqlDatabase.java | 15 +- .../driver/flightsql/FlightSqlDriver.java | 3 +- .../driver/flightsql/FlightSqlStatement.java | 15 +- .../driver/flightsql/InfoMetadataBuilder.java | 5 +- 8 files changed, 555 insertions(+), 83 deletions(-) create mode 100644 java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlClientWithCallOptions.java create mode 100644 java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionProperties.java diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightInfoReader.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightInfoReader.java index 9b0cda91dc..6850be0639 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightInfoReader.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightInfoReader.java @@ -24,12 +24,10 @@ import java.util.Objects; import org.apache.arrow.adbc.core.AdbcException; import org.apache.arrow.adbc.core.AdbcStatusCode; -import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightEndpoint; import org.apache.arrow.flight.FlightRuntimeException; import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Location; -import org.apache.arrow.flight.sql.FlightSqlClient; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.VectorUnloader; @@ -40,8 +38,8 @@ /** An ArrowReader that wraps a FlightInfo. */ public class FlightInfoReader extends ArrowReader { private final Schema schema; - private final FlightSqlClient client; - private final LoadingCache clientCache; + private final FlightSqlClientWithCallOptions client; + private final LoadingCache clientCache; private final List flightEndpoints; private int nextEndpointIndex; private FlightStream currentStream; @@ -49,8 +47,8 @@ public class FlightInfoReader extends ArrowReader { FlightInfoReader( BufferAllocator allocator, - FlightSqlClient client, - LoadingCache clientCache, + FlightSqlClientWithCallOptions client, + LoadingCache clientCache, List flightEndpoints) throws AdbcException { super(allocator); diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlClientWithCallOptions.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlClientWithCallOptions.java new file mode 100644 index 0000000000..0ef55091ff --- /dev/null +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlClientWithCallOptions.java @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.adbc.driver.flightsql; + +import static org.apache.arrow.flight.sql.FlightSqlClient.PreparedStatement; +import static org.apache.arrow.flight.sql.FlightSqlClient.Savepoint; +import static org.apache.arrow.flight.sql.FlightSqlClient.SubstraitPlan; +import static org.apache.arrow.flight.sql.FlightSqlClient.Transaction; + +import java.util.List; +import org.apache.arrow.flight.CallOption; +import org.apache.arrow.flight.CancelFlightInfoRequest; +import org.apache.arrow.flight.CancelFlightInfoResult; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.RenewFlightEndpointRequest; +import org.apache.arrow.flight.SchemaResult; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.sql.CancelResult; +import org.apache.arrow.flight.sql.FlightSqlClient; +import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.flight.sql.util.TableRef; +import org.apache.arrow.util.AutoCloseables; + +/** A wrapper around FlightSqlClient which automatically adds CallOptions to each RPC call. */ +public class FlightSqlClientWithCallOptions implements AutoCloseable { + private final FlightSqlClient client; + private final CallOption[] connectionOptions; + + public FlightSqlClientWithCallOptions(FlightSqlClient client, CallOption... options) { + this.client = client; + this.connectionOptions = options; + } + + public FlightInfo execute(String query, CallOption... options) { + return client.execute(query, combine(options)); + } + + public FlightInfo execute(String query, Transaction transaction, CallOption... options) { + return client.execute(query, transaction, combine(options)); + } + + public FlightInfo executeSubstrait(SubstraitPlan plan, CallOption... options) { + return client.executeSubstrait(plan, combine(options)); + } + + public FlightInfo executeSubstrait( + SubstraitPlan plan, Transaction transaction, CallOption... options) { + return client.executeSubstrait(plan, transaction, combine(options)); + } + + public SchemaResult getExecuteSchema( + String query, Transaction transaction, CallOption... options) { + return client.getExecuteSchema(query, transaction, combine(options)); + } + + public SchemaResult getExecuteSchema(String query, CallOption... options) { + return client.getExecuteSchema(query, combine(options)); + } + + public SchemaResult getExecuteSubstraitSchema( + SubstraitPlan plan, Transaction transaction, CallOption... options) { + return client.getExecuteSubstraitSchema(plan, transaction, combine(options)); + } + + public SchemaResult getExecuteSubstraitSchema( + SubstraitPlan substraitPlan, CallOption... options) { + return client.getExecuteSubstraitSchema(substraitPlan, combine(options)); + } + + public long executeUpdate(String query, CallOption... options) { + return client.executeUpdate(query, combine(options)); + } + + public long executeUpdate(String query, Transaction transaction, CallOption... options) { + return client.executeUpdate(query, transaction, combine(options)); + } + + public long executeSubstraitUpdate(SubstraitPlan plan, CallOption... options) { + return client.executeSubstraitUpdate(plan, combine(options)); + } + + public long executeSubstraitUpdate( + SubstraitPlan plan, Transaction transaction, CallOption... options) { + return client.executeSubstraitUpdate(plan, transaction, combine(options)); + } + + public FlightInfo getCatalogs(CallOption... options) { + return client.getCatalogs(options); + } + + public SchemaResult getCatalogsSchema(CallOption... options) { + return client.getCatalogsSchema(options); + } + + public FlightInfo getSchemas( + String catalog, String dbSchemaFilterPattern, CallOption... options) { + return client.getSchemas(catalog, dbSchemaFilterPattern, combine(options)); + } + + public SchemaResult getSchemasSchema(CallOption... options) { + return client.getSchemasSchema(options); + } + + public SchemaResult getSchema(FlightDescriptor descriptor, CallOption... options) { + return client.getSchema(descriptor, combine(options)); + } + + public FlightStream getStream(Ticket ticket, CallOption... options) { + return client.getStream(ticket, combine(options)); + } + + public FlightInfo getSqlInfo(FlightSql.SqlInfo... info) { + return client.getSqlInfo(info); + } + + public FlightInfo getSqlInfo(FlightSql.SqlInfo[] info, CallOption... options) { + return client.getSqlInfo(info, combine(options)); + } + + public FlightInfo getSqlInfo(int[] info, CallOption... options) { + return client.getSqlInfo(info, combine(options)); + } + + public FlightInfo getSqlInfo(Iterable info, CallOption... options) { + return client.getSqlInfo(info, combine(options)); + } + + public SchemaResult getSqlInfoSchema(CallOption... options) { + return client.getSqlInfoSchema(options); + } + + public FlightInfo getXdbcTypeInfo(int dataType, CallOption... options) { + return client.getXdbcTypeInfo(dataType, combine(options)); + } + + public FlightInfo getXdbcTypeInfo(CallOption... options) { + return client.getXdbcTypeInfo(options); + } + + public SchemaResult getXdbcTypeInfoSchema(CallOption... options) { + return client.getXdbcTypeInfoSchema(options); + } + + public FlightInfo getTables( + String catalog, + String dbSchemaFilterPattern, + String tableFilterPattern, + List tableTypes, + boolean includeSchema, + CallOption... options) { + return client.getTables( + catalog, + dbSchemaFilterPattern, + tableFilterPattern, + tableTypes, + includeSchema, + combine(options)); + } + + public SchemaResult getTablesSchema(boolean includeSchema, CallOption... options) { + return client.getTablesSchema(includeSchema, combine(options)); + } + + public FlightInfo getPrimaryKeys(TableRef tableRef, CallOption... options) { + return client.getPrimaryKeys(tableRef, combine(options)); + } + + public SchemaResult getPrimaryKeysSchema(CallOption... options) { + return client.getPrimaryKeysSchema(options); + } + + public FlightInfo getExportedKeys(TableRef tableRef, CallOption... options) { + return client.getExportedKeys(tableRef, combine(options)); + } + + public SchemaResult getExportedKeysSchema(CallOption... options) { + return client.getExportedKeysSchema(options); + } + + public FlightInfo getImportedKeys(TableRef tableRef, CallOption... options) { + return client.getImportedKeys(tableRef, combine(options)); + } + + public SchemaResult getImportedKeysSchema(CallOption... options) { + return client.getImportedKeysSchema(options); + } + + public FlightInfo getCrossReference( + TableRef pkTableRef, TableRef fkTableRef, CallOption... options) { + return client.getCrossReference(pkTableRef, fkTableRef, combine(options)); + } + + public SchemaResult getCrossReferenceSchema(CallOption... options) { + return client.getCrossReferenceSchema(options); + } + + public FlightInfo getTableTypes(CallOption... options) { + return client.getTableTypes(options); + } + + public SchemaResult getTableTypesSchema(CallOption... options) { + return client.getTableTypesSchema(options); + } + + public PreparedStatement prepare(String query, CallOption... options) { + return client.prepare(query, combine(options)); + } + + public PreparedStatement prepare( + String query, FlightSqlClient.Transaction transaction, CallOption... options) { + return client.prepare(query, transaction, combine(options)); + } + + public PreparedStatement prepare(SubstraitPlan plan, CallOption... options) { + return client.prepare(plan, combine(options)); + } + + public PreparedStatement prepare( + SubstraitPlan plan, Transaction transaction, CallOption... options) { + return client.prepare(plan, transaction, combine(options)); + } + + public Transaction beginTransaction(CallOption... options) { + return client.beginTransaction(options); + } + + public Savepoint beginSavepoint(Transaction transaction, String name, CallOption... options) { + return client.beginSavepoint(transaction, name, combine(options)); + } + + public void commit(Transaction transaction, CallOption... options) { + client.commit(transaction, combine(options)); + } + + public void release(Savepoint savepoint, CallOption... options) { + client.release(savepoint, combine(options)); + } + + public void rollback(Transaction transaction, CallOption... options) { + client.rollback(transaction, combine(options)); + } + + public void rollback(Savepoint savepoint, CallOption... options) { + client.rollback(savepoint, combine(options)); + } + + public CancelFlightInfoResult cancelFlightInfo( + CancelFlightInfoRequest request, CallOption... options) { + return client.cancelFlightInfo(request, combine(options)); + } + + public CancelResult cancelQuery(FlightInfo info, CallOption... options) { + return client.cancelQuery(info, combine(options)); + } + + public FlightEndpoint renewFlightEndpoint( + RenewFlightEndpointRequest request, CallOption... options) { + return client.renewFlightEndpoint(request, combine(options)); + } + + public void close() throws Exception { + AutoCloseables.close(client); + } + + private CallOption[] combine(CallOption... options) { + final CallOption[] result = new CallOption[connectionOptions.length + options.length]; + System.arraycopy(connectionOptions, 0, result, 0, connectionOptions.length); + System.arraycopy(options, 0, result, connectionOptions.length, options.length); + return options; + } +} diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java index 6370296cb5..0bab9c9e72 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java @@ -20,26 +20,34 @@ import com.github.benmanes.caffeine.cache.LoadingCache; import com.github.benmanes.caffeine.cache.RemovalCause; import com.google.protobuf.InvalidProtocolBufferException; +import java.io.ByteArrayInputStream; +import java.io.DataInputStream; +import java.io.IOException; +import java.io.InputStream; import java.net.URISyntaxException; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Collections; -import java.util.Objects; +import java.util.Map; +import java.util.Optional; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import org.apache.arrow.adbc.core.AdbcConnection; +import org.apache.arrow.adbc.core.AdbcDriver; import org.apache.arrow.adbc.core.AdbcException; import org.apache.arrow.adbc.core.AdbcStatement; +import org.apache.arrow.adbc.core.AdbcStatusCode; import org.apache.arrow.adbc.core.BulkIngestMode; import org.apache.arrow.adbc.sql.SqlQuirks; -import org.apache.arrow.flight.CallHeaders; -import org.apache.arrow.flight.CallInfo; -import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.CallOption; +import org.apache.arrow.flight.FlightCallHeaders; import org.apache.arrow.flight.FlightClient; -import org.apache.arrow.flight.FlightClientMiddleware; import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.HeaderCallOption; import org.apache.arrow.flight.Location; import org.apache.arrow.flight.Ticket; -import org.apache.arrow.flight.auth2.Auth2Constants; import org.apache.arrow.flight.auth2.BasicAuthCredentialWriter; +import org.apache.arrow.flight.client.ClientCookieMiddleware; import org.apache.arrow.flight.grpc.CredentialCallOption; import org.apache.arrow.flight.impl.Flight; import org.apache.arrow.flight.sql.FlightSqlClient; @@ -50,76 +58,57 @@ public class FlightSqlConnection implements AdbcConnection { private final BufferAllocator allocator; - private final FlightSqlClient client; + private final AtomicInteger counter; + private final FlightSqlClientWithCallOptions client; private final SqlQuirks quirks; - private final LoadingCache clientCache; + private final Map parameters; + private final LoadingCache clientCache; + + // Cached data to use across additional connections. + private ClientCookieMiddleware.Factory cookieMiddlewareFactory; + private CallOption[] callOptions; + + // Used to cache the InputStream content as a byte array since + // subsequent connections may need to use it but it is supplied as a stream. + private byte[] mtlsCertChainBytes; + private byte[] mtlsPrivateKeyBytes; + private byte[] tlsRootCertsBytes; FlightSqlConnection( BufferAllocator allocator, SqlQuirks quirks, Location location, - String username, - String password) { + Map parameters) + throws AdbcException { this.allocator = allocator; + this.counter = new AtomicInteger(0); this.quirks = quirks; + this.parameters = parameters; + this.client = + new FlightSqlClientWithCallOptions(new FlightSqlClient(createInitialConnection(location))); this.clientCache = Caffeine.newBuilder() .expireAfterAccess(5, TimeUnit.MINUTES) .removalListener( - (Location key, FlightClient value, RemovalCause cause) -> { + (Location key, FlightSqlClientWithCallOptions value, RemovalCause cause) -> { if (value == null) return; try { value.close(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException(e); + } catch (Exception ex) { + if (ex instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + throw new RuntimeException(ex); } }) .build( loc -> { - FlightClient flightClient = - FlightClient.builder(allocator, loc) - .intercept( - new FlightClientMiddleware.Factory() { - final String[] bearerValue = {null}; - - @Override - public FlightClientMiddleware onCallStarted(CallInfo info) { - return new FlightClientMiddleware() { - @Override - public void onBeforeSendingHeaders( - CallHeaders outgoingHeaders) { - if (bearerValue[0] != null) { - outgoingHeaders.insert( - Auth2Constants.AUTHORIZATION_HEADER, bearerValue[0]); - } - } - - @Override - public void onHeadersReceived(CallHeaders incomingHeaders) { - if (bearerValue[0] == null) { - bearerValue[0] = - incomingHeaders.get( - Auth2Constants.AUTHORIZATION_HEADER); - } - } - - @Override - public void onCallCompleted(CallStatus status) {} - }; - } - }) - .build(); - - if (username != null) { - flightClient.handshake( - new CredentialCallOption( - new BasicAuthCredentialWriter(username, password))); - } - return flightClient; + FlightClient client = buildClient(loc); + client.handshake(callOptions); + return new FlightSqlClientWithCallOptions( + new FlightSqlClient(client), callOptions); }); - - this.client = new FlightSqlClient(Objects.requireNonNull(clientCache.get(location))); + this.clientCache.put(location, this.client); } @Override @@ -193,6 +182,7 @@ public void setAutoCommit(boolean enableAutoCommit) throws AdbcException { @Override public void close() throws Exception { + clientCache.invalidateAll(); AutoCloseables.close(client, allocator); } @@ -200,4 +190,164 @@ public void close() throws Exception { public String toString() { return "FlightSqlConnection{" + "client=" + client + '}'; } + + /** + * Initialize cached data to share between connections and create, test, and authenticate the + * first connection. + */ + private FlightClient createInitialConnection(Location location) throws AdbcException { + // Setup cached pre-connection properties. + try { + final InputStream mtlsCertChain = + FlightSqlConnectionProperties.MTLS_CERT_CHAIN.get(parameters); + if (mtlsCertChain != null) { + this.mtlsCertChainBytes = inputStreamToBytes(mtlsCertChain); + } + + final InputStream mtlsPrivateKey = + FlightSqlConnectionProperties.MTLS_PRIVATE_KEY.get(parameters); + if (mtlsPrivateKey != null) { + this.mtlsPrivateKeyBytes = inputStreamToBytes(mtlsPrivateKey); + } + + final InputStream tlsRootCerts = FlightSqlConnectionProperties.TLS_ROOT_CERTS.get(parameters); + if (tlsRootCerts != null) { + this.tlsRootCertsBytes = inputStreamToBytes(tlsRootCerts); + } + } catch (IOException ex) { + throw new AdbcException( + String.format( + "Error reading stream for one of the options %s, %s, %s.", + FlightSqlConnectionProperties.MTLS_CERT_CHAIN.getKey(), + FlightSqlConnectionProperties.MTLS_PRIVATE_KEY.getKey(), + FlightSqlConnectionProperties.TLS_ROOT_CERTS.getKey()), + ex, + AdbcStatusCode.IO, + null, + 0); + } + + final boolean useCookieMiddleware = + Boolean.TRUE.equals(FlightSqlConnectionProperties.WITH_COOKIE_MIDDLEWARE.get(parameters)); + if (useCookieMiddleware) { + this.cookieMiddlewareFactory = new ClientCookieMiddleware.Factory(); + } + + // Build the client using the above properties. + final FlightClient client = buildClient(location); + + // Add user-specified headers. + ArrayList options = new ArrayList<>(); + final FlightCallHeaders callHeaders = new FlightCallHeaders(); + for (Map.Entry parameter : parameters.entrySet()) { + if (parameter.getKey().startsWith(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX)) { + String userHeaderName = + parameter + .getKey() + .substring(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX.length()); + + if (parameter.getValue() instanceof String) { + callHeaders.insert(userHeaderName, (String) parameter.getValue()); + } else if (parameter.getValue() instanceof byte[]) { + callHeaders.insert(userHeaderName, (byte[]) parameter.getValue()); + } else { + throw new AdbcException( + String.format( + "Header values must be String or byte[]. The header failing was %s.", + parameter.getKey()), + null, + AdbcStatusCode.INVALID_ARGUMENT, + null, + 0); + } + } + } + + options.add(new HeaderCallOption(callHeaders)); + + // Test the connection. + String username = AdbcDriver.PARAM_USERNAME.get(parameters); + String password = AdbcDriver.PARAM_PASSWORD.get(parameters); + if (username != null && password != null) { + Optional bearerToken = + client.authenticateBasicToken(username, password); + options.add( + bearerToken.orElse( + new CredentialCallOption(new BasicAuthCredentialWriter(username, password)))); + this.callOptions = options.toArray(new CallOption[0]); + } else { + this.callOptions = options.toArray(new CallOption[0]); + client.handshake(this.callOptions); + } + + return client; + } + + /** Returns a yet-to-be authenticated FlightClient */ + private FlightClient buildClient(Location location) throws AdbcException { + final FlightClient.Builder builder = + FlightClient.builder() + .allocator( + allocator.newChildAllocator( + "adbc-flightclient-connection-" + counter.getAndIncrement(), + 0, + allocator.getLimit())) + .location(location); + + // Configure TLS options. + if (mtlsCertChainBytes != null && mtlsPrivateKeyBytes != null) { + builder.clientCertificate( + new ByteArrayInputStream(mtlsCertChainBytes), + new ByteArrayInputStream(mtlsPrivateKeyBytes)); + } else if (mtlsCertChainBytes != null) { + throw new AdbcException( + String.format( + "Must provide both %s and %s or neither. %s provided only.", + FlightSqlConnectionProperties.MTLS_CERT_CHAIN.getKey(), + FlightSqlConnectionProperties.MTLS_PRIVATE_KEY.getKey(), + FlightSqlConnectionProperties.MTLS_CERT_CHAIN.getKey()), + null, + AdbcStatusCode.INVALID_ARGUMENT, + null, + 0); + } else if (mtlsPrivateKeyBytes != null) { + throw new AdbcException( + String.format( + "Must provide both %s and %s or neither. %s provided only.", + FlightSqlConnectionProperties.MTLS_PRIVATE_KEY.getKey(), + FlightSqlConnectionProperties.MTLS_CERT_CHAIN.getKey(), + FlightSqlConnectionProperties.MTLS_PRIVATE_KEY.getKey()), + null, + AdbcStatusCode.INVALID_ARGUMENT, + null, + 0); + } + + if (tlsRootCertsBytes != null) { + builder.trustedCertificates(new ByteArrayInputStream(tlsRootCertsBytes)); + } + + if (Boolean.TRUE.equals(FlightSqlConnectionProperties.TLS_SKIP_VERIFY.get(parameters))) { + builder.verifyServer(false); + } + + String hostnameOverride = FlightSqlConnectionProperties.TLS_OVERRIDE_HOSTNAME.get(parameters); + if (hostnameOverride != null) { + builder.overrideHostname(hostnameOverride); + } + + // Setup cookies if needed. + if (cookieMiddlewareFactory != null) { + builder.intercept(cookieMiddlewareFactory); + } + + return builder.build(); + } + + private static byte[] inputStreamToBytes(InputStream stream) throws IOException { + byte[] bytes = new byte[stream.available()]; + DataInputStream dataInputStream = new DataInputStream(stream); + dataInputStream.readFully(bytes); + return bytes; + } } diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionProperties.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionProperties.java new file mode 100644 index 0000000000..295bcb83bb --- /dev/null +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionProperties.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.adbc.driver.flightsql; + +import java.io.InputStream; +import org.apache.arrow.adbc.core.TypedKey; + +/** Defines connection options that are used by the FlightSql driver. */ +public interface FlightSqlConnectionProperties { + // + TypedKey MTLS_CERT_CHAIN = + new TypedKey<>("adbc.flight.sql.client_option.mtls_cert_chain", InputStream.class); + TypedKey MTLS_PRIVATE_KEY = + new TypedKey<>("adbc.flight.sql.client_option.mtls_private_key", InputStream.class); + TypedKey TLS_OVERRIDE_HOSTNAME = + new TypedKey<>("adbc.flight.sql.client_option.tls_override_hostname", String.class); + TypedKey TLS_SKIP_VERIFY = + new TypedKey<>("adbc.flight.sql.client_option.tls_skip_verify", Boolean.class); + TypedKey TLS_ROOT_CERTS = + new TypedKey<>("adbc.flight.sql.client_option.tls_root_certs", InputStream.class); + TypedKey WITH_COOKIE_MIDDLEWARE = + new TypedKey<>("adbc.flight.sql.rpc.with_cookie_middleware", Boolean.class); + String RPC_CALL_HEADER_PREFIX = "adbc.flight.sql.rpc.call_header."; +} diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDatabase.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDatabase.java index 89c89b4400..4468db2e19 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDatabase.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDatabase.java @@ -17,6 +17,7 @@ package org.apache.arrow.adbc.driver.flightsql; +import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; import org.apache.arrow.adbc.core.AdbcConnection; import org.apache.arrow.adbc.core.AdbcDatabase; @@ -31,33 +32,29 @@ public final class FlightSqlDatabase implements AdbcDatabase { private final Location location; private final SqlQuirks quirks; private final AtomicInteger counter; - private final String username; - private final String password; + private final Map parameters; FlightSqlDatabase( BufferAllocator allocator, Location location, SqlQuirks quirks, - String username, - String password) + Map parameters) throws AdbcException { this.allocator = allocator; this.location = location; this.quirks = quirks; this.counter = new AtomicInteger(); - this.username = username; - this.password = password; + this.parameters = parameters; } @Override public AdbcConnection connect() throws AdbcException { final int count = counter.getAndIncrement(); return new FlightSqlConnection( - allocator.newChildAllocator("adbc-jdbc-connection-" + count, 0, allocator.getLimit()), + allocator.newChildAllocator("adbc-flight-connection-" + count, 0, allocator.getLimit()), quirks, location, - username, - password); + parameters); } @Override diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java index 25a8591adb..ac8339631a 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java @@ -75,7 +75,6 @@ public AdbcDatabase open(Map parameters) throws AdbcException { allocator, location, (SqlQuirks) quirks, - PARAM_USERNAME.get(parameters), - PARAM_PASSWORD.get(parameters)); + parameters); } } diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java index e64508b4bf..77cb2622d1 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java @@ -29,7 +29,6 @@ import org.apache.arrow.adbc.core.BulkIngestMode; import org.apache.arrow.adbc.core.PartitionDescriptor; import org.apache.arrow.adbc.sql.SqlQuirks; -import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightEndpoint; import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.flight.FlightRuntimeException; @@ -44,8 +43,8 @@ public class FlightSqlStatement implements AdbcStatement { private final BufferAllocator allocator; - private final FlightSqlClient client; - private final LoadingCache clientCache; + private final FlightSqlClientWithCallOptions client; + private final LoadingCache clientCache; private final SqlQuirks quirks; // State for SQL queries @@ -57,8 +56,8 @@ public class FlightSqlStatement implements AdbcStatement { FlightSqlStatement( BufferAllocator allocator, - FlightSqlClient client, - LoadingCache clientCache, + FlightSqlClientWithCallOptions client, + LoadingCache clientCache, SqlQuirks quirks) { this.allocator = allocator; this.client = client; @@ -69,8 +68,8 @@ public class FlightSqlStatement implements AdbcStatement { static FlightSqlStatement ingestRoot( BufferAllocator allocator, - FlightSqlClient client, - LoadingCache clientCache, + FlightSqlClientWithCallOptions client, + LoadingCache clientCache, SqlQuirks quirks, String targetTableName, BulkIngestMode mode) { @@ -188,7 +187,7 @@ interface Execute { private R execute( Execute doPrepared, - Execute doRegular) + Execute doRegular) throws AdbcException { try { if (preparedStatement != null) { diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/InfoMetadataBuilder.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/InfoMetadataBuilder.java index 318405d6ce..38c026253d 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/InfoMetadataBuilder.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/InfoMetadataBuilder.java @@ -47,7 +47,7 @@ final class InfoMetadataBuilder implements AutoCloseable { private static final Map SUPPORTED_CODES = new HashMap<>(); private final Collection requestedCodes; - private final FlightSqlClient client; + private final FlightSqlClientWithCallOptions client; private VectorSchemaRoot root; private final UInt4Vector infoCodes; @@ -80,7 +80,8 @@ interface AddInfo { }); } - InfoMetadataBuilder(BufferAllocator allocator, FlightSqlClient client, int[] infoCodes) { + InfoMetadataBuilder( + BufferAllocator allocator, FlightSqlClientWithCallOptions client, int[] infoCodes) { if (infoCodes == null) { this.requestedCodes = new ArrayList<>(SUPPORTED_CODES.keySet()); this.requestedCodes.add(AdbcInfoCode.DRIVER_NAME.getValue());