From 50ffc95f17970514d02af7e953e4b939e41b32c0 Mon Sep 17 00:00:00 2001 From: Marius van Niekerk Date: Wed, 26 Feb 2025 10:30:54 -0500 Subject: [PATCH] feat(go/adbc/driver/flightsql): allow passing arbitrary grpc dial options in NewDatabase Add a new database constructor NewDatabaseWithOptions to allow passing arbitrary user-specified grpc dial options. This is useful for constructs like ```go driver := flightsql.NewDriver(memory.DefaultAllocator) driver.NewDatabaseWithOptions(map[string]string{ "uri": uri, }, grpc.WithStatsHandler(otelgrpc.NewClientHandler()) ) ``` which allows usage of the opentelemetry grpc instrumentation for example. This also provides an escape valve for users that need the ability to use less commonly used grpc client features that we have not exposed via the string map args. --- go/adbc/driver/flightsql/flightsql_database.go | 8 +++++--- go/adbc/driver/flightsql/flightsql_driver.go | 14 ++++++++++++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/go/adbc/driver/flightsql/flightsql_database.go b/go/adbc/driver/flightsql/flightsql_database.go index a25c839ec4..64d0c00ad5 100644 --- a/go/adbc/driver/flightsql/flightsql_database.go +++ b/go/adbc/driver/flightsql/flightsql_database.go @@ -67,6 +67,7 @@ type databaseImpl struct { dialOpts dbDialOpts enableCookies bool options map[string]string + userDialOpts []grpc.DialOption } func (d *databaseImpl) SetOptions(cnOptions map[string]string) error { @@ -337,7 +338,7 @@ func (d *databaseImpl) Close() error { return nil } -func getFlightClient(ctx context.Context, loc string, d *databaseImpl, authMiddle *bearerAuthMiddleware, cookies flight.CookieMiddleware) (*flightsql.Client, error) { +func getFlightClient(ctx context.Context, loc string, d *databaseImpl, authMiddle *bearerAuthMiddleware, cookies flight.CookieMiddleware, userGrpcDialOpts ...grpc.DialOption) (*flightsql.Client, error) { middleware := []flight.ClientMiddleware{ { Unary: makeUnaryLoggingInterceptor(d.Logger), @@ -371,6 +372,7 @@ func getFlightClient(ctx context.Context, loc string, d *databaseImpl, authMiddl dv, _ := d.DatabaseImplBase.DriverInfo.GetInfoForInfoCode(adbc.InfoDriverVersion) driverVersion := dv.(string) dialOpts := append(d.dialOpts.opts, grpc.WithConnectParams(d.timeout.connectParams()), grpc.WithTransportCredentials(creds), grpc.WithUserAgent("ADBC Flight SQL Driver "+driverVersion)) + dialOpts = append(dialOpts, userGrpcDialOpts...) d.Logger.DebugContext(ctx, "new client", "location", loc) cl, err := flightsql.NewClient(target, nil, middleware, dialOpts...) @@ -414,7 +416,7 @@ func (d *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) { cookies = flight.NewCookieMiddleware() } - cl, err := getFlightClient(ctx, d.uri.String(), d, authMiddle, cookies) + cl, err := getFlightClient(ctx, d.uri.String(), d, authMiddle, cookies, d.userDialOpts...) if err != nil { return nil, err } @@ -435,7 +437,7 @@ func (d *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) { } // use the existing auth token if there is one cl, err := getFlightClient(context.Background(), uri, d, - &bearerAuthMiddleware{hdrs: authMiddle.hdrs.Copy()}, cookieMiddleware) + &bearerAuthMiddleware{hdrs: authMiddle.hdrs.Copy()}, cookieMiddleware, d.userDialOpts...) if err != nil { return nil, err } diff --git a/go/adbc/driver/flightsql/flightsql_driver.go b/go/adbc/driver/flightsql/flightsql_driver.go index bf4e0b4a80..933e3e99d5 100644 --- a/go/adbc/driver/flightsql/flightsql_driver.go +++ b/go/adbc/driver/flightsql/flightsql_driver.go @@ -39,6 +39,7 @@ import ( "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase" "github.com/apache/arrow-go/v18/arrow/memory" "golang.org/x/exp/maps" + "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) @@ -77,12 +78,15 @@ type driverImpl struct { } // NewDriver creates a new Flight SQL driver using the given Arrow allocator. +// +// It optionally accepts gRPC dial options to be used when connecting to the func NewDriver(alloc memory.Allocator) adbc.Driver { info := driverbase.DefaultDriverInfo("Flight SQL") return driverbase.NewDriver(&driverImpl{DriverImplBase: driverbase.NewDriverImplBase(info, alloc)}) } -func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error) { +// NewDatabase creates a new Flight SQL database using the given options. +func (d *driverImpl) NewDatabaseWithOptions(opts map[string]string, userDialOpts ...grpc.DialOption) (adbc.Database, error) { opts = maps.Clone(opts) uri, ok := opts[adbc.OptionKeyURI] if !ok { @@ -99,7 +103,8 @@ func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error) // Match gRPC default connectTimeout: time.Second * 20, }, - hdrs: make(metadata.MD), + hdrs: make(metadata.MD), + userDialOpts: userDialOpts, } var err error @@ -118,3 +123,8 @@ func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error) return driverbase.NewDatabase(db), nil } + +// NewDatabase creates a new Flight SQL database using the given options. +func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error) { + return d.NewDatabaseWithOptions(opts) +}