Skip to content

Commit

Permalink
feat(go/adbc)!: close database explicitly (apache#1460)
Browse files Browse the repository at this point in the history
Implicit database release behaves inconsistently on different OS, which
leads to bugs.

BREAKING CHANGE: adds Close to the Database interface.
Closes apache#1306.

---------

Co-authored-by: Matt Topol <zotthewizard@gmail.com>
  • Loading branch information
levakin and zeroshade authored Jan 19, 2024
1 parent 046f8b6 commit 3aa0d12
Show file tree
Hide file tree
Showing 24 changed files with 148 additions and 67 deletions.
1 change: 1 addition & 0 deletions docs/source/driver/duckdb.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ ADBC support in DuckDB requires the driver manager.
if err != nil {
// handle error
}
defer db.Close()
cnxn, err := db.Open(context.Background())
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions docs/source/driver/flight_sql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ the :cpp:class:`AdbcDatabase`.
if err != nil {
// do something with the error
}
defer db.Close()
cnxn, err := db.Open(context.Background())
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions docs/source/driver/postgresql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ the :cpp:class:`AdbcDatabase`. This should be a `connection URI
if err != nil {
// handle error
}
defer db.Close()
cnxn, err := db.Open(context.Background())
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions docs/source/driver/snowflake.rst
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ constructing the :cpp::class:`AdbcDatabase`.
if err != nil {
// handle error
}
defer db.Close()
cnxn, err := db.Open(context.Background())
if err != nil {
Expand Down Expand Up @@ -241,6 +242,7 @@ a listing).
if err != nil {
// handle error
}
defer db.Close()
cnxn, err := db.Open(context.Background())
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions docs/source/driver/sqlite.rst
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ shared across all connections.
if err != nil {
// handle error
}
defer db.Close()
cnxn, err := db.Open(context.Background())
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions go/adbc/adbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,9 @@ type Driver interface {
type Database interface {
SetOptions(map[string]string) error
Open(ctx context.Context) (Connection, error)

// Close closes this database and releases any associated resources.
Close() error
}

type InfoCode uint32
Expand Down
5 changes: 5 additions & 0 deletions go/adbc/driver/driverbase/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type DatabaseImpl interface {
adbc.GetSetOptions
Base() *DatabaseImplBase
Open(context.Context) (adbc.Connection, error)
Close() error
SetOptions(map[string]string) error
}

Expand Down Expand Up @@ -134,6 +135,10 @@ func (db *database) Open(ctx context.Context) (adbc.Connection, error) {
return db.impl.Open(ctx)
}

func (db *database) Close() error {
return db.impl.Close()
}

func (db *database) SetLogger(logger *slog.Logger) {
if logger != nil {
db.impl.Base().Logger = logger
Expand Down
4 changes: 2 additions & 2 deletions go/adbc/driver/driverbase/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type DriverImpl interface {
NewDatabase(opts map[string]string) (adbc.Database, error)
}

// DatabaseImplBase is a struct that provides default implementations of the
// DriverImplBase is a struct that provides default implementations of the
// DriverImpl interface. It is meant to be used as a composite struct for a
// driver's DriverImpl implementation.
type DriverImplBase struct {
Expand All @@ -56,7 +56,7 @@ type driver struct {
impl DriverImpl
}

// NewDatabase wraps a DriverImpl to create an adbc.Driver.
// NewDriver wraps a DriverImpl to create an adbc.Driver.
func NewDriver(impl DriverImpl) adbc.Driver {
return &driver{impl}
}
Expand Down
1 change: 1 addition & 0 deletions go/adbc/driver/flightsql/flightsql_adbc_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ func (suite *ServerBasedTests) TearDownTest() {
}

func (suite *ServerBasedTests) TearDownSuite() {
suite.NoError(suite.db.Close())
suite.db = nil
suite.s.Shutdown()
}
Expand Down
9 changes: 9 additions & 0 deletions go/adbc/driver/flightsql/flightsql_adbc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ func (suite *DefaultDialOptionsTests) SetupSuite() {

func (suite *DefaultDialOptionsTests) TearDownSuite() {
suite.Quirks.TearDownDriver(suite.T(), suite.Driver)
suite.NoError(suite.DB.Close())
suite.DB = nil
suite.Driver = nil
}
Expand All @@ -361,6 +362,7 @@ func (suite *DefaultDialOptionsTests) TestMaxIncomingMessageSizeDefault() {
opts["adbc.flight.sql.client_option.with_max_msg_size"] = "1000000"
db, err := suite.Driver.NewDatabase(opts)
suite.NoError(err)
defer suite.NoError(db.Close())

cnxn, err := db.Open(suite.ctx)
suite.NoError(err)
Expand Down Expand Up @@ -505,6 +507,7 @@ func (suite *PartitionTests) TearDownTest() {
suite.Require().NoError(suite.Cnxn.Close())
suite.Quirks.TearDownDriver(suite.T(), suite.Driver)
suite.Cnxn = nil
suite.NoError(suite.DB.Close())
suite.DB = nil
suite.Driver = nil
}
Expand Down Expand Up @@ -558,6 +561,7 @@ func (suite *StatementTests) TearDownTest() {
suite.Require().NoError(suite.Cnxn.Close())
suite.Quirks.TearDownDriver(suite.T(), suite.Driver)
suite.Cnxn = nil
suite.NoError(suite.DB.Close())
suite.DB = nil
suite.Driver = nil
}
Expand Down Expand Up @@ -639,6 +643,7 @@ func (suite *HeaderTests) TearDownTest() {
suite.Require().NoError(suite.Cnxn.Close())
suite.Quirks.TearDownDriver(suite.T(), suite.Driver)
suite.Cnxn = nil
suite.NoError(suite.DB.Close())
suite.DB = nil
suite.Driver = nil
}
Expand Down Expand Up @@ -842,6 +847,7 @@ func (suite *TLSTests) TearDownTest() {
suite.Require().NoError(suite.Cnxn.Close())
suite.Quirks.TearDownDriver(suite.T(), suite.Driver)
suite.Cnxn = nil
suite.NoError(suite.DB.Close())
suite.DB = nil
suite.Driver = nil
}
Expand All @@ -863,6 +869,7 @@ func (suite *TLSTests) TestInvalidOptions() {
"adbc.flight.sql.client_option.tls_skip_verify": "false",
})
suite.Require().NoError(err)
defer suite.NoError(db.Close())

cnxn, err := db.Open(suite.ctx)
suite.Require().NoError(err)
Expand Down Expand Up @@ -912,6 +919,7 @@ func (suite *ConnectionTests) SetupSuite() {
}

func (suite *ConnectionTests) TearDownSuite() {
suite.NoError(suite.DB.Close())
suite.server.Shutdown()
suite.alloc.AssertSize(suite.T(), 0)
}
Expand Down Expand Up @@ -1009,6 +1017,7 @@ func (suite *DomainSocketTests) SetupSuite() {
func (suite *DomainSocketTests) TearDownSuite() {
suite.Require().NoError(suite.Stmt.Close())
suite.Require().NoError(suite.Cnxn.Close())
suite.NoError(suite.DB.Close())
suite.server.Shutdown()
suite.alloc.AssertSize(suite.T(), 0)
}
Expand Down
20 changes: 12 additions & 8 deletions go/adbc/driver/flightsql/flightsql_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,10 @@ func (d *databaseImpl) SetOptionDouble(key string, value float64) error {
return d.DatabaseImplBase.SetOptionDouble(key, value)
}

func (d *databaseImpl) Close() error {
return nil
}

func getFlightClient(ctx context.Context, loc string, d *databaseImpl) (*flightsql.Client, error) {
authMiddle := &bearerAuthMiddleware{hdrs: d.hdrs.Copy()}
middleware := []flight.ClientMiddleware{
Expand Down Expand Up @@ -396,8 +400,8 @@ type support struct {
transactions bool
}

func (impl *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) {
cl, err := getFlightClient(ctx, impl.uri.String(), impl)
func (d *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) {
cl, err := getFlightClient(ctx, d.uri.String(), d)
if err != nil {
return nil, err
}
Expand All @@ -410,12 +414,12 @@ func (impl *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) {
return nil, adbc.Error{Msg: fmt.Sprintf("Location must be a string, got %#v", uri), Code: adbc.StatusInternal}
}

cl, err := getFlightClient(context.Background(), uri, impl)
cl, err := getFlightClient(context.Background(), uri, d)
if err != nil {
return nil, err
}

cl.Alloc = impl.Alloc
cl.Alloc = d.Alloc
return cl, nil
}).
EvictedFunc(func(_, client interface{}) {
Expand All @@ -425,13 +429,13 @@ func (impl *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) {

var cnxnSupport support

info, err := cl.GetSqlInfo(ctx, []flightsql.SqlInfo{flightsql.SqlInfoFlightSqlServerTransaction}, impl.timeout)
info, err := cl.GetSqlInfo(ctx, []flightsql.SqlInfo{flightsql.SqlInfoFlightSqlServerTransaction}, d.timeout)
// ignore this if it fails
if err == nil {
const int32code = 3

for _, endpoint := range info.Endpoint {
rdr, err := doGet(ctx, cl, endpoint, cache, impl.timeout)
rdr, err := doGet(ctx, cl, endpoint, cache, d.timeout)
if err != nil {
continue
}
Expand Down Expand Up @@ -465,8 +469,8 @@ func (impl *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) {
}
}

return &cnxn{cl: cl, db: impl, clientCache: cache,
hdrs: make(metadata.MD), timeouts: impl.timeout,
return &cnxn{cl: cl, db: d, clientCache: cache,
hdrs: make(metadata.MD), timeouts: d.timeout,
supportInfo: cnxnSupport}, nil
}

Expand Down
1 change: 1 addition & 0 deletions go/adbc/driver/flightsql/flightsql_driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,5 +145,6 @@ func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error)
if err := db.SetOptions(opts); err != nil {
return nil, err
}

return driverbase.NewDatabase(db), nil
}
5 changes: 5 additions & 0 deletions go/adbc/driver/panicdummy/panicdummy_adbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ func (d *database) Open(ctx context.Context) (adbc.Connection, error) {
return &cnxn{}, nil
}

func (d *database) Close() error {
maybePanic("DatabaseClose")
return nil
}

type cnxn struct{}

func (c *cnxn) SetOption(key, value string) error {
Expand Down
6 changes: 3 additions & 3 deletions go/adbc/driver/snowflake/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -714,16 +714,16 @@ func prepareTablesSQL(matchingCatalogNames []string, catalog *string, dbSchema *

func prepareColumnsSQL(matchingCatalogNames []string, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (string, []interface{}) {
prefixQuery := ""
for _, catalog_name := range matchingCatalogNames {
for _, catalogName := range matchingCatalogNames {
if prefixQuery != "" {
prefixQuery += " UNION ALL "
}
prefixQuery += `SELECT T.table_type,
C.*
FROM
"` + strings.ReplaceAll(catalog_name, "\"", "\"\"") + `".INFORMATION_SCHEMA.TABLES AS T
"` + strings.ReplaceAll(catalogName, "\"", "\"\"") + `".INFORMATION_SCHEMA.TABLES AS T
JOIN
"` + strings.ReplaceAll(catalog_name, "\"", "\"\"") + `".INFORMATION_SCHEMA.COLUMNS AS C
"` + strings.ReplaceAll(catalogName, "\"", "\"\"") + `".INFORMATION_SCHEMA.COLUMNS AS C
ON
T.table_catalog = C.table_catalog
AND T.table_schema = C.table_schema
Expand Down
1 change: 1 addition & 0 deletions go/adbc/driver/snowflake/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,5 +202,6 @@ func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error)
if err := db.SetOptions(opts); err != nil {
return nil, err
}

return driverbase.NewDatabase(db), nil
}
62 changes: 33 additions & 29 deletions go/adbc/driver/snowflake/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ func (suite *SnowflakeTests) TearDownTest() {
}

func (suite *SnowflakeTests) TearDownSuite() {
suite.NoError(suite.db.Close())
suite.db = nil
}

Expand Down Expand Up @@ -464,21 +465,21 @@ func (suite *SnowflakeTests) TestMetadataGetObjectsColumnsXdbc() {
xdbcDateTimeSub []string
}{
{
"BASIC", //name
[]string{"int64s", "strings"}, //colNames
[]string{"1", "2"}, //positions
[]string{"NUMBER", "TEXT"}, //dataTypes
[]string{"", ""}, //comments
[]string{"9", "13"}, //xdbcDataType
[]string{"NUMBER", "TEXT"}, //xdbcTypeName
[]string{"-5", "12"}, //xdbcSqlDataType
[]string{"1", "1"}, //xdbcNullable
[]string{"YES", "YES"}, //xdbcIsNullable
[]string{"0", "0"}, //xdbcScale
[]string{"10", "0"}, //xdbcNumPrecRadix
[]string{"38", "16777216"}, //xdbcCharMaxLen (xdbcPrecision)
[]string{"0", "16777216"}, //xdbcCharOctetLen
[]string{"-5", "12", "0"}, //xdbcDateTimeSub
"BASIC", // name
[]string{"int64s", "strings"}, // colNames
[]string{"1", "2"}, // positions
[]string{"NUMBER", "TEXT"}, // dataTypes
[]string{"", ""}, // comments
[]string{"9", "13"}, // xdbcDataType
[]string{"NUMBER", "TEXT"}, // xdbcTypeName
[]string{"-5", "12"}, // xdbcSqlDataType
[]string{"1", "1"}, // xdbcNullable
[]string{"YES", "YES"}, // xdbcIsNullable
[]string{"0", "0"}, // xdbcScale
[]string{"10", "0"}, // xdbcNumPrecRadix
[]string{"38", "16777216"}, // xdbcCharMaxLen (xdbcPrecision)
[]string{"0", "16777216"}, // xdbcCharOctetLen
[]string{"-5", "12", "0"}, // xdbcDateTimeSub
},
}

Expand Down Expand Up @@ -576,20 +577,20 @@ func (suite *SnowflakeTests) TestMetadataGetObjectsColumnsXdbc() {

suite.False(rdr.Next())
suite.True(foundExpected)
suite.Equal(tt.colnames, colnames) //colNames
suite.Equal(tt.positions, positions) //positions
suite.Equal(tt.comments, comments) //comments
suite.Equal(tt.xdbcDataType, xdbcDataTypes) //xdbcDataType
suite.Equal(tt.dataTypes, dataTypes) //dataTypes
suite.Equal(tt.xdbcTypeName, xdbcTypeNames) //xdbcTypeName
suite.Equal(tt.xdbcCharMaxLen, xdbcCharMaxLens) //xdbcCharMaxLen
suite.Equal(tt.xdbcScale, xdbcScales) //xdbcScale
suite.Equal(tt.xdbcNumPrecRadix, xdbcNumPrecRadixs) //xdbcNumPrecRadix
suite.Equal(tt.xdbcNullable, xdbcNullables) //xdbcNullable
suite.Equal(tt.xdbcSqlDataType, xdbcSqlDataTypes) //xdbcSqlDataType
suite.Equal(tt.xdbcDateTimeSub, xdbcDateTimeSub) //xdbcDateTimeSub
suite.Equal(tt.xdbcCharOctetLen, xdbcCharOctetLen) //xdbcCharOctetLen
suite.Equal(tt.xdbcIsNullable, xdbcIsNullables) //xdbcIsNullable
suite.Equal(tt.colnames, colnames) // colNames
suite.Equal(tt.positions, positions) // positions
suite.Equal(tt.comments, comments) // comments
suite.Equal(tt.xdbcDataType, xdbcDataTypes) // xdbcDataType
suite.Equal(tt.dataTypes, dataTypes) // dataTypes
suite.Equal(tt.xdbcTypeName, xdbcTypeNames) // xdbcTypeName
suite.Equal(tt.xdbcCharMaxLen, xdbcCharMaxLens) // xdbcCharMaxLen
suite.Equal(tt.xdbcScale, xdbcScales) // xdbcScale
suite.Equal(tt.xdbcNumPrecRadix, xdbcNumPrecRadixs) // xdbcNumPrecRadix
suite.Equal(tt.xdbcNullable, xdbcNullables) // xdbcNullable
suite.Equal(tt.xdbcSqlDataType, xdbcSqlDataTypes) // xdbcSqlDataType
suite.Equal(tt.xdbcDateTimeSub, xdbcDateTimeSub) // xdbcDateTimeSub
suite.Equal(tt.xdbcCharOctetLen, xdbcCharOctetLen) // xdbcCharOctetLen
suite.Equal(tt.xdbcIsNullable, xdbcIsNullables) // xdbcIsNullable

})
}
Expand All @@ -605,6 +606,7 @@ func (suite *SnowflakeTests) TestNewDatabaseGetSetOptions() {
})
suite.NoError(err)
suite.NotNil(db)
defer suite.NoError(db.Close())

getSetDB, ok := db.(adbc.GetSetOptions)
suite.True(ok)
Expand Down Expand Up @@ -862,6 +864,7 @@ func ConnectWithJwt(uri, keyValue, passcode string) {
if err != nil {
panic(err)
}
defer db.Close()

cnxn, err := db.Open(context.Background())
if err != nil {
Expand Down Expand Up @@ -912,6 +915,7 @@ func (suite *SnowflakeTests) TestJwtPrivateKey() {
opts[driver.OptionJwtPrivateKey] = keyFile
db, err := suite.driver.NewDatabase(opts)
suite.NoError(err)
defer db.Close()
cnxn, err := db.Open(suite.ctx)
suite.NoError(err)
defer cnxn.Close()
Expand Down
4 changes: 4 additions & 0 deletions go/adbc/driver/snowflake/snowflake_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,7 @@ func (d *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) {
useHighPrecision: d.useHighPrecision,
}, nil
}

func (d *databaseImpl) Close() error {
return nil
}
Loading

0 comments on commit 3aa0d12

Please sign in to comment.