Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(go/adbc)!: close database explicitly #1460

Merged
merged 3 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/driver/duckdb.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ ADBC support in DuckDB requires the driver manager.
if err != nil {
// handle error
}
defer db.Close()

cnxn, err := db.Open(context.Background())
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions docs/source/driver/flight_sql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ the :cpp:class:`AdbcDatabase`.
if err != nil {
// do something with the error
}
defer db.Close()

cnxn, err := db.Open(context.Background())
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions docs/source/driver/postgresql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ the :cpp:class:`AdbcDatabase`. This should be a `connection URI
if err != nil {
// handle error
}
defer db.Close()

cnxn, err := db.Open(context.Background())
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions docs/source/driver/snowflake.rst
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ constructing the :cpp::class:`AdbcDatabase`.
if err != nil {
// handle error
}
defer db.Close()

cnxn, err := db.Open(context.Background())
if err != nil {
Expand Down Expand Up @@ -241,6 +242,7 @@ a listing).
if err != nil {
// handle error
}
defer db.Close()

cnxn, err := db.Open(context.Background())
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions docs/source/driver/sqlite.rst
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ shared across all connections.
if err != nil {
// handle error
}
defer db.Close()

cnxn, err := db.Open(context.Background())
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions go/adbc/adbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,9 @@ type Driver interface {
type Database interface {
SetOptions(map[string]string) error
Open(ctx context.Context) (Connection, error)

// Close closes this database and releases any associated resources.
Close() error
}

type InfoCode uint32
Expand Down
5 changes: 5 additions & 0 deletions go/adbc/driver/driverbase/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type DatabaseImpl interface {
adbc.GetSetOptions
Base() *DatabaseImplBase
Open(context.Context) (adbc.Connection, error)
Close() error
SetOptions(map[string]string) error
}

Expand Down Expand Up @@ -134,6 +135,10 @@ func (db *database) Open(ctx context.Context) (adbc.Connection, error) {
return db.impl.Open(ctx)
}

func (db *database) Close() error {
return db.impl.Close()
}

func (db *database) SetLogger(logger *slog.Logger) {
if logger != nil {
db.impl.Base().Logger = logger
Expand Down
4 changes: 2 additions & 2 deletions go/adbc/driver/driverbase/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type DriverImpl interface {
NewDatabase(opts map[string]string) (adbc.Database, error)
}

// DatabaseImplBase is a struct that provides default implementations of the
// DriverImplBase is a struct that provides default implementations of the
// DriverImpl interface. It is meant to be used as a composite struct for a
// driver's DriverImpl implementation.
type DriverImplBase struct {
Expand All @@ -56,7 +56,7 @@ type driver struct {
impl DriverImpl
}

// NewDatabase wraps a DriverImpl to create an adbc.Driver.
// NewDriver wraps a DriverImpl to create an adbc.Driver.
func NewDriver(impl DriverImpl) adbc.Driver {
return &driver{impl}
}
Expand Down
1 change: 1 addition & 0 deletions go/adbc/driver/flightsql/flightsql_adbc_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ func (suite *ServerBasedTests) TearDownTest() {
}

func (suite *ServerBasedTests) TearDownSuite() {
suite.NoError(suite.db.Close())
suite.db = nil
suite.s.Shutdown()
}
Expand Down
9 changes: 9 additions & 0 deletions go/adbc/driver/flightsql/flightsql_adbc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ func (suite *DefaultDialOptionsTests) SetupSuite() {

func (suite *DefaultDialOptionsTests) TearDownSuite() {
suite.Quirks.TearDownDriver(suite.T(), suite.Driver)
suite.NoError(suite.DB.Close())
suite.DB = nil
suite.Driver = nil
}
Expand All @@ -361,6 +362,7 @@ func (suite *DefaultDialOptionsTests) TestMaxIncomingMessageSizeDefault() {
opts["adbc.flight.sql.client_option.with_max_msg_size"] = "1000000"
db, err := suite.Driver.NewDatabase(opts)
suite.NoError(err)
defer suite.NoError(db.Close())

cnxn, err := db.Open(suite.ctx)
suite.NoError(err)
Expand Down Expand Up @@ -505,6 +507,7 @@ func (suite *PartitionTests) TearDownTest() {
suite.Require().NoError(suite.Cnxn.Close())
suite.Quirks.TearDownDriver(suite.T(), suite.Driver)
suite.Cnxn = nil
suite.NoError(suite.DB.Close())
suite.DB = nil
suite.Driver = nil
}
Expand Down Expand Up @@ -558,6 +561,7 @@ func (suite *StatementTests) TearDownTest() {
suite.Require().NoError(suite.Cnxn.Close())
suite.Quirks.TearDownDriver(suite.T(), suite.Driver)
suite.Cnxn = nil
suite.NoError(suite.DB.Close())
suite.DB = nil
suite.Driver = nil
}
Expand Down Expand Up @@ -639,6 +643,7 @@ func (suite *HeaderTests) TearDownTest() {
suite.Require().NoError(suite.Cnxn.Close())
suite.Quirks.TearDownDriver(suite.T(), suite.Driver)
suite.Cnxn = nil
suite.NoError(suite.DB.Close())
suite.DB = nil
suite.Driver = nil
}
Expand Down Expand Up @@ -842,6 +847,7 @@ func (suite *TLSTests) TearDownTest() {
suite.Require().NoError(suite.Cnxn.Close())
suite.Quirks.TearDownDriver(suite.T(), suite.Driver)
suite.Cnxn = nil
suite.NoError(suite.DB.Close())
suite.DB = nil
suite.Driver = nil
}
Expand All @@ -863,6 +869,7 @@ func (suite *TLSTests) TestInvalidOptions() {
"adbc.flight.sql.client_option.tls_skip_verify": "false",
})
suite.Require().NoError(err)
defer suite.NoError(db.Close())

cnxn, err := db.Open(suite.ctx)
suite.Require().NoError(err)
Expand Down Expand Up @@ -912,6 +919,7 @@ func (suite *ConnectionTests) SetupSuite() {
}

func (suite *ConnectionTests) TearDownSuite() {
suite.NoError(suite.DB.Close())
suite.server.Shutdown()
suite.alloc.AssertSize(suite.T(), 0)
}
Expand Down Expand Up @@ -1009,6 +1017,7 @@ func (suite *DomainSocketTests) SetupSuite() {
func (suite *DomainSocketTests) TearDownSuite() {
suite.Require().NoError(suite.Stmt.Close())
suite.Require().NoError(suite.Cnxn.Close())
suite.NoError(suite.DB.Close())
suite.server.Shutdown()
suite.alloc.AssertSize(suite.T(), 0)
}
Expand Down
20 changes: 12 additions & 8 deletions go/adbc/driver/flightsql/flightsql_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,10 @@ func (d *databaseImpl) SetOptionDouble(key string, value float64) error {
return d.DatabaseImplBase.SetOptionDouble(key, value)
}

func (d *databaseImpl) Close() error {
return nil
}

func getFlightClient(ctx context.Context, loc string, d *databaseImpl) (*flightsql.Client, error) {
authMiddle := &bearerAuthMiddleware{hdrs: d.hdrs.Copy()}
middleware := []flight.ClientMiddleware{
Expand Down Expand Up @@ -396,8 +400,8 @@ type support struct {
transactions bool
}

func (impl *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) {
cl, err := getFlightClient(ctx, impl.uri.String(), impl)
func (d *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) {
cl, err := getFlightClient(ctx, d.uri.String(), d)
if err != nil {
return nil, err
}
Expand All @@ -410,12 +414,12 @@ func (impl *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) {
return nil, adbc.Error{Msg: fmt.Sprintf("Location must be a string, got %#v", uri), Code: adbc.StatusInternal}
}

cl, err := getFlightClient(context.Background(), uri, impl)
cl, err := getFlightClient(context.Background(), uri, d)
if err != nil {
return nil, err
}

cl.Alloc = impl.Alloc
cl.Alloc = d.Alloc
return cl, nil
}).
EvictedFunc(func(_, client interface{}) {
Expand All @@ -425,13 +429,13 @@ func (impl *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) {

var cnxnSupport support

info, err := cl.GetSqlInfo(ctx, []flightsql.SqlInfo{flightsql.SqlInfoFlightSqlServerTransaction}, impl.timeout)
info, err := cl.GetSqlInfo(ctx, []flightsql.SqlInfo{flightsql.SqlInfoFlightSqlServerTransaction}, d.timeout)
// ignore this if it fails
if err == nil {
const int32code = 3

for _, endpoint := range info.Endpoint {
rdr, err := doGet(ctx, cl, endpoint, cache, impl.timeout)
rdr, err := doGet(ctx, cl, endpoint, cache, d.timeout)
if err != nil {
continue
}
Expand Down Expand Up @@ -465,8 +469,8 @@ func (impl *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) {
}
}

return &cnxn{cl: cl, db: impl, clientCache: cache,
hdrs: make(metadata.MD), timeouts: impl.timeout,
return &cnxn{cl: cl, db: d, clientCache: cache,
hdrs: make(metadata.MD), timeouts: d.timeout,
supportInfo: cnxnSupport}, nil
}

Expand Down
1 change: 1 addition & 0 deletions go/adbc/driver/flightsql/flightsql_driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,5 +145,6 @@ func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error)
if err := db.SetOptions(opts); err != nil {
return nil, err
}

return driverbase.NewDatabase(db), nil
}
5 changes: 5 additions & 0 deletions go/adbc/driver/panicdummy/panicdummy_adbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ func (d *database) Open(ctx context.Context) (adbc.Connection, error) {
return &cnxn{}, nil
}

func (d *database) Close() error {
maybePanic("DatabaseClose")
return nil
}

type cnxn struct{}

func (c *cnxn) SetOption(key, value string) error {
Expand Down
6 changes: 3 additions & 3 deletions go/adbc/driver/snowflake/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -714,16 +714,16 @@ func prepareTablesSQL(matchingCatalogNames []string, catalog *string, dbSchema *

func prepareColumnsSQL(matchingCatalogNames []string, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (string, []interface{}) {
prefixQuery := ""
for _, catalog_name := range matchingCatalogNames {
for _, catalogName := range matchingCatalogNames {
if prefixQuery != "" {
prefixQuery += " UNION ALL "
}
prefixQuery += `SELECT T.table_type,
C.*
FROM
"` + strings.ReplaceAll(catalog_name, "\"", "\"\"") + `".INFORMATION_SCHEMA.TABLES AS T
"` + strings.ReplaceAll(catalogName, "\"", "\"\"") + `".INFORMATION_SCHEMA.TABLES AS T
JOIN
"` + strings.ReplaceAll(catalog_name, "\"", "\"\"") + `".INFORMATION_SCHEMA.COLUMNS AS C
"` + strings.ReplaceAll(catalogName, "\"", "\"\"") + `".INFORMATION_SCHEMA.COLUMNS AS C
ON
T.table_catalog = C.table_catalog
AND T.table_schema = C.table_schema
Expand Down
1 change: 1 addition & 0 deletions go/adbc/driver/snowflake/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,5 +202,6 @@ func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error)
if err := db.SetOptions(opts); err != nil {
return nil, err
}

return driverbase.NewDatabase(db), nil
}
62 changes: 33 additions & 29 deletions go/adbc/driver/snowflake/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ func (suite *SnowflakeTests) TearDownTest() {
}

func (suite *SnowflakeTests) TearDownSuite() {
suite.NoError(suite.db.Close())
suite.db = nil
}

Expand Down Expand Up @@ -464,21 +465,21 @@ func (suite *SnowflakeTests) TestMetadataGetObjectsColumnsXdbc() {
xdbcDateTimeSub []string
}{
{
"BASIC", //name
[]string{"int64s", "strings"}, //colNames
[]string{"1", "2"}, //positions
[]string{"NUMBER", "TEXT"}, //dataTypes
[]string{"", ""}, //comments
[]string{"9", "13"}, //xdbcDataType
[]string{"NUMBER", "TEXT"}, //xdbcTypeName
[]string{"-5", "12"}, //xdbcSqlDataType
[]string{"1", "1"}, //xdbcNullable
[]string{"YES", "YES"}, //xdbcIsNullable
[]string{"0", "0"}, //xdbcScale
[]string{"10", "0"}, //xdbcNumPrecRadix
[]string{"38", "16777216"}, //xdbcCharMaxLen (xdbcPrecision)
[]string{"0", "16777216"}, //xdbcCharOctetLen
[]string{"-5", "12", "0"}, //xdbcDateTimeSub
"BASIC", // name
[]string{"int64s", "strings"}, // colNames
[]string{"1", "2"}, // positions
[]string{"NUMBER", "TEXT"}, // dataTypes
[]string{"", ""}, // comments
[]string{"9", "13"}, // xdbcDataType
[]string{"NUMBER", "TEXT"}, // xdbcTypeName
[]string{"-5", "12"}, // xdbcSqlDataType
[]string{"1", "1"}, // xdbcNullable
[]string{"YES", "YES"}, // xdbcIsNullable
[]string{"0", "0"}, // xdbcScale
[]string{"10", "0"}, // xdbcNumPrecRadix
[]string{"38", "16777216"}, // xdbcCharMaxLen (xdbcPrecision)
[]string{"0", "16777216"}, // xdbcCharOctetLen
[]string{"-5", "12", "0"}, // xdbcDateTimeSub
},
}

Expand Down Expand Up @@ -576,20 +577,20 @@ func (suite *SnowflakeTests) TestMetadataGetObjectsColumnsXdbc() {

suite.False(rdr.Next())
suite.True(foundExpected)
suite.Equal(tt.colnames, colnames) //colNames
suite.Equal(tt.positions, positions) //positions
suite.Equal(tt.comments, comments) //comments
suite.Equal(tt.xdbcDataType, xdbcDataTypes) //xdbcDataType
suite.Equal(tt.dataTypes, dataTypes) //dataTypes
suite.Equal(tt.xdbcTypeName, xdbcTypeNames) //xdbcTypeName
suite.Equal(tt.xdbcCharMaxLen, xdbcCharMaxLens) //xdbcCharMaxLen
suite.Equal(tt.xdbcScale, xdbcScales) //xdbcScale
suite.Equal(tt.xdbcNumPrecRadix, xdbcNumPrecRadixs) //xdbcNumPrecRadix
suite.Equal(tt.xdbcNullable, xdbcNullables) //xdbcNullable
suite.Equal(tt.xdbcSqlDataType, xdbcSqlDataTypes) //xdbcSqlDataType
suite.Equal(tt.xdbcDateTimeSub, xdbcDateTimeSub) //xdbcDateTimeSub
suite.Equal(tt.xdbcCharOctetLen, xdbcCharOctetLen) //xdbcCharOctetLen
suite.Equal(tt.xdbcIsNullable, xdbcIsNullables) //xdbcIsNullable
suite.Equal(tt.colnames, colnames) // colNames
suite.Equal(tt.positions, positions) // positions
suite.Equal(tt.comments, comments) // comments
suite.Equal(tt.xdbcDataType, xdbcDataTypes) // xdbcDataType
suite.Equal(tt.dataTypes, dataTypes) // dataTypes
suite.Equal(tt.xdbcTypeName, xdbcTypeNames) // xdbcTypeName
suite.Equal(tt.xdbcCharMaxLen, xdbcCharMaxLens) // xdbcCharMaxLen
suite.Equal(tt.xdbcScale, xdbcScales) // xdbcScale
suite.Equal(tt.xdbcNumPrecRadix, xdbcNumPrecRadixs) // xdbcNumPrecRadix
suite.Equal(tt.xdbcNullable, xdbcNullables) // xdbcNullable
suite.Equal(tt.xdbcSqlDataType, xdbcSqlDataTypes) // xdbcSqlDataType
suite.Equal(tt.xdbcDateTimeSub, xdbcDateTimeSub) // xdbcDateTimeSub
suite.Equal(tt.xdbcCharOctetLen, xdbcCharOctetLen) // xdbcCharOctetLen
suite.Equal(tt.xdbcIsNullable, xdbcIsNullables) // xdbcIsNullable

})
}
Expand All @@ -605,6 +606,7 @@ func (suite *SnowflakeTests) TestNewDatabaseGetSetOptions() {
})
suite.NoError(err)
suite.NotNil(db)
defer suite.NoError(db.Close())

getSetDB, ok := db.(adbc.GetSetOptions)
suite.True(ok)
Expand Down Expand Up @@ -862,6 +864,7 @@ func ConnectWithJwt(uri, keyValue, passcode string) {
if err != nil {
panic(err)
}
defer db.Close()

cnxn, err := db.Open(context.Background())
if err != nil {
Expand Down Expand Up @@ -912,6 +915,7 @@ func (suite *SnowflakeTests) TestJwtPrivateKey() {
opts[driver.OptionJwtPrivateKey] = keyFile
db, err := suite.driver.NewDatabase(opts)
suite.NoError(err)
defer db.Close()
cnxn, err := db.Open(suite.ctx)
suite.NoError(err)
defer cnxn.Close()
Expand Down
4 changes: 4 additions & 0 deletions go/adbc/driver/snowflake/snowflake_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,7 @@ func (d *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) {
useHighPrecision: d.useHighPrecision,
}, nil
}

func (d *databaseImpl) Close() error {
return nil
}
Loading
Loading