Skip to content

Commit

Permalink
pass logger to connection and add/refactor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
joellubi committed Mar 15, 2024
1 parent f7b4dcc commit 822c8ad
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 33 deletions.
3 changes: 3 additions & 0 deletions go/adbc/driver/driverbase/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/apache/arrow/go/v16/arrow"
"github.com/apache/arrow/go/v16/arrow/array"
"github.com/apache/arrow/go/v16/arrow/memory"
"golang.org/x/exp/slog"
)

const (
Expand Down Expand Up @@ -102,6 +103,7 @@ type ConnectionImplBase struct {
Alloc memory.Allocator
ErrorHelper ErrorHelper
DriverInfo *DriverInfo
Logger *slog.Logger

Autocommit bool
Closed bool
Expand All @@ -116,6 +118,7 @@ func NewConnectionImplBase(database *DatabaseImplBase) ConnectionImplBase {
Alloc: database.Alloc,
ErrorHelper: database.ErrorHelper,
DriverInfo: database.DriverInfo,
Logger: database.Logger,
Autocommit: true,
Closed: false,
}
Expand Down
174 changes: 141 additions & 33 deletions go/adbc/driver/driverbase/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,29 +38,24 @@ const (
OptionKeyUnrecognized = "unrecognized"
)

type MockedHandler struct {
mock.Mock
}

func (h *MockedHandler) Enabled(ctx context.Context, level slog.Level) bool { return true }
func (h *MockedHandler) WithAttrs(attrs []slog.Attr) slog.Handler { return h }
func (h *MockedHandler) WithGroup(name string) slog.Handler { return h }
func (h *MockedHandler) Handle(ctx context.Context, r slog.Record) error {
// We only care to assert the message value, and want to isolate nondetermistic behavior (e.g. timestamp)
args := h.Called(ctx, r.Message)
return args.Error(0)
}

func NewDriver(alloc memory.Allocator, useHelpers bool) adbc.Driver {
// NewDriver creates a new adbc.Driver for testing. In addition to a memory.Allocator, it takes
// a slog.Handler to use for all structured logging as well as a useHelpers flag to determine whether
// the test should register helper methods or use the default driverbase implementation.
func NewDriver(alloc memory.Allocator, handler slog.Handler, useHelpers bool) adbc.Driver {
info := driverbase.DefaultDriverInfo("MockDriver")
_ = info.RegisterInfoCode(adbc.InfoCode(10_001), "my custom info")
return driverbase.NewDriver(&driverImpl{DriverImplBase: driverbase.NewDriverImplBase(info, alloc), useHelpers: useHelpers})
return driverbase.NewDriver(&driverImpl{DriverImplBase: driverbase.NewDriverImplBase(info, alloc), handler: handler, useHelpers: useHelpers})
}

func TestDefaultDriver(t *testing.T) {
var handler MockedHandler
handler.On("Handle", mock.Anything, mock.Anything).Return(nil)

ctx := context.TODO()
alloc := memory.DefaultAllocator
drv := NewDriver(alloc, false) // Do not use helper implementations; only default behavior
alloc := memory.NewCheckedAllocator(memory.DefaultAllocator)
defer alloc.AssertSize(t, 0)

drv := NewDriver(alloc, &handler, false) // Do not use helper implementations; only default behavior

db, err := drv.NewDatabase(nil)
require.NoError(t, err)
Expand Down Expand Up @@ -90,6 +85,8 @@ func TestDefaultDriver(t *testing.T) {

info, err := cnxn.GetInfo(ctx, nil)
require.NoError(t, err)
getInfoTable := tableFromRecordReader(info)
defer getInfoTable.Release()

// This is what the driverbase provided GetInfo result should look like out of the box,
// with one custom setting registered at initialization
Expand Down Expand Up @@ -128,9 +125,7 @@ func TestDefaultDriver(t *testing.T) {
}
]`})
require.NoError(t, err)

getInfoTable := tableFromRecordReader(info)
defer getInfoTable.Release()
defer expectedGetInfoTable.Release()

require.Truef(t, array.TableEqual(expectedGetInfoTable, getInfoTable), "expected: %s\ngot: %s", expectedGetInfoTable, getInfoTable)

Expand All @@ -157,12 +152,41 @@ func TestDefaultDriver(t *testing.T) {
err = cnxn.(adbc.GetSetOptions).SetOption(adbc.OptionKeyCurrentCatalog, "test_catalog")
require.Error(t, err)
require.Equal(t, "Not Implemented: [MockDriver] Unknown connection option 'adbc.connection.catalog'", err.Error())

// We passed a mock handler into the driver to use for logs, so we can check actual messages logged
expectedLogMessages := []logMessage{
{Message: "Opening a new connection", Level: "INFO", Attrs: map[string]string{"withHelpers": "false"}},
}

logMessages := make([]logMessage, 0, len(handler.Calls))
for _, call := range handler.Calls {
sr, ok := call.Arguments.Get(1).(slog.Record)
require.True(t, ok)
logMessages = append(logMessages, newLogMessage(sr))
}

for _, expected := range expectedLogMessages {
var found bool
for _, message := range logMessages {
if messagesEqual(message, expected) {
found = true
break
}
}
require.Truef(t, found, "expected message was never logged: %v", expected)
}

}

func TestCustomizedDriver(t *testing.T) {
var handler MockedHandler
handler.On("Handle", mock.Anything, mock.Anything).Return(nil)

ctx := context.TODO()
alloc := memory.DefaultAllocator
drv := NewDriver(alloc, true) // Use helper implementations
alloc := memory.NewCheckedAllocator(memory.DefaultAllocator)
defer alloc.AssertSize(t, 0)

drv := NewDriver(alloc, &handler, true) // Use helper implementations

db, err := drv.NewDatabase(nil)
require.NoError(t, err)
Expand All @@ -188,6 +212,8 @@ func TestCustomizedDriver(t *testing.T) {

info, err := cnxn.GetInfo(ctx, nil)
require.NoError(t, err)
getInfoTable := tableFromRecordReader(info)
defer getInfoTable.Release()

// This is the arrow table representation of GetInfo produced by merging:
// - the default DriverInfo set at initialization
Expand Down Expand Up @@ -232,14 +258,14 @@ func TestCustomizedDriver(t *testing.T) {
}
]`})
require.NoError(t, err)

getInfoTable := tableFromRecordReader(info)
defer getInfoTable.Release()
defer expectedGetInfoTable.Release()

require.Truef(t, array.TableEqual(expectedGetInfoTable, getInfoTable), "expected: %s\ngot: %s", expectedGetInfoTable, getInfoTable)

dbObjects, err := cnxn.GetObjects(ctx, adbc.ObjectDepthAll, nil, nil, nil, nil, nil)
require.NoError(t, err)
dbObjectsTable := tableFromRecordReader(dbObjects)
defer dbObjectsTable.Release()

// This is the arrow table representation of the GetObjects output we get by implementing
// the simplified TableTypeLister interface
Expand Down Expand Up @@ -289,14 +315,14 @@ func TestCustomizedDriver(t *testing.T) {
}
]`})
require.NoError(t, err)

dbObjectsTable := tableFromRecordReader(dbObjects)
defer dbObjectsTable.Release()
defer expectedDbObjectsTable.Release()

require.Truef(t, array.TableEqual(expectedDbObjectsTable, dbObjectsTable), "expected: %s\ngot: %s", expectedDbObjectsTable, dbObjectsTable)

tableTypes, err := cnxn.GetTableTypes(ctx)
require.NoError(t, err)
tableTypeTable := tableFromRecordReader(tableTypes)
defer tableTypeTable.Release()

// This is the arrow table representation of the GetTableTypes output we get by implementing
// the simplified TableTypeLister interface
Expand All @@ -305,9 +331,7 @@ func TestCustomizedDriver(t *testing.T) {
{ "table_type": "VIEW" }
]`})
require.NoError(t, err)

