Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(go/adbc/driver/flightsql): expose FlightInfo during polling #1582

Merged
merged 2 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions go/adbc/driver/flightsql/cmd/testserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,32 @@ func (srv *ExampleServer) PollFlightInfo(ctx context.Context, desc *flight.Fligh
return nil, err
}

srv.pollingStatus[val.Value]--
progress := srv.pollingStatus[val.Value]

ticket, err := flightsql.CreateStatementQueryTicket([]byte(val.Value))
if err != nil {
return nil, err
}

endpoints := make([]*flight.FlightEndpoint, 5-progress)
if val.Value == "forever" {
srv.pollingStatus[val.Value]++
return &flight.PollInfo{
Info: &flight.FlightInfo{
Schema: flight.SerializeSchema(arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil), srv.Alloc),
Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: ticket}}},
FlightDescriptor: desc,
TotalRecords: -1,
TotalBytes: -1,
AppMetadata: []byte("app metadata"),
},
FlightDescriptor: desc,
Progress: proto.Float64(float64(srv.pollingStatus[val.Value]) / 100.0),
}, nil
}

srv.pollingStatus[val.Value]--
progress := srv.pollingStatus[val.Value]

numEndpoints := 5 - progress
endpoints := make([]*flight.FlightEndpoint, numEndpoints)
for i := range endpoints {
endpoints[i] = &flight.FlightEndpoint{Ticket: &flight.Ticket{Ticket: ticket}}
}
Expand Down
94 changes: 94 additions & 0 deletions go/adbc/driver/flightsql/flightsql_adbc_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,32 @@ func (srv *IncrementalPollTestServer) PollFlightInfo(ctx context.Context, desc *
return nil, status.Errorf(codes.NotFound, "Query ID not found")
}

if query.query == "infinite" {
query.nextIndex++

descriptor, err := proto.Marshal(&wrapperspb.StringValue{Value: queryId})
if err != nil {
return nil, err
}
return &flight.PollInfo{
Info: &flight.FlightInfo{
Schema: nil,
Endpoint: []*flight.FlightEndpoint{{
Ticket: &flight.Ticket{
Ticket: []byte{},
},
}},
AppMetadata: []byte("app metadata"),
},
FlightDescriptor: &flight.FlightDescriptor{
Type: flight.DescriptorCMD,
Cmd: descriptor,
},
// always makes a bit of progress, never gets anywhere
Progress: proto.Float64(float64(query.nextIndex) / 100.0),
}, nil
}

