From e7d28a021d6c906cdd7fe8473d34721137c9c058 Mon Sep 17 00:00:00 2001 From: Anton Levakin Date: Sat, 13 Jan 2024 23:33:57 +0100 Subject: [PATCH 1/3] fix typos in comments --- go/adbc/driver/driverbase/driver.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/go/adbc/driver/driverbase/driver.go b/go/adbc/driver/driverbase/driver.go index c47677948f..8b6c71f797 100644 --- a/go/adbc/driver/driverbase/driver.go +++ b/go/adbc/driver/driverbase/driver.go @@ -21,8 +21,9 @@ package driverbase import ( - "github.com/apache/arrow-adbc/go/adbc" "github.com/apache/arrow/go/v14/arrow/memory" + + "github.com/apache/arrow-adbc/go/adbc" ) // DriverImpl is an interface that drivers implement to provide @@ -32,7 +33,7 @@ type DriverImpl interface { NewDatabase(opts map[string]string) (adbc.Database, error) } -// DatabaseImplBase is a struct that provides default implementations of the +// DriverImplBase is a struct that provides default implementations of the // DriverImpl interface. It is meant to be used as a composite struct for a // driver's DriverImpl implementation. type DriverImplBase struct { @@ -56,7 +57,7 @@ type driver struct { impl DriverImpl } -// NewDatabase wraps a DriverImpl to create an adbc.Driver. +// NewDriver wraps a DriverImpl to create an adbc.Driver. func NewDriver(impl DriverImpl) adbc.Driver { return &driver{impl} } From 2bdb8b4d40147a38a1ea02234eebc527af7bbe22 Mon Sep 17 00:00:00 2001 From: Anton Levakin Date: Sun, 14 Jan 2024 00:41:11 +0100 Subject: [PATCH 2/3] feat(go/adbc): Close database explicitly Implicit database release behaves inconsistently on different OS, which leads to bugs. Closes #1306 --- docs/source/driver/duckdb.rst | 1 + docs/source/driver/flight_sql.rst | 1 + docs/source/driver/postgresql.rst | 1 + docs/source/driver/snowflake.rst | 2 + docs/source/driver/sqlite.rst | 1 + go/adbc/adbc.go | 3 + go/adbc/driver/driverbase/database.go | 5 ++ go/adbc/driver/driverbase/driver.go | 3 +- .../flightsql/flightsql_adbc_server_test.go | 1 + .../driver/flightsql/flightsql_adbc_test.go | 9 +++ .../driver/flightsql/flightsql_database.go | 20 +++--- go/adbc/driver/flightsql/flightsql_driver.go | 1 + go/adbc/driver/panicdummy/panicdummy_adbc.go | 5 ++ go/adbc/driver/snowflake/connection.go | 6 +- go/adbc/driver/snowflake/driver.go | 1 + go/adbc/driver/snowflake/driver_test.go | 62 ++++++++++--------- .../driver/snowflake/snowflake_database.go | 4 ++ go/adbc/drivermgr/wrapper.go | 46 +++++++++----- go/adbc/drivermgr/wrapper_sqlite_test.go | 5 ++ go/adbc/pkg/_tmpl/driver.go.tmpl | 1 + go/adbc/pkg/flightsql/driver.go | 1 + go/adbc/pkg/panicdummy/driver.go | 1 + go/adbc/pkg/snowflake/driver.go | 1 + go/adbc/validation/validation.go | 3 + 24 files changed, 126 insertions(+), 58 deletions(-) diff --git a/docs/source/driver/duckdb.rst b/docs/source/driver/duckdb.rst index 410331c39f..94460eb531 100644 --- a/docs/source/driver/duckdb.rst +++ b/docs/source/driver/duckdb.rst @@ -72,6 +72,7 @@ ADBC support in DuckDB requires the driver manager. if err != nil { // handle error } + defer db.Close() cnxn, err := db.Open(context.Background()) if err != nil { diff --git a/docs/source/driver/flight_sql.rst b/docs/source/driver/flight_sql.rst index aca95d86cd..7473a7cb4c 100644 --- a/docs/source/driver/flight_sql.rst +++ b/docs/source/driver/flight_sql.rst @@ -152,6 +152,7 @@ the :cpp:class:`AdbcDatabase`. if err != nil { // do something with the error } + defer db.Close() cnxn, err := db.Open(context.Background()) if err != nil { diff --git a/docs/source/driver/postgresql.rst b/docs/source/driver/postgresql.rst index ddf9115d76..c724a2c174 100644 --- a/docs/source/driver/postgresql.rst +++ b/docs/source/driver/postgresql.rst @@ -124,6 +124,7 @@ the :cpp:class:`AdbcDatabase`. This should be a `connection URI if err != nil { // handle error } + defer db.Close() cnxn, err := db.Open(context.Background()) if err != nil { diff --git a/docs/source/driver/snowflake.rst b/docs/source/driver/snowflake.rst index 04023a62a5..bf44534967 100644 --- a/docs/source/driver/snowflake.rst +++ b/docs/source/driver/snowflake.rst @@ -127,6 +127,7 @@ constructing the :cpp::class:`AdbcDatabase`. if err != nil { // handle error } + defer db.Close() cnxn, err := db.Open(context.Background()) if err != nil { @@ -241,6 +242,7 @@ a listing). if err != nil { // handle error } + defer db.Close() cnxn, err := db.Open(context.Background()) if err != nil { diff --git a/docs/source/driver/sqlite.rst b/docs/source/driver/sqlite.rst index 30e7d32b67..96bd7bbdb8 100644 --- a/docs/source/driver/sqlite.rst +++ b/docs/source/driver/sqlite.rst @@ -140,6 +140,7 @@ shared across all connections. if err != nil { // handle error } + defer db.Close() cnxn, err := db.Open(context.Background()) if err != nil { diff --git a/go/adbc/adbc.go b/go/adbc/adbc.go index 3fb61d692d..71a75dafa8 100644 --- a/go/adbc/adbc.go +++ b/go/adbc/adbc.go @@ -329,6 +329,9 @@ type Driver interface { type Database interface { SetOptions(map[string]string) error Open(ctx context.Context) (Connection, error) + + // Close closes this database and releases any associated resources. + Close() error } type InfoCode uint32 diff --git a/go/adbc/driver/driverbase/database.go b/go/adbc/driver/driverbase/database.go index e3a96ff16c..7f32510c8e 100644 --- a/go/adbc/driver/driverbase/database.go +++ b/go/adbc/driver/driverbase/database.go @@ -31,6 +31,7 @@ type DatabaseImpl interface { adbc.GetSetOptions Base() *DatabaseImplBase Open(context.Context) (adbc.Connection, error) + Close() error SetOptions(map[string]string) error } @@ -134,6 +135,10 @@ func (db *database) Open(ctx context.Context) (adbc.Connection, error) { return db.impl.Open(ctx) } +func (db *database) Close() error { + return db.impl.Close() +} + func (db *database) SetLogger(logger *slog.Logger) { if logger != nil { db.impl.Base().Logger = logger diff --git a/go/adbc/driver/driverbase/driver.go b/go/adbc/driver/driverbase/driver.go index 8b6c71f797..acd182f8a1 100644 --- a/go/adbc/driver/driverbase/driver.go +++ b/go/adbc/driver/driverbase/driver.go @@ -21,9 +21,8 @@ package driverbase import ( - "github.com/apache/arrow/go/v14/arrow/memory" - "github.com/apache/arrow-adbc/go/adbc" + "github.com/apache/arrow/go/v14/arrow/memory" ) // DriverImpl is an interface that drivers implement to provide diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go index a591f1cae2..dfd1f6cfd7 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go +++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go @@ -90,6 +90,7 @@ func (suite *ServerBasedTests) TearDownTest() { } func (suite *ServerBasedTests) TearDownSuite() { + suite.NoError(suite.db.Close()) suite.db = nil suite.s.Shutdown() } diff --git a/go/adbc/driver/flightsql/flightsql_adbc_test.go b/go/adbc/driver/flightsql/flightsql_adbc_test.go index 1619f8fabb..dc7d207dd5 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc_test.go +++ b/go/adbc/driver/flightsql/flightsql_adbc_test.go @@ -352,6 +352,7 @@ func (suite *DefaultDialOptionsTests) SetupSuite() { func (suite *DefaultDialOptionsTests) TearDownSuite() { suite.Quirks.TearDownDriver(suite.T(), suite.Driver) + suite.NoError(suite.DB.Close()) suite.DB = nil suite.Driver = nil } @@ -361,6 +362,7 @@ func (suite *DefaultDialOptionsTests) TestMaxIncomingMessageSizeDefault() { opts["adbc.flight.sql.client_option.with_max_msg_size"] = "1000000" db, err := suite.Driver.NewDatabase(opts) suite.NoError(err) + defer suite.NoError(db.Close()) cnxn, err := db.Open(suite.ctx) suite.NoError(err) @@ -505,6 +507,7 @@ func (suite *PartitionTests) TearDownTest() { suite.Require().NoError(suite.Cnxn.Close()) suite.Quirks.TearDownDriver(suite.T(), suite.Driver) suite.Cnxn = nil + suite.NoError(suite.DB.Close()) suite.DB = nil suite.Driver = nil } @@ -558,6 +561,7 @@ func (suite *StatementTests) TearDownTest() { suite.Require().NoError(suite.Cnxn.Close()) suite.Quirks.TearDownDriver(suite.T(), suite.Driver) suite.Cnxn = nil + suite.NoError(suite.DB.Close()) suite.DB = nil suite.Driver = nil } @@ -639,6 +643,7 @@ func (suite *HeaderTests) TearDownTest() { suite.Require().NoError(suite.Cnxn.Close()) suite.Quirks.TearDownDriver(suite.T(), suite.Driver) suite.Cnxn = nil + suite.NoError(suite.DB.Close()) suite.DB = nil suite.Driver = nil } @@ -842,6 +847,7 @@ func (suite *TLSTests) TearDownTest() { suite.Require().NoError(suite.Cnxn.Close()) suite.Quirks.TearDownDriver(suite.T(), suite.Driver) suite.Cnxn = nil + suite.NoError(suite.DB.Close()) suite.DB = nil suite.Driver = nil } @@ -863,6 +869,7 @@ func (suite *TLSTests) TestInvalidOptions() { "adbc.flight.sql.client_option.tls_skip_verify": "false", }) suite.Require().NoError(err) + defer suite.NoError(db.Close()) cnxn, err := db.Open(suite.ctx) suite.Require().NoError(err) @@ -912,6 +919,7 @@ func (suite *ConnectionTests) SetupSuite() { } func (suite *ConnectionTests) TearDownSuite() { + suite.NoError(suite.DB.Close()) suite.server.Shutdown() suite.alloc.AssertSize(suite.T(), 0) } @@ -1009,6 +1017,7 @@ func (suite *DomainSocketTests) SetupSuite() { func (suite *DomainSocketTests) TearDownSuite() { suite.Require().NoError(suite.Stmt.Close()) suite.Require().NoError(suite.Cnxn.Close()) + suite.NoError(suite.DB.Close()) suite.server.Shutdown() suite.alloc.AssertSize(suite.T(), 0) } diff --git a/go/adbc/driver/flightsql/flightsql_database.go b/go/adbc/driver/flightsql/flightsql_database.go index 8b6ab2ccb9..f9537f5097 100644 --- a/go/adbc/driver/flightsql/flightsql_database.go +++ b/go/adbc/driver/flightsql/flightsql_database.go @@ -332,6 +332,10 @@ func (d *databaseImpl) SetOptionDouble(key string, value float64) error { return d.DatabaseImplBase.SetOptionDouble(key, value) } +func (d *databaseImpl) Close() error { + return nil +} + func getFlightClient(ctx context.Context, loc string, d *databaseImpl) (*flightsql.Client, error) { authMiddle := &bearerAuthMiddleware{hdrs: d.hdrs.Copy()} middleware := []flight.ClientMiddleware{ @@ -396,8 +400,8 @@ type support struct { transactions bool } -func (impl *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) { - cl, err := getFlightClient(ctx, impl.uri.String(), impl) +func (d *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) { + cl, err := getFlightClient(ctx, d.uri.String(), d) if err != nil { return nil, err } @@ -410,12 +414,12 @@ func (impl *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) { return nil, adbc.Error{Msg: fmt.Sprintf("Location must be a string, got %#v", uri), Code: adbc.StatusInternal} } - cl, err := getFlightClient(context.Background(), uri, impl) + cl, err := getFlightClient(context.Background(), uri, d) if err != nil { return nil, err } - cl.Alloc = impl.Alloc + cl.Alloc = d.Alloc return cl, nil }). EvictedFunc(func(_, client interface{}) { @@ -425,13 +429,13 @@ func (impl *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) { var cnxnSupport support - info, err := cl.GetSqlInfo(ctx, []flightsql.SqlInfo{flightsql.SqlInfoFlightSqlServerTransaction}, impl.timeout) + info, err := cl.GetSqlInfo(ctx, []flightsql.SqlInfo{flightsql.SqlInfoFlightSqlServerTransaction}, d.timeout) // ignore this if it fails if err == nil { const int32code = 3 for _, endpoint := range info.Endpoint { - rdr, err := doGet(ctx, cl, endpoint, cache, impl.timeout) + rdr, err := doGet(ctx, cl, endpoint, cache, d.timeout) if err != nil { continue } @@ -465,8 +469,8 @@ func (impl *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) { } } - return &cnxn{cl: cl, db: impl, clientCache: cache, - hdrs: make(metadata.MD), timeouts: impl.timeout, + return &cnxn{cl: cl, db: d, clientCache: cache, + hdrs: make(metadata.MD), timeouts: d.timeout, supportInfo: cnxnSupport}, nil } diff --git a/go/adbc/driver/flightsql/flightsql_driver.go b/go/adbc/driver/flightsql/flightsql_driver.go index 0060c04055..cc58a9e15b 100644 --- a/go/adbc/driver/flightsql/flightsql_driver.go +++ b/go/adbc/driver/flightsql/flightsql_driver.go @@ -145,5 +145,6 @@ func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error) if err := db.SetOptions(opts); err != nil { return nil, err } + return driverbase.NewDatabase(db), nil } diff --git a/go/adbc/driver/panicdummy/panicdummy_adbc.go b/go/adbc/driver/panicdummy/panicdummy_adbc.go index f0513cd6f3..95171591d0 100644 --- a/go/adbc/driver/panicdummy/panicdummy_adbc.go +++ b/go/adbc/driver/panicdummy/panicdummy_adbc.go @@ -66,6 +66,11 @@ func (d *database) Open(ctx context.Context) (adbc.Connection, error) { return &cnxn{}, nil } +func (d *database) Close() error { + maybePanic("DatabaseClose") + return nil +} + type cnxn struct{} func (c *cnxn) SetOption(key, value string) error { diff --git a/go/adbc/driver/snowflake/connection.go b/go/adbc/driver/snowflake/connection.go index 73c31604de..e2f98487ec 100644 --- a/go/adbc/driver/snowflake/connection.go +++ b/go/adbc/driver/snowflake/connection.go @@ -714,16 +714,16 @@ func prepareTablesSQL(matchingCatalogNames []string, catalog *string, dbSchema * func prepareColumnsSQL(matchingCatalogNames []string, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (string, []interface{}) { prefixQuery := "" - for _, catalog_name := range matchingCatalogNames { + for _, catalogName := range matchingCatalogNames { if prefixQuery != "" { prefixQuery += " UNION ALL " } prefixQuery += `SELECT T.table_type, C.* FROM - "` + strings.ReplaceAll(catalog_name, "\"", "\"\"") + `".INFORMATION_SCHEMA.TABLES AS T + "` + strings.ReplaceAll(catalogName, "\"", "\"\"") + `".INFORMATION_SCHEMA.TABLES AS T JOIN - "` + strings.ReplaceAll(catalog_name, "\"", "\"\"") + `".INFORMATION_SCHEMA.COLUMNS AS C + "` + strings.ReplaceAll(catalogName, "\"", "\"\"") + `".INFORMATION_SCHEMA.COLUMNS AS C ON T.table_catalog = C.table_catalog AND T.table_schema = C.table_schema diff --git a/go/adbc/driver/snowflake/driver.go b/go/adbc/driver/snowflake/driver.go index db5efed456..3b9d72cc7d 100644 --- a/go/adbc/driver/snowflake/driver.go +++ b/go/adbc/driver/snowflake/driver.go @@ -202,5 +202,6 @@ func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error) if err := db.SetOptions(opts); err != nil { return nil, err } + return driverbase.NewDatabase(db), nil } diff --git a/go/adbc/driver/snowflake/driver_test.go b/go/adbc/driver/snowflake/driver_test.go index 61e944866d..a69a0b0455 100644 --- a/go/adbc/driver/snowflake/driver_test.go +++ b/go/adbc/driver/snowflake/driver_test.go @@ -336,6 +336,7 @@ func (suite *SnowflakeTests) TearDownTest() { } func (suite *SnowflakeTests) TearDownSuite() { + suite.NoError(suite.db.Close()) suite.db = nil } @@ -464,21 +465,21 @@ func (suite *SnowflakeTests) TestMetadataGetObjectsColumnsXdbc() { xdbcDateTimeSub []string }{ { - "BASIC", //name - []string{"int64s", "strings"}, //colNames - []string{"1", "2"}, //positions - []string{"NUMBER", "TEXT"}, //dataTypes - []string{"", ""}, //comments - []string{"9", "13"}, //xdbcDataType - []string{"NUMBER", "TEXT"}, //xdbcTypeName - []string{"-5", "12"}, //xdbcSqlDataType - []string{"1", "1"}, //xdbcNullable - []string{"YES", "YES"}, //xdbcIsNullable - []string{"0", "0"}, //xdbcScale - []string{"10", "0"}, //xdbcNumPrecRadix - []string{"38", "16777216"}, //xdbcCharMaxLen (xdbcPrecision) - []string{"0", "16777216"}, //xdbcCharOctetLen - []string{"-5", "12", "0"}, //xdbcDateTimeSub + "BASIC", // name + []string{"int64s", "strings"}, // colNames + []string{"1", "2"}, // positions + []string{"NUMBER", "TEXT"}, // dataTypes + []string{"", ""}, // comments + []string{"9", "13"}, // xdbcDataType + []string{"NUMBER", "TEXT"}, // xdbcTypeName + []string{"-5", "12"}, // xdbcSqlDataType + []string{"1", "1"}, // xdbcNullable + []string{"YES", "YES"}, // xdbcIsNullable + []string{"0", "0"}, // xdbcScale + []string{"10", "0"}, // xdbcNumPrecRadix + []string{"38", "16777216"}, // xdbcCharMaxLen (xdbcPrecision) + []string{"0", "16777216"}, // xdbcCharOctetLen + []string{"-5", "12", "0"}, // xdbcDateTimeSub }, } @@ -576,20 +577,20 @@ func (suite *SnowflakeTests) TestMetadataGetObjectsColumnsXdbc() { suite.False(rdr.Next()) suite.True(foundExpected) - suite.Equal(tt.colnames, colnames) //colNames - suite.Equal(tt.positions, positions) //positions - suite.Equal(tt.comments, comments) //comments - suite.Equal(tt.xdbcDataType, xdbcDataTypes) //xdbcDataType - suite.Equal(tt.dataTypes, dataTypes) //dataTypes - suite.Equal(tt.xdbcTypeName, xdbcTypeNames) //xdbcTypeName - suite.Equal(tt.xdbcCharMaxLen, xdbcCharMaxLens) //xdbcCharMaxLen - suite.Equal(tt.xdbcScale, xdbcScales) //xdbcScale - suite.Equal(tt.xdbcNumPrecRadix, xdbcNumPrecRadixs) //xdbcNumPrecRadix - suite.Equal(tt.xdbcNullable, xdbcNullables) //xdbcNullable - suite.Equal(tt.xdbcSqlDataType, xdbcSqlDataTypes) //xdbcSqlDataType - suite.Equal(tt.xdbcDateTimeSub, xdbcDateTimeSub) //xdbcDateTimeSub - suite.Equal(tt.xdbcCharOctetLen, xdbcCharOctetLen) //xdbcCharOctetLen - suite.Equal(tt.xdbcIsNullable, xdbcIsNullables) //xdbcIsNullable + suite.Equal(tt.colnames, colnames) // colNames + suite.Equal(tt.positions, positions) // positions + suite.Equal(tt.comments, comments) // comments + suite.Equal(tt.xdbcDataType, xdbcDataTypes) // xdbcDataType + suite.Equal(tt.dataTypes, dataTypes) // dataTypes + suite.Equal(tt.xdbcTypeName, xdbcTypeNames) // xdbcTypeName + suite.Equal(tt.xdbcCharMaxLen, xdbcCharMaxLens) // xdbcCharMaxLen + suite.Equal(tt.xdbcScale, xdbcScales) // xdbcScale + suite.Equal(tt.xdbcNumPrecRadix, xdbcNumPrecRadixs) // xdbcNumPrecRadix + suite.Equal(tt.xdbcNullable, xdbcNullables) // xdbcNullable + suite.Equal(tt.xdbcSqlDataType, xdbcSqlDataTypes) // xdbcSqlDataType + suite.Equal(tt.xdbcDateTimeSub, xdbcDateTimeSub) // xdbcDateTimeSub + suite.Equal(tt.xdbcCharOctetLen, xdbcCharOctetLen) // xdbcCharOctetLen + suite.Equal(tt.xdbcIsNullable, xdbcIsNullables) // xdbcIsNullable }) } @@ -605,6 +606,7 @@ func (suite *SnowflakeTests) TestNewDatabaseGetSetOptions() { }) suite.NoError(err) suite.NotNil(db) + defer suite.NoError(db.Close()) getSetDB, ok := db.(adbc.GetSetOptions) suite.True(ok) @@ -862,6 +864,7 @@ func ConnectWithJwt(uri, keyValue, passcode string) { if err != nil { panic(err) } + defer db.Close() cnxn, err := db.Open(context.Background()) if err != nil { @@ -912,6 +915,7 @@ func (suite *SnowflakeTests) TestJwtPrivateKey() { opts[driver.OptionJwtPrivateKey] = keyFile db, err := suite.driver.NewDatabase(opts) suite.NoError(err) + defer db.Close() cnxn, err := db.Open(suite.ctx) suite.NoError(err) defer cnxn.Close() diff --git a/go/adbc/driver/snowflake/snowflake_database.go b/go/adbc/driver/snowflake/snowflake_database.go index 45e3aab467..7b76fa5a5a 100644 --- a/go/adbc/driver/snowflake/snowflake_database.go +++ b/go/adbc/driver/snowflake/snowflake_database.go @@ -466,3 +466,7 @@ func (d *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) { useHighPrecision: d.useHighPrecision, }, nil } + +func (d *databaseImpl) Close() error { + return nil +} diff --git a/go/adbc/drivermgr/wrapper.go b/go/adbc/drivermgr/wrapper.go index 07bb94b814..63fb9ee9f4 100644 --- a/go/adbc/drivermgr/wrapper.go +++ b/go/adbc/drivermgr/wrapper.go @@ -39,7 +39,7 @@ package drivermgr import "C" import ( "context" - "runtime" + "sync" "unsafe" "github.com/apache/arrow-adbc/go/adbc" @@ -100,27 +100,15 @@ func (d Driver) NewDatabase(opts map[string]string) (adbc.Database, error) { return nil, errOut } - runtime.SetFinalizer(db, func(db *Database) { - if db.db != nil { - var err C.struct_AdbcError - code := adbc.Status(C.AdbcDatabaseRelease(db.db, &err)) - if code != adbc.StatusOK { - panic(toAdbcError(code, &err)) - } - } - - for _, o := range db.options { - C.free(unsafe.Pointer(o.key)) - C.free(unsafe.Pointer(o.val)) - } - }) - return db, nil } type Database struct { options map[string]option db *C.struct_AdbcDatabase + + mu sync.Mutex // protects following fields + closed bool } func toAdbcError(code adbc.Status, e *C.struct_AdbcError) error { @@ -182,6 +170,32 @@ func (d *Database) Open(context.Context) (adbc.Connection, error) { return &cnxn{conn: &c}, nil } +func (d *Database) Close() error { + d.mu.Lock() + defer d.mu.Unlock() + + if d.closed { + return nil + } + + d.closed = true + + for _, o := range d.options { + C.free(unsafe.Pointer(o.key)) + C.free(unsafe.Pointer(o.val)) + } + + if d.db != nil { + var err C.struct_AdbcError + code := adbc.Status(C.AdbcDatabaseRelease(d.db, &err)) + if code != adbc.StatusOK { + return toAdbcError(code, &err) + } + } + + return nil +} + func getRdr(out *C.struct_ArrowArrayStream) (array.RecordReader, error) { rdr, err := cdata.ImportCRecordReader((*cdata.CArrowArrayStream)(unsafe.Pointer(out)), nil) if err != nil { diff --git a/go/adbc/drivermgr/wrapper_sqlite_test.go b/go/adbc/drivermgr/wrapper_sqlite_test.go index c33adf2792..af307a08d5 100644 --- a/go/adbc/drivermgr/wrapper_sqlite_test.go +++ b/go/adbc/drivermgr/wrapper_sqlite_test.go @@ -74,6 +74,10 @@ func (dm *DriverMgrSuite) SetupSuite() { dm.Equal(int64(1), nrows) } +func (dm *DriverMgrSuite) TearDownSuite() { + dm.NoError(dm.db.Close()) +} + func (dm *DriverMgrSuite) SetupTest() { cnxn, err := dm.db.Open(dm.ctx) dm.Require().NoError(err) @@ -597,6 +601,7 @@ func TestDriverMgrCustomInitFunc(t *testing.T) { cnxn, err := db.Open(context.Background()) assert.NoError(t, err) require.NoError(t, cnxn.Close()) + require.NoError(t, db.Close()) // set invalid entrypoint drv = drivermgr.Driver{} diff --git a/go/adbc/pkg/_tmpl/driver.go.tmpl b/go/adbc/pkg/_tmpl/driver.go.tmpl index 4b7008ea9e..a87303e573 100644 --- a/go/adbc/pkg/_tmpl/driver.go.tmpl +++ b/go/adbc/pkg/_tmpl/driver.go.tmpl @@ -591,6 +591,7 @@ func {{.Prefix}}DatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcErr h := (*(*cgo.Handle)(db.private_data)) cdb := h.Value().(*cDatabase) + cdb.db.Close() cdb.db = nil cdb.opts = nil C.free(unsafe.Pointer(db.private_data)) diff --git a/go/adbc/pkg/flightsql/driver.go b/go/adbc/pkg/flightsql/driver.go index 2847274c30..5315d2ee76 100644 --- a/go/adbc/pkg/flightsql/driver.go +++ b/go/adbc/pkg/flightsql/driver.go @@ -594,6 +594,7 @@ func FlightSQLDatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcError h := (*(*cgo.Handle)(db.private_data)) cdb := h.Value().(*cDatabase) + cdb.db.Close() cdb.db = nil cdb.opts = nil C.free(unsafe.Pointer(db.private_data)) diff --git a/go/adbc/pkg/panicdummy/driver.go b/go/adbc/pkg/panicdummy/driver.go index 399d0edcdb..fbaa5204cd 100644 --- a/go/adbc/pkg/panicdummy/driver.go +++ b/go/adbc/pkg/panicdummy/driver.go @@ -594,6 +594,7 @@ func PanicDummyDatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcErro h := (*(*cgo.Handle)(db.private_data)) cdb := h.Value().(*cDatabase) + cdb.db.Close() cdb.db = nil cdb.opts = nil C.free(unsafe.Pointer(db.private_data)) diff --git a/go/adbc/pkg/snowflake/driver.go b/go/adbc/pkg/snowflake/driver.go index b591018106..790887503f 100644 --- a/go/adbc/pkg/snowflake/driver.go +++ b/go/adbc/pkg/snowflake/driver.go @@ -594,6 +594,7 @@ func SnowflakeDatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcError h := (*(*cgo.Handle)(db.private_data)) cdb := h.Value().(*cDatabase) + cdb.db.Close() cdb.db = nil cdb.opts = nil C.free(unsafe.Pointer(db.private_data)) diff --git a/go/adbc/validation/validation.go b/go/adbc/validation/validation.go index 8925f72532..192228075b 100644 --- a/go/adbc/validation/validation.go +++ b/go/adbc/validation/validation.go @@ -100,6 +100,7 @@ func (d *DatabaseTests) TestNewDatabase() { d.NoError(err) d.NotNil(db) d.Implements((*adbc.Database)(nil), db) + d.NoError(db.Close()) } type ConnectionTests struct { @@ -121,6 +122,7 @@ func (c *ConnectionTests) SetupTest() { func (c *ConnectionTests) TearDownTest() { c.Quirks.TearDownDriver(c.T(), c.Driver) c.Driver = nil + c.NoError(c.DB.Close()) c.DB = nil } @@ -514,6 +516,7 @@ func (s *StatementTests) TearDownTest() { s.Require().NoError(s.Cnxn.Close()) s.Quirks.TearDownDriver(s.T(), s.Driver) s.Cnxn = nil + s.NoError(s.DB.Close()) s.DB = nil s.Driver = nil } From ff3ae55d3ca8150ea90832ee03ea280ebc20c270 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Thu, 18 Jan 2024 16:58:02 -0500 Subject: [PATCH 3/3] don't call close on nil db --- go/adbc/pkg/_tmpl/driver.go.tmpl | 12 ++++++++---- go/adbc/pkg/flightsql/driver.go | 12 ++++++++---- go/adbc/pkg/snowflake/driver.go | 12 ++++++++---- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/go/adbc/pkg/_tmpl/driver.go.tmpl b/go/adbc/pkg/_tmpl/driver.go.tmpl index a87303e573..901d164ec0 100644 --- a/go/adbc/pkg/_tmpl/driver.go.tmpl +++ b/go/adbc/pkg/_tmpl/driver.go.tmpl @@ -591,11 +591,15 @@ func {{.Prefix}}DatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcErr h := (*(*cgo.Handle)(db.private_data)) cdb := h.Value().(*cDatabase) - cdb.db.Close() - cdb.db = nil + if cdb.db != nil { + cdb.db.Close() + cdb.db = nil + } cdb.opts = nil - C.free(unsafe.Pointer(db.private_data)) - db.private_data = nil + if db.private_data != nil { + C.free(unsafe.Pointer(db.private_data)) + db.private_data = nil + } h.Delete() // manually trigger GC for two reasons: // 1. ASAN expects the release callback to be called before diff --git a/go/adbc/pkg/flightsql/driver.go b/go/adbc/pkg/flightsql/driver.go index 5315d2ee76..d57a91b7e7 100644 --- a/go/adbc/pkg/flightsql/driver.go +++ b/go/adbc/pkg/flightsql/driver.go @@ -594,11 +594,15 @@ func FlightSQLDatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcError h := (*(*cgo.Handle)(db.private_data)) cdb := h.Value().(*cDatabase) - cdb.db.Close() - cdb.db = nil + if cdb.db != nil { + cdb.db.Close() + cdb.db = nil + } cdb.opts = nil - C.free(unsafe.Pointer(db.private_data)) - db.private_data = nil + if db.private_data != nil { + C.free(unsafe.Pointer(db.private_data)) + db.private_data = nil + } h.Delete() // manually trigger GC for two reasons: // 1. ASAN expects the release callback to be called before diff --git a/go/adbc/pkg/snowflake/driver.go b/go/adbc/pkg/snowflake/driver.go index 790887503f..6e2d3bac50 100644 --- a/go/adbc/pkg/snowflake/driver.go +++ b/go/adbc/pkg/snowflake/driver.go @@ -594,11 +594,15 @@ func SnowflakeDatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcError h := (*(*cgo.Handle)(db.private_data)) cdb := h.Value().(*cDatabase) - cdb.db.Close() - cdb.db = nil + if cdb.db != nil { + cdb.db.Close() + cdb.db = nil + } cdb.opts = nil - C.free(unsafe.Pointer(db.private_data)) - db.private_data = nil + if db.private_data != nil { + C.free(unsafe.Pointer(db.private_data)) + db.private_data = nil + } h.Delete() // manually trigger GC for two reasons: // 1. ASAN expects the release callback to be called before