tableTypeTable := tableFromRecordReader(tableTypes)
defer tableTypeTable.Release()
defer expectedTableTypesTable.Release()

require.Truef(t, array.TableEqual(expectedTableTypesTable, tableTypeTable), "expected: %s\ngot: %s", expectedTableTypesTable, tableTypeTable)

Expand Down Expand Up @@ -349,22 +373,48 @@ func TestCustomizedDriver(t *testing.T) {
currentDbSchema, err := cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentDbSchema)
require.NoError(t, err)
require.Equal(t, "test_schema", currentDbSchema)

// We passed a mock handler into the driver to use for logs, so we can check actual messages logged
expectedLogMessages := []logMessage{
{Message: "Opening a new connection", Level: "INFO", Attrs: map[string]string{"withHelpers": "true"}},
{Message: "SetAutocommit", Level: "DEBUG", Attrs: map[string]string{"enabled": "false"}},
{Message: "SetCurrentCatalog", Level: "DEBUG", Attrs: map[string]string{"val": "test_catalog"}},
{Message: "SetCurrentDbSchema", Level: "DEBUG", Attrs: map[string]string{"val": "test_schema"}},
}

logMessages := make([]logMessage, 0, len(handler.Calls))
for _, call := range handler.Calls {
sr, ok := call.Arguments.Get(1).(slog.Record)
require.True(t, ok)
logMessages = append(logMessages, newLogMessage(sr))
}