testCase, ok := srv.testCases[query.query]
if !ok {
if query.query == "unavailable" {
Expand Down Expand Up @@ -518,6 +544,32 @@ func (srv *IncrementalPollTestServer) PollFlightInfoStatement(ctx context.Contex
}

return srv.MakePollInfo(&unavailableCase, srv.queries[queryId], queryId)
} else if query.GetQuery() == "infinite" {
srv.queries[queryId] = &IncrementalQuery{
query: query.GetQuery(),
nextIndex: 0,
}

descriptor, err := proto.Marshal(&wrapperspb.StringValue{Value: queryId})
if err != nil {
return nil, err
}
return &flight.PollInfo{
Info: &flight.FlightInfo{
Schema: nil,
Endpoint: []*flight.FlightEndpoint{{
Ticket: &flight.Ticket{
Ticket: []byte{},
},
}},
AppMetadata: []byte("app metadata"),
},
FlightDescriptor: &flight.FlightDescriptor{
Type: flight.DescriptorCMD,
Cmd: descriptor,
},
Progress: proto.Float64(0),
}, nil
}

testCase, ok := srv.testCases[query.GetQuery()]
Expand Down Expand Up @@ -727,6 +779,48 @@ func (ts *IncrementalPollTests) TestOptionValue() {
ts.Equal(adbc.StatusInvalidArgument, adbcErr.Code)
}

func (ts *IncrementalPollTests) TestAppMetadata() {
ctx, cancel := context.WithCancel(context.Background())
stmt, err := ts.cnxn.NewStatement()
ts.NoError(err)
defer stmt.Close()

ts.NoError(stmt.SetOption(adbc.OptionKeyIncremental, adbc.OptionValueEnabled))

ts.NoError(stmt.SetSqlQuery("infinite"))
_, partitions, _, err := stmt.ExecutePartitions(ctx)
ts.NoError(err)
ts.Equalf(uint64(1), partitions.NumPartitions, "%#v", partitions)

progress := 0.0
go func() {
var err error
var info []byte
for {
// While the below is stuck, we should be able to get the app metadata and progress
progress, err = stmt.(adbc.GetSetOptions).GetOptionDouble(adbc.OptionKeyProgress)
ts.NoError(err)

info, err = stmt.(adbc.GetSetOptions).GetOptionBytes(driver.OptionLastFlightInfo)
ts.NoError(err)
var flightInfo flight.FlightInfo
ts.NoError(proto.Unmarshal(info, &flightInfo))
ts.Equal([]byte("app metadata"), flightInfo.AppMetadata)

if progress > 0.03 {
break
}
}
cancel()
}()

// will get stuck forever, but will "make progress"
_, _, _, err = stmt.ExecutePartitions(ctx)
var adbcErr adbc.Error
ts.ErrorAs(err, &adbcErr)
ts.Equal(adbc.StatusCancelled, adbcErr.Code)
}

func (ts *IncrementalPollTests) TestUnavailable() {
// An error from the server should not tear down all the state. We
// should be able to retry the request.
Expand Down
1 change: 1 addition & 0 deletions go/adbc/driver/flightsql/flightsql_driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ const (
OptionTimeoutUpdate = "adbc.flight.sql.rpc.timeout_seconds.update"
OptionRPCCallHeaderPrefix = "adbc.flight.sql.rpc.call_header."
OptionCookieMiddleware = "adbc.flight.sql.rpc.with_cookie_middleware"
OptionLastFlightInfo = "adbc.flight.sql.statement.exec.last_flight_info"
infoDriverName = "ADBC Flight SQL Driver - Go"
)

Expand Down
17 changes: 17 additions & 0 deletions go/adbc/driver/flightsql/flightsql_statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ type statement struct {
timeouts timeoutOption
incrementalState *incrementalState
progress float64
// may seem redundant, but incrementalState isn't locked
lastInfo atomic.Pointer[flight.FlightInfo]
}

func (s *statement) closePreparedStatement() error {
Expand Down Expand Up @@ -249,6 +251,18 @@ func (s *statement) GetOption(key string) (string, error) {
}
}
func (s *statement) GetOptionBytes(key string) ([]byte, error) {
switch key {
case OptionLastFlightInfo:
info := s.lastInfo.Load()
serialized, err := proto.Marshal(info)
if err != nil {
return nil, adbc.Error{
Msg: fmt.Sprintf("[Flight SQL] Could not serialize result for '%s': %s", key, err.Error()),
Code: adbc.StatusInternal,
}
}
return serialized, nil
}
return nil, adbc.Error{
Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key),
Code: adbc.StatusNotFound,
Expand Down Expand Up @@ -594,6 +608,7 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc.
// Reset the statement for reuse
s.incrementalState = &incrementalState{}
atomicStoreFloat64(&s.progress, 0.0)
s.lastInfo.Store(nil)
return schema, adbc.Partitions{}, totalRecords, nil
}

Expand Down Expand Up @@ -628,6 +643,7 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc.
s.incrementalState.previousInfo = poll.GetInfo()
s.incrementalState.retryDescriptor = poll.GetFlightDescriptor()
atomicStoreFloat64(&s.progress, poll.GetProgress())
s.lastInfo.Store(poll.GetInfo())

if s.incrementalState.retryDescriptor == nil {
// Query is finished
Expand All @@ -651,6 +667,7 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc.
if s.incrementalState.complete && len(info.Endpoint) == 0 {
s.incrementalState = &incrementalState{}
atomicStoreFloat64(&s.progress, 0.0)
s.lastInfo.Store(nil)
}
} else if s.prepared != nil {
info, err = s.prepared.Execute(ctx, grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,15 @@ class ConnectionOptions(enum.Enum):
class StatementOptions(enum.Enum):
"""Statement options specific to the Flight SQL driver."""

#: The latest FlightInfo value.
#:
#: Thread-safe. Mostly useful when using incremental execution, where an
#: advanced client may want to inspect the latest FlightInfo from the
#: service, but without waiting for execute_partitions to return. (The
#: service may send an updated FlightInfo with progress/app_metadata
#: values, but execute_partitions will only return if there are new
#: endpoints.)
LAST_FLIGHT_INFO = "adbc.flight.sql.statement.exec.last_flight_info"
#: The number of batches to queue per partition. Defaults to 5.
#:
#: This controls how much we read ahead on result sets.
Expand Down
44 changes: 44 additions & 0 deletions python/adbc_driver_flightsql/tests/test_incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
# under the License.

import re
import threading

import google.protobuf.any_pb2 as any_pb2
import google.protobuf.wrappers_pb2 as wrappers_pb2
import pyarrow
import pyarrow.flight
import pytest

import adbc_driver_manager
from adbc_driver_flightsql import StatementOptions as FlightSqlStatementOptions
from adbc_driver_manager import StatementOptions

SCHEMA = pyarrow.schema([("ints", "int32")])
Expand Down Expand Up @@ -106,6 +109,47 @@ def test_incremental_error_poll(test_dbapi) -> None:
assert partitions == []


def test_incremental_cancel(test_dbapi) -> None:
with test_dbapi.cursor() as cur:
cur.adbc_statement.set_options(
**{
StatementOptions.INCREMENTAL.value: "true",
}
)
partitions, schema = cur.adbc_execute_partitions("forever")
assert len(partitions) == 1

passed = False

def _bg():
nonlocal passed
while True:
progress = cur.adbc_statement.get_option_float(
StatementOptions.PROGRESS.value
)
# XXX: upstream PyArrow never bothered exposing app_metadata
raw_info = cur.adbc_statement.get_option_bytes(
FlightSqlStatementOptions.LAST_FLIGHT_INFO.value
)

# check that it's a valid info
pyarrow.flight.FlightInfo.deserialize(raw_info)
passed = b"app metadata" in raw_info

if progress > 0.07:
break
cur.adbc_cancel()

t = threading.Thread(target=_bg, daemon=True)
t.start()

with pytest.raises(test_dbapi.OperationalError, match="(?i)cancelled"):
cur.adbc_execute_partitions("forever")

t.join()
assert passed


def test_incremental_immediately(test_dbapi) -> None:
with test_dbapi.cursor() as cur:
cur.adbc_statement.set_options(
Expand Down
Loading