From 822c8ad55d8d92dd961c4f2bef9bce05af7b6950 Mon Sep 17 00:00:00 2001 From: joel Date: Fri, 15 Mar 2024 10:32:36 -0400 Subject: [PATCH] pass logger to connection and add/refactor tests --- go/adbc/driver/driverbase/connection.go | 3 + go/adbc/driver/driverbase/driver_test.go | 174 ++++++++++++++++++----- 2 files changed, 144 insertions(+), 33 deletions(-) diff --git a/go/adbc/driver/driverbase/connection.go b/go/adbc/driver/driverbase/connection.go index e9a0f8b7b0..a861d004e3 100644 --- a/go/adbc/driver/driverbase/connection.go +++ b/go/adbc/driver/driverbase/connection.go @@ -26,6 +26,7 @@ import ( "github.com/apache/arrow/go/v16/arrow" "github.com/apache/arrow/go/v16/arrow/array" "github.com/apache/arrow/go/v16/arrow/memory" + "golang.org/x/exp/slog" ) const ( @@ -102,6 +103,7 @@ type ConnectionImplBase struct { Alloc memory.Allocator ErrorHelper ErrorHelper DriverInfo *DriverInfo + Logger *slog.Logger Autocommit bool Closed bool @@ -116,6 +118,7 @@ func NewConnectionImplBase(database *DatabaseImplBase) ConnectionImplBase { Alloc: database.Alloc, ErrorHelper: database.ErrorHelper, DriverInfo: database.DriverInfo, + Logger: database.Logger, Autocommit: true, Closed: false, } diff --git a/go/adbc/driver/driverbase/driver_test.go b/go/adbc/driver/driverbase/driver_test.go index 46f70f7a4d..89dadd2b8a 100644 --- a/go/adbc/driver/driverbase/driver_test.go +++ b/go/adbc/driver/driverbase/driver_test.go @@ -38,29 +38,24 @@ const ( OptionKeyUnrecognized = "unrecognized" ) -type MockedHandler struct { - mock.Mock -} - -func (h *MockedHandler) Enabled(ctx context.Context, level slog.Level) bool { return true } -func (h *MockedHandler) WithAttrs(attrs []slog.Attr) slog.Handler { return h } -func (h *MockedHandler) WithGroup(name string) slog.Handler { return h } -func (h *MockedHandler) Handle(ctx context.Context, r slog.Record) error { - // We only care to assert the message value, and want to isolate nondetermistic behavior (e.g. timestamp) - args := h.Called(ctx, r.Message) - return args.Error(0) -} - -func NewDriver(alloc memory.Allocator, useHelpers bool) adbc.Driver { +// NewDriver creates a new adbc.Driver for testing. In addition to a memory.Allocator, it takes +// a slog.Handler to use for all structured logging as well as a useHelpers flag to determine whether +// the test should register helper methods or use the default driverbase implementation. +func NewDriver(alloc memory.Allocator, handler slog.Handler, useHelpers bool) adbc.Driver { info := driverbase.DefaultDriverInfo("MockDriver") _ = info.RegisterInfoCode(adbc.InfoCode(10_001), "my custom info") - return driverbase.NewDriver(&driverImpl{DriverImplBase: driverbase.NewDriverImplBase(info, alloc), useHelpers: useHelpers}) + return driverbase.NewDriver(&driverImpl{DriverImplBase: driverbase.NewDriverImplBase(info, alloc), handler: handler, useHelpers: useHelpers}) } func TestDefaultDriver(t *testing.T) { + var handler MockedHandler + handler.On("Handle", mock.Anything, mock.Anything).Return(nil) + ctx := context.TODO() - alloc := memory.DefaultAllocator - drv := NewDriver(alloc, false) // Do not use helper implementations; only default behavior + alloc := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer alloc.AssertSize(t, 0) + + drv := NewDriver(alloc, &handler, false) // Do not use helper implementations; only default behavior db, err := drv.NewDatabase(nil) require.NoError(t, err) @@ -90,6 +85,8 @@ func TestDefaultDriver(t *testing.T) { info, err := cnxn.GetInfo(ctx, nil) require.NoError(t, err) + getInfoTable := tableFromRecordReader(info) + defer getInfoTable.Release() // This is what the driverbase provided GetInfo result should look like out of the box, // with one custom setting registered at initialization @@ -128,9 +125,7 @@ func TestDefaultDriver(t *testing.T) { } ]`}) require.NoError(t, err) - - getInfoTable := tableFromRecordReader(info) - defer getInfoTable.Release() + defer expectedGetInfoTable.Release() require.Truef(t, array.TableEqual(expectedGetInfoTable, getInfoTable), "expected: %s\ngot: %s", expectedGetInfoTable, getInfoTable) @@ -157,12 +152,41 @@ func TestDefaultDriver(t *testing.T) { err = cnxn.(adbc.GetSetOptions).SetOption(adbc.OptionKeyCurrentCatalog, "test_catalog") require.Error(t, err) require.Equal(t, "Not Implemented: [MockDriver] Unknown connection option 'adbc.connection.catalog'", err.Error()) + + // We passed a mock handler into the driver to use for logs, so we can check actual messages logged + expectedLogMessages := []logMessage{ + {Message: "Opening a new connection", Level: "INFO", Attrs: map[string]string{"withHelpers": "false"}}, + } + + logMessages := make([]logMessage, 0, len(handler.Calls)) + for _, call := range handler.Calls { + sr, ok := call.Arguments.Get(1).(slog.Record) + require.True(t, ok) + logMessages = append(logMessages, newLogMessage(sr)) + } + + for _, expected := range expectedLogMessages { + var found bool + for _, message := range logMessages { + if messagesEqual(message, expected) { + found = true + break + } + } + require.Truef(t, found, "expected message was never logged: %v", expected) + } + } func TestCustomizedDriver(t *testing.T) { + var handler MockedHandler + handler.On("Handle", mock.Anything, mock.Anything).Return(nil) + ctx := context.TODO() - alloc := memory.DefaultAllocator - drv := NewDriver(alloc, true) // Use helper implementations + alloc := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer alloc.AssertSize(t, 0) + + drv := NewDriver(alloc, &handler, true) // Use helper implementations db, err := drv.NewDatabase(nil) require.NoError(t, err) @@ -188,6 +212,8 @@ func TestCustomizedDriver(t *testing.T) { info, err := cnxn.GetInfo(ctx, nil) require.NoError(t, err) + getInfoTable := tableFromRecordReader(info) + defer getInfoTable.Release() // This is the arrow table representation of GetInfo produced by merging: // - the default DriverInfo set at initialization @@ -232,14 +258,14 @@ func TestCustomizedDriver(t *testing.T) { } ]`}) require.NoError(t, err) - - getInfoTable := tableFromRecordReader(info) - defer getInfoTable.Release() + defer expectedGetInfoTable.Release() require.Truef(t, array.TableEqual(expectedGetInfoTable, getInfoTable), "expected: %s\ngot: %s", expectedGetInfoTable, getInfoTable) dbObjects, err := cnxn.GetObjects(ctx, adbc.ObjectDepthAll, nil, nil, nil, nil, nil) require.NoError(t, err) + dbObjectsTable := tableFromRecordReader(dbObjects) + defer dbObjectsTable.Release() // This is the arrow table representation of the GetObjects output we get by implementing // the simplified TableTypeLister interface @@ -289,14 +315,14 @@ func TestCustomizedDriver(t *testing.T) { } ]`}) require.NoError(t, err) - - dbObjectsTable := tableFromRecordReader(dbObjects) - defer dbObjectsTable.Release() + defer expectedDbObjectsTable.Release() require.Truef(t, array.TableEqual(expectedDbObjectsTable, dbObjectsTable), "expected: %s\ngot: %s", expectedDbObjectsTable, dbObjectsTable) tableTypes, err := cnxn.GetTableTypes(ctx) require.NoError(t, err) + tableTypeTable := tableFromRecordReader(tableTypes) + defer tableTypeTable.Release() // This is the arrow table representation of the GetTableTypes output we get by implementing // the simplified TableTypeLister interface @@ -305,9 +331,7 @@ func TestCustomizedDriver(t *testing.T) { { "table_type": "VIEW" } ]`}) require.NoError(t, err) - - tableTypeTable := tableFromRecordReader(tableTypes) - defer tableTypeTable.Release() + defer expectedTableTypesTable.Release() require.Truef(t, array.TableEqual(expectedTableTypesTable, tableTypeTable), "expected: %s\ngot: %s", expectedTableTypesTable, tableTypeTable) @@ -349,22 +373,48 @@ func TestCustomizedDriver(t *testing.T) { currentDbSchema, err := cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentDbSchema) require.NoError(t, err) require.Equal(t, "test_schema", currentDbSchema) + + // We passed a mock handler into the driver to use for logs, so we can check actual messages logged + expectedLogMessages := []logMessage{ + {Message: "Opening a new connection", Level: "INFO", Attrs: map[string]string{"withHelpers": "true"}}, + {Message: "SetAutocommit", Level: "DEBUG", Attrs: map[string]string{"enabled": "false"}}, + {Message: "SetCurrentCatalog", Level: "DEBUG", Attrs: map[string]string{"val": "test_catalog"}}, + {Message: "SetCurrentDbSchema", Level: "DEBUG", Attrs: map[string]string{"val": "test_schema"}}, + } + + logMessages := make([]logMessage, 0, len(handler.Calls)) + for _, call := range handler.Calls { + sr, ok := call.Arguments.Get(1).(slog.Record) + require.True(t, ok) + logMessages = append(logMessages, newLogMessage(sr)) + } + + for _, expected := range expectedLogMessages { + var found bool + for _, message := range logMessages { + if messagesEqual(message, expected) { + found = true + break + } + } + require.Truef(t, found, "expected message was never logged: %v", expected) + } } type driverImpl struct { driverbase.DriverImplBase + handler slog.Handler useHelpers bool } func (drv *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error) { - var handler MockedHandler db := driverbase.NewDatabase( &databaseImpl{DatabaseImplBase: driverbase.NewDatabaseImplBase(&drv.DriverImplBase), drv: drv, useHelpers: drv.useHelpers, }) - db.SetLogger(slog.New(&handler)) + db.SetLogger(slog.New(drv.handler)) return db, nil } @@ -397,6 +447,7 @@ func (d *databaseImpl) SetOption(key, value string) error { } func (db *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) { + db.DatabaseImplBase.Logger.Info("Opening a new connection", "withHelpers", db.useHelpers) cnxn := &connectionImpl{ConnectionImplBase: driverbase.NewConnectionImplBase(&db.DatabaseImplBase), db: db} bldr := driverbase.NewConnectionBuilder(cnxn) if db.useHelpers { // this toggles between the NewDefaultDriver and NewCustomizedDriver scenarios @@ -420,6 +471,7 @@ type connectionImpl struct { } func (c *connectionImpl) SetAutocommit(enabled bool) error { + c.Base().Logger.Debug("SetAutocommit", "enabled", enabled) return nil } @@ -438,11 +490,13 @@ func (c *connectionImpl) GetCurrentDbSchema() (string, bool) { } func (c *connectionImpl) SetCurrentCatalog(val string) error { + c.Base().Logger.Debug("SetCurrentCatalog", "val", val) c.currentCatalog = val return nil } func (c *connectionImpl) SetCurrentDbSchema(val string) error { + c.Base().Logger.Debug("SetCurrentDbSchema", "val", val) c.currentDbSchema = val return nil } @@ -474,7 +528,61 @@ func (c *connectionImpl) GetObjectsTables(ctx context.Context, depth adbc.Object }, nil } +// MockedHandler is a mock.Mock that implements the slog.Handler interface. +// It is used to assert specific behavior for loggers it is injected into. +type MockedHandler struct { + mock.Mock +} + +func (h *MockedHandler) Enabled(ctx context.Context, level slog.Level) bool { return true } +func (h *MockedHandler) WithAttrs(attrs []slog.Attr) slog.Handler { return h } +func (h *MockedHandler) WithGroup(name string) slog.Handler { return h } +func (h *MockedHandler) Handle(ctx context.Context, r slog.Record) error { + // We only care to assert the message value, and want to isolate nondetermistic behavior (e.g. timestamp) + args := h.Called(ctx, r) + return args.Error(0) +} + +// logMessage is a container for log attributes we would like to compare for equality during tests. +// It intentionally omits timestamps and other sources of nondeterminism. +type logMessage struct { + Message string + Level string + Attrs map[string]string +} + +// newLogMessage constructs a logMessage from a slog.Record, containing only deterministic fields. +func newLogMessage(r slog.Record) logMessage { + message := logMessage{Message: r.Message, Level: r.Level.String(), Attrs: make(map[string]string)} + r.Attrs(func(a slog.Attr) bool { + message.Attrs[a.Key] = a.Value.String() + return true + }) + return message +} + +// messagesEqual compares two logMessages and returns whether they are equal. +func messagesEqual(expected, actual logMessage) bool { + if expected.Message != actual.Message { + return false + } + if expected.Level != actual.Level { + return false + } + if len(expected.Attrs) != len(actual.Attrs) { + return false + } + for k, v := range expected.Attrs { + if actual.Attrs[k] != v { + return false + } + } + return true +} + func tableFromRecordReader(rdr array.RecordReader) arrow.Table { + defer rdr.Release() + recs := make([]arrow.Record, 0) for rdr.Next() { rec := rdr.Record()