Skip to content

Commit

Permalink
feat(go/adbc/driver/flightsql): allow passing arbitrary grpc dial opt…
Browse files Browse the repository at this point in the history
…ions in NewDatabase

Add a new database constructor NewDatabaseWithOptions to allow passing arbitrary user-specified grpc dial options.

This is useful for constructs like
```go
driver := flightsql.NewDriver(memory.DefaultAllocator)
driver.NewDatabaseWithOptions(map[string]string{
		"uri": uri,
	},
	grpc.WithStatsHandler(otelgrpc.NewClientHandler())
)
```

which allows usage of the opentelemetry grpc instrumentation for example.

This also provides an escape valve for users that need the ability to use less commonly used grpc client features that we have not exposed via the string map args.
  • Loading branch information
mariusvniekerk committed Feb 26, 2025
1 parent 470b209 commit 50ffc95
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
8 changes: 5 additions & 3 deletions go/adbc/driver/flightsql/flightsql_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ type databaseImpl struct {
dialOpts dbDialOpts
enableCookies bool
options map[string]string
userDialOpts []grpc.DialOption
}

func (d *databaseImpl) SetOptions(cnOptions map[string]string) error {
Expand Down Expand Up @@ -337,7 +338,7 @@ func (d *databaseImpl) Close() error {
return nil
}

func getFlightClient(ctx context.Context, loc string, d *databaseImpl, authMiddle *bearerAuthMiddleware, cookies flight.CookieMiddleware) (*flightsql.Client, error) {
func getFlightClient(ctx context.Context, loc string, d *databaseImpl, authMiddle *bearerAuthMiddleware, cookies flight.CookieMiddleware, userGrpcDialOpts ...grpc.DialOption) (*flightsql.Client, error) {
middleware := []flight.ClientMiddleware{
{
Unary: makeUnaryLoggingInterceptor(d.Logger),
Expand Down Expand Up @@ -371,6 +372,7 @@ func getFlightClient(ctx context.Context, loc string, d *databaseImpl, authMiddl
dv, _ := d.DatabaseImplBase.DriverInfo.GetInfoForInfoCode(adbc.InfoDriverVersion)
driverVersion := dv.(string)
dialOpts := append(d.dialOpts.opts, grpc.WithConnectParams(d.timeout.connectParams()), grpc.WithTransportCredentials(creds), grpc.WithUserAgent("ADBC Flight SQL Driver "+driverVersion))
dialOpts = append(dialOpts, userGrpcDialOpts...)

d.Logger.DebugContext(ctx, "new client", "location", loc)
cl, err := flightsql.NewClient(target, nil, middleware, dialOpts...)
Expand Down Expand Up @@ -414,7 +416,7 @@ func (d *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) {
cookies = flight.NewCookieMiddleware()
}

cl, err := getFlightClient(ctx, d.uri.String(), d, authMiddle, cookies)
cl, err := getFlightClient(ctx, d.uri.String(), d, authMiddle, cookies, d.userDialOpts...)
if err != nil {
return nil, err
}
Expand All @@ -435,7 +437,7 @@ func (d *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) {
}
// use the existing auth token if there is one
cl, err := getFlightClient(context.Background(), uri, d,
&bearerAuthMiddleware{hdrs: authMiddle.hdrs.Copy()}, cookieMiddleware)
&bearerAuthMiddleware{hdrs: authMiddle.hdrs.Copy()}, cookieMiddleware, d.userDialOpts...)
if err != nil {
return nil, err
}
Expand Down
14 changes: 12 additions & 2 deletions go/adbc/driver/flightsql/flightsql_driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase"
"github.com/apache/arrow-go/v18/arrow/memory"
"golang.org/x/exp/maps"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)

Expand Down Expand Up @@ -77,12 +78,15 @@ type driverImpl struct {
}

// NewDriver creates a new Flight SQL driver using the given Arrow allocator.
//
// It optionally accepts gRPC dial options to be used when connecting to the
func NewDriver(alloc memory.Allocator) adbc.Driver {
info := driverbase.DefaultDriverInfo("Flight SQL")
return driverbase.NewDriver(&driverImpl{DriverImplBase: driverbase.NewDriverImplBase(info, alloc)})
}

func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error) {
// NewDatabase creates a new Flight SQL database using the given options.
func (d *driverImpl) NewDatabaseWithOptions(opts map[string]string, userDialOpts ...grpc.DialOption) (adbc.Database, error) {
opts = maps.Clone(opts)
uri, ok := opts[adbc.OptionKeyURI]
if !ok {
Expand All @@ -99,7 +103,8 @@ func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error)
// Match gRPC default
connectTimeout: time.Second * 20,
},
hdrs: make(metadata.MD),
hdrs: make(metadata.MD),
userDialOpts: userDialOpts,
}

var err error
Expand All @@ -118,3 +123,8 @@ func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error)

return driverbase.NewDatabase(db), nil
}

// NewDatabase creates a new Flight SQL database using the given options.
func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error) {
return d.NewDatabaseWithOptions(opts)
}

0 comments on commit 50ffc95

Please sign in to comment.