for _, expected := range expectedLogMessages {
var found bool
for _, message := range logMessages {
if messagesEqual(message, expected) {
found = true
break
}
}
require.Truef(t, found, "expected message was never logged: %v", expected)
}
}

type driverImpl struct {
driverbase.DriverImplBase

handler slog.Handler
useHelpers bool
}

func (drv *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error) {
var handler MockedHandler
db := driverbase.NewDatabase(
&databaseImpl{DatabaseImplBase: driverbase.NewDatabaseImplBase(&drv.DriverImplBase),
drv: drv,
useHelpers: drv.useHelpers,
})
db.SetLogger(slog.New(&handler))
db.SetLogger(slog.New(drv.handler))
return db, nil
}

Expand Down Expand Up @@ -397,6 +447,7 @@ func (d *databaseImpl) SetOption(key, value string) error {
}

func (db *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) {
db.DatabaseImplBase.Logger.Info("Opening a new connection", "withHelpers", db.useHelpers)
cnxn := &connectionImpl{ConnectionImplBase: driverbase.NewConnectionImplBase(&db.DatabaseImplBase), db: db}
bldr := driverbase.NewConnectionBuilder(cnxn)
if db.useHelpers { // this toggles between the NewDefaultDriver and NewCustomizedDriver scenarios
Expand All @@ -420,6 +471,7 @@ type connectionImpl struct {
}

func (c *connectionImpl) SetAutocommit(enabled bool) error {
c.Base().Logger.Debug("SetAutocommit", "enabled", enabled)
return nil
}

Expand All @@ -438,11 +490,13 @@ func (c *connectionImpl) GetCurrentDbSchema() (string, bool) {
}

func (c *connectionImpl) SetCurrentCatalog(val string) error {
c.Base().Logger.Debug("SetCurrentCatalog", "val", val)
c.currentCatalog = val
return nil
}

func (c *connectionImpl) SetCurrentDbSchema(val string) error {
c.Base().Logger.Debug("SetCurrentDbSchema", "val", val)
c.currentDbSchema = val
return nil
}
Expand Down Expand Up @@ -474,7 +528,61 @@ func (c *connectionImpl) GetObjectsTables(ctx context.Context, depth adbc.Object
}, nil
}

// MockedHandler is a mock.Mock that implements the slog.Handler interface.
// It is used to assert specific behavior for loggers it is injected into.
type MockedHandler struct {
mock.Mock
}

func (h *MockedHandler) Enabled(ctx context.Context, level slog.Level) bool { return true }
func (h *MockedHandler) WithAttrs(attrs []slog.Attr) slog.Handler { return h }
func (h *MockedHandler) WithGroup(name string) slog.Handler { return h }
func (h *MockedHandler) Handle(ctx context.Context, r slog.Record) error {
// We only care to assert the message value, and want to isolate nondetermistic behavior (e.g. timestamp)
args := h.Called(ctx, r)
return args.Error(0)
}

// logMessage is a container for log attributes we would like to compare for equality during tests.
// It intentionally omits timestamps and other sources of nondeterminism.
type logMessage struct {
Message string
Level string
Attrs map[string]string
}

// newLogMessage constructs a logMessage from a slog.Record, containing only deterministic fields.
func newLogMessage(r slog.Record) logMessage {
message := logMessage{Message: r.Message, Level: r.Level.String(), Attrs: make(map[string]string)}
r.Attrs(func(a slog.Attr) bool {
message.Attrs[a.Key] = a.Value.String()
return true
})
return message
}

// messagesEqual compares two logMessages and returns whether they are equal.
func messagesEqual(expected, actual logMessage) bool {
if expected.Message != actual.Message {
return false
}
if expected.Level != actual.Level {
return false
}
if len(expected.Attrs) != len(actual.Attrs) {
return false
}
for k, v := range expected.Attrs {
if actual.Attrs[k] != v {
return false
}
}
return true
}

func tableFromRecordReader(rdr array.RecordReader) arrow.Table {
defer rdr.Release()

recs := make([]arrow.Record, 0)
for rdr.Next() {
rec := rdr.Record()
Expand Down

0 comments on commit 822c8ad

Please sign in to comment.