From 1cbfb2f2e457ddde109f5324ed800c99acc6c576 Mon Sep 17 00:00:00 2001
From: David Li
Date: Mon, 11 Mar 2024 11:06:37 -0400
Subject: [PATCH] feat(go/adbc/driver/flightsql): support session options
(#1597)
Fixes #1557.
---
c/validation/adbc_validation_connection.cc | 2 +-
.../flightsql/flightsql_adbc_server_test.go | 217 +++++++++++++
.../driver/flightsql/flightsql_connection.go | 293 +++++++++++++++++-
go/adbc/driver/flightsql/flightsql_driver.go | 39 ++-
.../adbc_driver_flightsql/__init__.py | 12 +
5 files changed, 538 insertions(+), 25 deletions(-)
diff --git a/c/validation/adbc_validation_connection.cc b/c/validation/adbc_validation_connection.cc
index 7a438f1001..4ed1d0e098 100644
--- a/c/validation/adbc_validation_connection.cc
+++ b/c/validation/adbc_validation_connection.cc
@@ -151,7 +151,7 @@ void ConnectionTest::TestMetadataCurrentCatalog() {
ASSERT_THAT(
AdbcConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_CURRENT_CATALOG,
buffer, &buffer_size, &error),
- IsStatus(ADBC_STATUS_NOT_FOUND));
+ IsStatus(ADBC_STATUS_NOT_FOUND, &error));
}
}
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
index 3b5761495a..e70694e395 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -21,6 +21,7 @@ package flightsql_test
import (
"context"
+ "encoding/json"
"errors"
"fmt"
"net"
@@ -41,6 +42,7 @@ import (
"github.com/apache/arrow/go/v16/arrow/flight"
"github.com/apache/arrow/go/v16/arrow/flight/flightsql"
"github.com/apache/arrow/go/v16/arrow/flight/flightsql/schema_ref"
+ flightproto "github.com/apache/arrow/go/v16/arrow/flight/gen/flight"
"github.com/apache/arrow/go/v16/arrow/memory"
"github.com/golang/protobuf/ptypes/wrappers"
"github.com/stretchr/testify/suite"
@@ -134,6 +136,10 @@ func TestMultiTable(t *testing.T) {
suite.Run(t, &MultiTableTests{})
}
+func TestSessionOptions(t *testing.T) {
+ suite.Run(t, &SessionOptionTests{})
+}
+
// ---- AuthN Tests --------------------
type AuthnTestServer struct {
@@ -1654,3 +1660,214 @@ func (suite *MultiTableTests) TestGetTableSchema() {
expectedSchema := arrow.NewSchema([]arrow.Field{{Name: "b", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
suite.Equal(expectedSchema, actualSchema)
}
+
+// ---- Session Option Tests --------------------
+
+type SessionOptionTestServer struct {
+ flightsql.BaseServer
+ options map[string]interface{}
+}
+
+func (server *SessionOptionTestServer) GetSessionOptions(ctx context.Context, req *flight.GetSessionOptionsRequest) (*flight.GetSessionOptionsResult, error) {
+ options := make(map[string]*flight.SessionOptionValue)
+ for k, v := range server.options {
+ switch s := v.(type) {
+ case bool:
+ options[k] = &flight.SessionOptionValue{OptionValue: &flightproto.SessionOptionValue_BoolValue{BoolValue: s}}
+ case float64:
+ options[k] = &flight.SessionOptionValue{OptionValue: &flightproto.SessionOptionValue_DoubleValue{DoubleValue: s}}
+ case int64:
+ options[k] = &flight.SessionOptionValue{OptionValue: &flightproto.SessionOptionValue_Int64Value{Int64Value: s}}
+ case string:
+ options[k] = &flight.SessionOptionValue{OptionValue: &flightproto.SessionOptionValue_StringValue{StringValue: s}}
+ case []string:
+ options[k] = &flight.SessionOptionValue{OptionValue: &flightproto.SessionOptionValue_StringListValue_{StringListValue: &flightproto.SessionOptionValue_StringListValue{Values: s}}}
+ case nil:
+ options[k] = &flight.SessionOptionValue{}
+ default:
+ panic("not implemented")
+ }
+ }
+ return &flight.GetSessionOptionsResult{
+ SessionOptions: options,
+ }, nil
+}
+
+func (server *SessionOptionTestServer) SetSessionOptions(ctx context.Context, req *flight.SetSessionOptionsRequest) (*flight.SetSessionOptionsResult, error) {
+ errors := map[string]*flightproto.SetSessionOptionsResult_Error{}
+ for k, v := range req.SessionOptions {
+ switch k {
+ case "bad name":
+ errors[k] = &flightproto.SetSessionOptionsResult_Error{Value: flightproto.SetSessionOptionsResult_INVALID_NAME}
+ continue
+ case "bad value":
+ errors[k] = &flightproto.SetSessionOptionsResult_Error{Value: flightproto.SetSessionOptionsResult_INVALID_VALUE}
+ continue
+ case "error":
+ errors[k] = &flightproto.SetSessionOptionsResult_Error{Value: flightproto.SetSessionOptionsResult_ERROR}
+ continue
+ }
+ switch s := v.GetOptionValue().(type) {
+ case *flightproto.SessionOptionValue_BoolValue:
+ server.options[k] = s.BoolValue
+ case *flightproto.SessionOptionValue_DoubleValue:
+ server.options[k] = s.DoubleValue
+ case *flightproto.SessionOptionValue_Int64Value:
+ server.options[k] = s.Int64Value
+ case *flightproto.SessionOptionValue_StringValue:
+ server.options[k] = s.StringValue
+ case *flightproto.SessionOptionValue_StringListValue_:
+ server.options[k] = s.StringListValue.Values
+ case nil:
+ delete(server.options, k)
+ default:
+ return nil, status.Error(codes.InvalidArgument, "invalid option type")
+ }
+ }
+ return &flight.SetSessionOptionsResult{Errors: errors}, nil
+}
+
+func (server *SessionOptionTestServer) CloseSession(ctx context.Context, req *flight.CloseSessionRequest) (*flight.CloseSessionResult, error) {
+ return &flight.CloseSessionResult{
+ Status: flight.CloseSessionResultClosed,
+ }, nil
+}
+
+type SessionOptionTests struct {
+ ServerBasedTests
+}
+
+func (suite *SessionOptionTests) SetupSuite() {
+ suite.DoSetupSuite(&SessionOptionTestServer{
+ options: map[string]interface{}{
+ "string": "expected",
+ "bool": true,
+ "float64": float64(1.5),
+ "int64": int64(20),
+ "catalog": "main",
+ "schema": "session",
+ "stringlist": []string{"a", "b", "c"},
+ "nilopt": nil,
+ },
+ }, nil, map[string]string{})
+}
+
+func (suite *SessionOptionTests) TestGetAllOptions() {
+ val, err := suite.cnxn.(adbc.GetSetOptions).GetOption(driver.OptionSessionOptions)
+ suite.NoError(err)
+
+ options := make(map[string]interface{})
+ suite.NoError(json.Unmarshal([]byte(val), &options))
+ // XXX: because Go decodes ints to strings by default. Should we use
+ // an alternate representation? What happens to int64max?
+ suite.Equal(float64(20), options["int64"])
+ suite.Equal("expected", options["string"])
+ // Bit of a hack, but lets servers send "this option exists, but is
+ // not set" by returning a nil/unset value
+ suite.Nil(options["nilopt"])
+}
+
+func (suite *SessionOptionTests) TestGetAllOptionsByte() {
+ val, err := suite.cnxn.(adbc.GetSetOptions).GetOptionBytes(driver.OptionSessionOptions)
+ suite.NoError(err)
+
+ options := make(map[string]interface{})
+ // XXX: maybe we can return the underlying proto repr here?
+ suite.NoError(json.Unmarshal(val, &options))
+ suite.Equal(float64(20), options["int64"])
+ suite.Equal("expected", options["string"])
+}
+
+func (suite *SessionOptionTests) TestGetSetCatalog() {
+ val, err := suite.cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentCatalog)
+ suite.NoError(err)
+ suite.Equal("main", val)
+
+ suite.NoError(suite.cnxn.(adbc.GetSetOptions).SetOption(adbc.OptionKeyCurrentCatalog, "postgres"))
+ val, err = suite.cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentCatalog)
+ suite.NoError(err)
+ suite.Equal("postgres", val)
+}
+
+func (suite *SessionOptionTests) TestGetSetSchema() {
+ val, err := suite.cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentDbSchema)
+ suite.NoError(err)
+ suite.Equal("session", val)
+
+ suite.NoError(suite.cnxn.(adbc.GetSetOptions).SetOption(adbc.OptionKeyCurrentDbSchema, "public"))
+ val, err = suite.cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentDbSchema)
+ suite.NoError(err)
+ suite.Equal("public", val)
+}
+
+func (suite *SessionOptionTests) TestGetSetBool() {
+ o := suite.cnxn.(adbc.GetSetOptions)
+ val, err := o.GetOption(driver.OptionBoolSessionOptionPrefix + "bool")
+ suite.NoError(err)
+ suite.Equal("true", val)
+
+ suite.NoError(o.SetOption(driver.OptionBoolSessionOptionPrefix+"bool", "false"))
+ val, err = o.GetOption(driver.OptionBoolSessionOptionPrefix + "bool")
+ suite.NoError(err)
+ suite.Equal("false", val)
+}
+
+func (suite *SessionOptionTests) TestGetSetFloat64() {
+ o := suite.cnxn.(adbc.GetSetOptions)
+ val, err := o.GetOptionDouble(driver.OptionSessionOptionPrefix + "float64")
+ suite.NoError(err)
+ suite.Equal(1.5, val)
+
+ suite.NoError(o.SetOptionDouble(driver.OptionSessionOptionPrefix+"float64", -42.0))
+ val, err = o.GetOptionDouble(driver.OptionSessionOptionPrefix + "float64")
+ suite.NoError(err)
+ suite.Equal(-42.0, val)
+}
+
+func (suite *SessionOptionTests) TestGetSetInt64() {
+ o := suite.cnxn.(adbc.GetSetOptions)
+ val, err := o.GetOptionInt(driver.OptionSessionOptionPrefix + "int64")
+ suite.NoError(err)
+ suite.Equal(int64(20), val)
+
+ suite.NoError(o.SetOptionInt(driver.OptionSessionOptionPrefix+"int64", 128))
+ val, err = o.GetOptionInt(driver.OptionSessionOptionPrefix + "int64")
+ suite.NoError(err)
+ suite.Equal(int64(128), val)
+}
+
+func (suite *SessionOptionTests) TestGetSetString() {
+ o := suite.cnxn.(adbc.GetSetOptions)
+ _, err := o.GetOption(driver.OptionSessionOptionPrefix + "unknown")
+ suite.ErrorContains(err, "unknown session option 'unknown'")
+
+ suite.NoError(o.SetOption(driver.OptionSessionOptionPrefix+"unknown", "42"))
+ val, err := o.GetOption(driver.OptionSessionOptionPrefix + "unknown")
+ suite.NoError(err)
+ suite.Equal("42", val)
+
+ suite.NoError(o.SetOption(driver.OptionEraseSessionOptionPrefix+"unknown", ""))
+ _, err = o.GetOption(driver.OptionSessionOptionPrefix + "unknown")
+ suite.ErrorContains(err, "unknown session option 'unknown'")
+
+ suite.ErrorContains(o.SetOption(driver.OptionSessionOptionPrefix+"bad name", ""), "Could not set option(s) 'bad name' (invalid name)")
+ suite.ErrorContains(o.SetOption(driver.OptionSessionOptionPrefix+"bad value", ""), "Could not set option(s) 'bad value' (invalid value)")
+ suite.ErrorContains(o.SetOption(driver.OptionSessionOptionPrefix+"error", ""), "Could not set option(s) 'error' (error setting option)")
+}
+
+func (suite *SessionOptionTests) TestGetSetStringList() {
+ o := suite.cnxn.(adbc.GetSetOptions)
+ val, err := o.GetOption(driver.OptionStringListSessionOptionPrefix + "stringlist")
+ suite.NoError(err)
+ suite.Equal(`["a","b","c"]`, val)
+
+ suite.NoError(o.SetOption(driver.OptionStringListSessionOptionPrefix+"stringlist", `["foo", "bar"]`))
+ val, err = o.GetOption(driver.OptionStringListSessionOptionPrefix + "stringlist")
+ suite.NoError(err)
+ suite.Equal(`["foo","bar"]`, val)
+
+ suite.NoError(o.SetOption(driver.OptionStringListSessionOptionPrefix+"stringlist", `[]`))
+ val, err = o.GetOption(driver.OptionStringListSessionOptionPrefix + "stringlist")
+ suite.NoError(err)
+ suite.Equal(`[]`, val)
+}
diff --git a/go/adbc/driver/flightsql/flightsql_connection.go b/go/adbc/driver/flightsql/flightsql_connection.go
index d0aa0b02bb..e71ac308df 100644
--- a/go/adbc/driver/flightsql/flightsql_connection.go
+++ b/go/adbc/driver/flightsql/flightsql_connection.go
@@ -20,6 +20,7 @@ package flightsql
import (
"bytes"
"context"
+ "encoding/json"
"fmt"
"io"
"math"
@@ -32,6 +33,7 @@ import (
"github.com/apache/arrow/go/v16/arrow/flight"
"github.com/apache/arrow/go/v16/arrow/flight/flightsql"
"github.com/apache/arrow/go/v16/arrow/flight/flightsql/schema_ref"
+ flightproto "github.com/apache/arrow/go/v16/arrow/flight/gen/flight"
"github.com/apache/arrow/go/v16/arrow/ipc"
"github.com/bluele/gcache"
"google.golang.org/grpc"
@@ -95,6 +97,115 @@ func doGet(ctx context.Context, cl *flightsql.Client, endpoint *flight.FlightEnd
return nil, err
}
+func (c *cnxn) getSessionOptions(ctx context.Context) (map[string]interface{}, error) {
+ ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
+ var header, trailer metadata.MD
+ rawOptions, err := c.cl.GetSessionOptions(ctx, &flight.GetSessionOptionsRequest{}, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts)
+ if err != nil {
+ // We're going to make a bit of a concession to backwards compatibility
+ // here and ignore UNIMPLEMENTED or INVALID_ARGUMENT
+ grpcStatus := grpcstatus.Convert(err)
+ if grpcStatus.Code() == grpccodes.InvalidArgument || grpcStatus.Code() == grpccodes.Unimplemented {
+ return map[string]interface{}{}, nil
+ }
+ return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetSessionOptions")
+ }
+
+ options := make(map[string]interface{}, len(rawOptions.SessionOptions))
+ for k, rawValue := range rawOptions.SessionOptions {
+ switch v := rawValue.OptionValue.(type) {
+ case *flightproto.SessionOptionValue_BoolValue:
+ options[k] = v.BoolValue
+ case *flightproto.SessionOptionValue_DoubleValue:
+ options[k] = v.DoubleValue
+ case *flightproto.SessionOptionValue_Int64Value:
+ options[k] = v.Int64Value
+ case *flightproto.SessionOptionValue_StringValue:
+ options[k] = v.StringValue
+ case *flightproto.SessionOptionValue_StringListValue_:
+ if v.StringListValue.Values == nil {
+ options[k] = make([]string, 0)
+ } else {
+ options[k] = v.StringListValue.Values
+ }
+ case nil:
+ options[k] = nil
+ default:
+ return nil, adbc.Error{
+ Code: adbc.StatusNotImplemented,
+ Msg: fmt.Sprintf("[FlightSQL] Unknown session option type %#v", rawValue),
+ }
+ }
+ }
+ return options, nil
+}
+
+func (c *cnxn) setSessionOptions(ctx context.Context, key string, val interface{}) error {
+ req := flight.SetSessionOptionsRequest{}
+
+ var err error
+ req.SessionOptions, err = flight.NewSessionOptionValues(map[string]any{key: val})
+ if err != nil {
+ return adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] Invalid session option %s=%#v: %s", key, val, err.Error()),
+ Code: adbc.StatusInvalidArgument,
+ }
+ }
+
+ var header, trailer metadata.MD
+ errors, err := c.cl.SetSessionOptions(ctx, &req, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts)
+ if err != nil {
+ return adbcFromFlightStatusWithDetails(err, header, trailer, "GetSessionOptions")
+ }
+ if len(errors.Errors) > 0 {
+ msg := strings.Builder{}
+ fmt.Fprint(&msg, "[Flight SQL] Could not set option(s) ")
+
+ first := true
+ for k, v := range errors.Errors {
+ if !first {
+ fmt.Fprint(&msg, ", ")
+ }
+ first = false
+
+ errmsg := "unknown error"
+ switch v.Value {
+ case flightproto.SetSessionOptionsResult_INVALID_NAME:
+ errmsg = "invalid name"
+ case flightproto.SetSessionOptionsResult_INVALID_VALUE:
+ errmsg = "invalid value"
+ case flightproto.SetSessionOptionsResult_ERROR:
+ errmsg = "error setting option"
+ }
+ fmt.Fprintf(&msg, "'%s' (%s)", k, errmsg)
+ }
+
+ return adbc.Error{
+ Msg: msg.String(),
+ Code: adbc.StatusInvalidArgument,
+ }
+ }
+ return nil
+}
+
+func getSessionOption[T any](options map[string]interface{}, key string, defaultVal T, valueType string) (T, error) {
+ rawValue, ok := options[key]
+ if !ok {
+ return defaultVal, adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] unknown session option '%s'", key),
+ Code: adbc.StatusNotFound,
+ }
+ }
+ value, ok := rawValue.(T)
+ if !ok {
+ return defaultVal, adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] session option %s=%#v is not %s value", key, rawValue, valueType),
+ Code: adbc.StatusNotFound,
+ }
+ }
+ return value, nil
+}
+
func (c *cnxn) GetOption(key string) (string, error) {
if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) {
name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix)
@@ -124,16 +235,96 @@ func (c *cnxn) GetOption(key string) (string, error) {
return adbc.OptionValueEnabled, nil
}
case adbc.OptionKeyCurrentCatalog:
+ options, err := c.getSessionOptions(context.Background())
+ if err != nil {
+ return "", err
+ }
+ if catalog, ok := options["catalog"]; ok {
+ if val, ok := catalog.(string); ok {
+ return val, nil
+ }
+ return "", adbc.Error{
+ Msg: fmt.Sprintf("[FlightSQL] Server returned non-string catalog %#v", catalog),
+ Code: adbc.StatusInternal,
+ }
+ }
return "", adbc.Error{
- Msg: "[Flight SQL] current catalog not supported",
+ Msg: "[FlightSQL] current catalog not supported",
Code: adbc.StatusNotFound,
}
case adbc.OptionKeyCurrentDbSchema:
+ options, err := c.getSessionOptions(context.Background())
+ if err != nil {
+ return "", err
+ }
+ if schema, ok := options["schema"]; ok {
+ if val, ok := schema.(string); ok {
+ return val, nil
+ }
+ return "", adbc.Error{
+ Msg: fmt.Sprintf("[FlightSQL] Server returned non-string schema %#v", schema),
+ Code: adbc.StatusInternal,
+ }
+ }
return "", adbc.Error{
- Msg: "[Flight SQL] current schema not supported",
+ Msg: "[FlightSQL] current schema not supported",
Code: adbc.StatusNotFound,
}
+ case OptionSessionOptions:
+ options, err := c.getSessionOptions(context.Background())
+ if err != nil {
+ return "", err
+ }
+ encoded, err := json.Marshal(options)
+ if err != nil {
+ return "", adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] Could not encode option values: %s", err.Error()),
+ Code: adbc.StatusInternal,
+ }
+ }
+ return string(encoded), nil
+ }
+ switch {
+ case strings.HasPrefix(key, OptionSessionOptionPrefix):
+ options, err := c.getSessionOptions(context.Background())
+ if err != nil {
+ return "", err
+ }
+ name := key[len(OptionSessionOptionPrefix):]
+ return getSessionOption(options, name, "", "a string")
+ case strings.HasPrefix(key, OptionBoolSessionOptionPrefix):
+ options, err := c.getSessionOptions(context.Background())
+ if err != nil {
+ return "", err
+ }
+ name := key[len(OptionBoolSessionOptionPrefix):]
+ v, err := getSessionOption(options, name, false, "a boolean")
+ if err != nil {
+ return "", err
+ }
+ if v {
+ return adbc.OptionValueEnabled, nil
+ }
+ return adbc.OptionValueDisabled, nil
+ case strings.HasPrefix(key, OptionStringListSessionOptionPrefix):
+ options, err := c.getSessionOptions(context.Background())
+ if err != nil {
+ return "", err
+ }
+ name := key[len(OptionStringListSessionOptionPrefix):]
+ v, err := getSessionOption[[]string](options, name, nil, "a string list")
+ if err != nil {
+ return "", err
+ }
+ encoded, err := json.Marshal(v)
+ if err != nil {
+ return "", adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] Could not encode option value: %s", err.Error()),
+ Code: adbc.StatusInternal,
+ }
+ }
+ return string(encoded), nil
}
return "", adbc.Error{
@@ -143,6 +334,22 @@ func (c *cnxn) GetOption(key string) (string, error) {
}
func (c *cnxn) GetOptionBytes(key string) ([]byte, error) {
+ switch key {
+ case OptionSessionOptions:
+ options, err := c.getSessionOptions(context.Background())
+ if err != nil {
+ return nil, err
+ }
+ encoded, err := json.Marshal(options)
+ if err != nil {
+ return nil, adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] Could not encode option values: %s", err.Error()),
+ Code: adbc.StatusInternal,
+ }
+ }
+ return encoded, nil
+ }
+
return nil, adbc.Error{
Msg: "[Flight SQL] unknown connection option",
Code: adbc.StatusNotFound,
@@ -162,6 +369,14 @@ func (c *cnxn) GetOptionInt(key string) (int64, error) {
}
return int64(val), nil
}
+ if strings.HasPrefix(key, OptionSessionOptionPrefix) {
+ options, err := c.getSessionOptions(context.Background())
+ if err != nil {
+ return 0, err
+ }
+ name := key[len(OptionSessionOptionPrefix):]
+ return getSessionOption(options, name, int64(0), "an integer")
+ }
return 0, adbc.Error{
Msg: "[Flight SQL] unknown connection option",
@@ -178,6 +393,14 @@ func (c *cnxn) GetOptionDouble(key string) (float64, error) {
case OptionTimeoutUpdate:
return c.timeouts.updateTimeout.Seconds(), nil
}
+ if strings.HasPrefix(key, OptionSessionOptionPrefix) {
+ options, err := c.getSessionOptions(context.Background())
+ if err != nil {
+ return 0, err
+ }
+ name := key[len(OptionSessionOptionPrefix):]
+ return getSessionOption(options, name, float64(0.0), "a floating-point")
+ }
return 0.0, adbc.Error{
Msg: "[Flight SQL] unknown connection option",
@@ -245,12 +468,47 @@ func (c *cnxn) SetOption(key, value string) error {
}
}
return nil
+ case adbc.OptionKeyCurrentCatalog:
+ return c.setSessionOptions(context.Background(), "catalog", value)
+ case adbc.OptionKeyCurrentDbSchema:
+ return c.setSessionOptions(context.Background(), "schema", value)
+ }
- default:
- return adbc.Error{
- Msg: "[Flight SQL] unknown connection option",
- Code: adbc.StatusNotImplemented,
+ switch {
+ case strings.HasPrefix(key, OptionSessionOptionPrefix):
+ name := key[len(OptionSessionOptionPrefix):]
+ return c.setSessionOptions(context.Background(), name, value)
+ case strings.HasPrefix(key, OptionBoolSessionOptionPrefix):
+ name := key[len(OptionBoolSessionOptionPrefix):]
+ switch value {
+ case adbc.OptionValueEnabled:
+ return c.setSessionOptions(context.Background(), name, true)
+ case adbc.OptionValueDisabled:
+ return c.setSessionOptions(context.Background(), name, false)
+ default:
+ return adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] invalid boolean session option value %s=%s", name, value),
+ Code: adbc.StatusNotImplemented,
+ }
}
+ case strings.HasPrefix(key, OptionStringListSessionOptionPrefix):
+ name := key[len(OptionStringListSessionOptionPrefix):]
+ stringlist := make([]string, 0)
+ if err := json.Unmarshal([]byte(value), &stringlist); err != nil {
+ return adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] invalid string list session option value %s=%s: %s", name, value, err.Error()),
+ Code: adbc.StatusNotImplemented,
+ }
+ }
+ return c.setSessionOptions(context.Background(), name, stringlist)
+ case strings.HasPrefix(key, OptionEraseSessionOptionPrefix):
+ name := key[len(OptionEraseSessionOptionPrefix):]
+ return c.setSessionOptions(context.Background(), name, nil)
+ }
+
+ return adbc.Error{
+ Msg: "[Flight SQL] unknown connection option",
+ Code: adbc.StatusNotImplemented,
}
}
@@ -266,6 +524,10 @@ func (c *cnxn) SetOptionInt(key string, value int64) error {
case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate:
return c.timeouts.setTimeout(key, float64(value))
}
+ if strings.HasPrefix(key, OptionSessionOptionPrefix) {
+ name := key[len(OptionSessionOptionPrefix):]
+ return c.setSessionOptions(context.Background(), name, value)
+ }
return adbc.Error{
Msg: "[Flight SQL] unknown connection option",
@@ -282,6 +544,10 @@ func (c *cnxn) SetOptionDouble(key string, value float64) error {
case OptionTimeoutUpdate:
return c.timeouts.setTimeout(key, value)
}
+ if strings.HasPrefix(key, OptionSessionOptionPrefix) {
+ name := key[len(OptionSessionOptionPrefix):]
+ return c.setSessionOptions(context.Background(), name, value)
+ }
return adbc.Error{
Msg: "[Flight SQL] unknown connection option",
@@ -937,7 +1203,20 @@ func (c *cnxn) Close() error {
}
}
- err := c.cl.Close()
+ ctx := metadata.NewOutgoingContext(context.Background(), c.hdrs)
+ var header, trailer metadata.MD
+ _, err := c.cl.CloseSession(ctx, &flight.CloseSessionRequest{}, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts)
+ if err != nil {
+ grpcStatus := grpcstatus.Convert(err)
+ // Ignore unimplemented
+ if grpcStatus.Code() != grpccodes.Unimplemented {
+ // Ignore the error since server may not support it and may not properly return UNIMPLEMENTED
+ // TODO(https://github.com/apache/arrow-adbc/issues/1243): log a proper warning
+ c.db.Logger.Debug("failed to close session", "error", err.Error())
+ }
+ }
+
+ err = c.cl.Close()
c.cl = nil
return adbcFromFlightStatus(err, "Close")
}
diff --git a/go/adbc/driver/flightsql/flightsql_driver.go b/go/adbc/driver/flightsql/flightsql_driver.go
index df1ae688b4..d437f0829b 100644
--- a/go/adbc/driver/flightsql/flightsql_driver.go
+++ b/go/adbc/driver/flightsql/flightsql_driver.go
@@ -45,23 +45,28 @@ import (
)
const (
- OptionAuthority = "adbc.flight.sql.client_option.authority"
- OptionMTLSCertChain = "adbc.flight.sql.client_option.mtls_cert_chain"
- OptionMTLSPrivateKey = "adbc.flight.sql.client_option.mtls_private_key"
- OptionSSLOverrideHostname = "adbc.flight.sql.client_option.tls_override_hostname"
- OptionSSLSkipVerify = "adbc.flight.sql.client_option.tls_skip_verify"
- OptionSSLRootCerts = "adbc.flight.sql.client_option.tls_root_certs"
- OptionWithBlock = "adbc.flight.sql.client_option.with_block"
- OptionWithMaxMsgSize = "adbc.flight.sql.client_option.with_max_msg_size"
- OptionAuthorizationHeader = "adbc.flight.sql.authorization_header"
- OptionTimeoutConnect = "adbc.flight.sql.rpc.timeout_seconds.connect"
- OptionTimeoutFetch = "adbc.flight.sql.rpc.timeout_seconds.fetch"
- OptionTimeoutQuery = "adbc.flight.sql.rpc.timeout_seconds.query"
- OptionTimeoutUpdate = "adbc.flight.sql.rpc.timeout_seconds.update"
- OptionRPCCallHeaderPrefix = "adbc.flight.sql.rpc.call_header."
- OptionCookieMiddleware = "adbc.flight.sql.rpc.with_cookie_middleware"
- OptionLastFlightInfo = "adbc.flight.sql.statement.exec.last_flight_info"
- infoDriverName = "ADBC Flight SQL Driver - Go"
+ OptionAuthority = "adbc.flight.sql.client_option.authority"
+ OptionMTLSCertChain = "adbc.flight.sql.client_option.mtls_cert_chain"
+ OptionMTLSPrivateKey = "adbc.flight.sql.client_option.mtls_private_key"
+ OptionSSLOverrideHostname = "adbc.flight.sql.client_option.tls_override_hostname"
+ OptionSSLSkipVerify = "adbc.flight.sql.client_option.tls_skip_verify"
+ OptionSSLRootCerts = "adbc.flight.sql.client_option.tls_root_certs"
+ OptionWithBlock = "adbc.flight.sql.client_option.with_block"
+ OptionWithMaxMsgSize = "adbc.flight.sql.client_option.with_max_msg_size"
+ OptionAuthorizationHeader = "adbc.flight.sql.authorization_header"
+ OptionTimeoutConnect = "adbc.flight.sql.rpc.timeout_seconds.connect"
+ OptionTimeoutFetch = "adbc.flight.sql.rpc.timeout_seconds.fetch"
+ OptionTimeoutQuery = "adbc.flight.sql.rpc.timeout_seconds.query"
+ OptionTimeoutUpdate = "adbc.flight.sql.rpc.timeout_seconds.update"
+ OptionRPCCallHeaderPrefix = "adbc.flight.sql.rpc.call_header."
+ OptionCookieMiddleware = "adbc.flight.sql.rpc.with_cookie_middleware"
+ OptionSessionOptions = "adbc.flight.sql.session.options"
+ OptionSessionOptionPrefix = "adbc.flight.sql.session.option."
+ OptionEraseSessionOptionPrefix = "adbc.flight.sql.session.optionerase."
+ OptionBoolSessionOptionPrefix = "adbc.flight.sql.session.optionbool."
+ OptionStringListSessionOptionPrefix = "adbc.flight.sql.session.optionstringlist."
+ OptionLastFlightInfo = "adbc.flight.sql.statement.exec.last_flight_info"
+ infoDriverName = "ADBC Flight SQL Driver - Go"
)
var (
diff --git a/python/adbc_driver_flightsql/adbc_driver_flightsql/__init__.py b/python/adbc_driver_flightsql/adbc_driver_flightsql/__init__.py
index e50f7c5aaf..7d45adf0b7 100644
--- a/python/adbc_driver_flightsql/adbc_driver_flightsql/__init__.py
+++ b/python/adbc_driver_flightsql/adbc_driver_flightsql/__init__.py
@@ -87,6 +87,18 @@ class ConnectionOptions(enum.Enum):
#:
#: Overrides any headers set via the equivalent database option.
RPC_CALL_HEADER_PREFIX = DatabaseOptions.RPC_CALL_HEADER_PREFIX.value
+ #: Get all session options as a JSON key-value blob.
+ OPTION_SESSION_OPTIONS = "adbc.flight.sql.session.options"
+ #: Get or set a session option.
+ OPTION_SESSION_OPTION_PREFIX = "adbc.flight.sql.session.option."
+ #: Erase a session option (use "" as the value).
+ OPTION_ERASE_SESSION_OPTION_PREFIX = "adbc.flight.sql.session.optionerase."
+ #: Get or set a boolean valued session option.
+ OPTION_BOOL_SESSION_OPTION_PREFIX = "adbc.flight.sql.session.optionbool."
+ #: Get or set a string-list-valued session option as a JSON array.
+ OPTION_STRING_LIST_SESSION_OPTION_PREFIX = (
+ "adbc.flight.sql.session.optionstringlist."
+ )
#: Set a timeout on calls that fetch data (in floating-point seconds).
#:
#: This corresponds to Flight RPC DoGet calls.