From d9629615155b5bfebf288fd56c1f6272ea22c3bb Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 1 Mar 2024 18:20:59 -0500
Subject: [PATCH] feat(go/adbc/driver/flightsql): expose FlightInfo during
polling (#1582)
Fixes #1571.
---
.../driver/flightsql/cmd/testserver/main.go | 25 ++++-
.../flightsql/flightsql_adbc_server_test.go | 94 +++++++++++++++++++
go/adbc/driver/flightsql/flightsql_driver.go | 1 +
.../driver/flightsql/flightsql_statement.go | 21 +++++
.../adbc_driver_flightsql/__init__.py | 9 ++
.../tests/test_incremental.py | 51 ++++++++++
6 files changed, 197 insertions(+), 4 deletions(-)
diff --git a/go/adbc/driver/flightsql/cmd/testserver/main.go b/go/adbc/driver/flightsql/cmd/testserver/main.go
index 22a928adb7..8ce65c9f72 100644
--- a/go/adbc/driver/flightsql/cmd/testserver/main.go
+++ b/go/adbc/driver/flightsql/cmd/testserver/main.go
@@ -134,15 +134,32 @@ func (srv *ExampleServer) PollFlightInfo(ctx context.Context, desc *flight.Fligh
return nil, err
}
- srv.pollingStatus[val.Value]--
- progress := srv.pollingStatus[val.Value]
-
ticket, err := flightsql.CreateStatementQueryTicket([]byte(val.Value))
if err != nil {
return nil, err
}
- endpoints := make([]*flight.FlightEndpoint, 5-progress)
+ if val.Value == "forever" {
+ srv.pollingStatus[val.Value]++
+ return &flight.PollInfo{
+ Info: &flight.FlightInfo{
+ Schema: flight.SerializeSchema(arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil), srv.Alloc),
+ Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: ticket}}},
+ FlightDescriptor: desc,
+ TotalRecords: -1,
+ TotalBytes: -1,
+ AppMetadata: []byte("app metadata"),
+ },
+ FlightDescriptor: desc,
+ Progress: proto.Float64(float64(srv.pollingStatus[val.Value]) / 100.0),
+ }, nil
+ }
+
+ srv.pollingStatus[val.Value]--
+ progress := srv.pollingStatus[val.Value]
+
+ numEndpoints := 5 - progress
+ endpoints := make([]*flight.FlightEndpoint, numEndpoints)
for i := range endpoints {
endpoints[i] = &flight.FlightEndpoint{Ticket: &flight.Ticket{Ticket: ticket}}
}
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
index 66e94da44e..78ae01b441 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -546,6 +546,32 @@ func (srv *IncrementalPollTestServer) PollFlightInfo(ctx context.Context, desc *
return nil, status.Errorf(codes.NotFound, "Query ID not found")
}
+ if query.query == "infinite" {
+ query.nextIndex++
+
+ descriptor, err := proto.Marshal(&wrapperspb.StringValue{Value: queryId})
+ if err != nil {
+ return nil, err
+ }
+ return &flight.PollInfo{
+ Info: &flight.FlightInfo{
+ Schema: nil,
+ Endpoint: []*flight.FlightEndpoint{{
+ Ticket: &flight.Ticket{
+ Ticket: []byte{},
+ },
+ }},
+ AppMetadata: []byte("app metadata"),
+ },
+ FlightDescriptor: &flight.FlightDescriptor{
+ Type: flight.DescriptorCMD,
+ Cmd: descriptor,
+ },
+ // always makes a bit of progress, never gets anywhere
+ Progress: proto.Float64(float64(query.nextIndex) / 100.0),
+ }, nil
+ }
+
testCase, ok := srv.testCases[query.query]
if !ok {
if query.query == "unavailable" {
@@ -581,6 +607,32 @@ func (srv *IncrementalPollTestServer) PollFlightInfoStatement(ctx context.Contex
}
return srv.MakePollInfo(&unavailableCase, srv.queries[queryId], queryId)
+ } else if query.GetQuery() == "infinite" {
+ srv.queries[queryId] = &IncrementalQuery{
+ query: query.GetQuery(),
+ nextIndex: 0,
+ }
+
+ descriptor, err := proto.Marshal(&wrapperspb.StringValue{Value: queryId})
+ if err != nil {
+ return nil, err
+ }
+ return &flight.PollInfo{
+ Info: &flight.FlightInfo{
+ Schema: nil,
+ Endpoint: []*flight.FlightEndpoint{{
+ Ticket: &flight.Ticket{
+ Ticket: []byte{},
+ },
+ }},
+ AppMetadata: []byte("app metadata"),
+ },
+ FlightDescriptor: &flight.FlightDescriptor{
+ Type: flight.DescriptorCMD,
+ Cmd: descriptor,
+ },
+ Progress: proto.Float64(0),
+ }, nil
}
testCase, ok := srv.testCases[query.GetQuery()]
@@ -790,6 +842,48 @@ func (ts *IncrementalPollTests) TestOptionValue() {
ts.Equal(adbc.StatusInvalidArgument, adbcErr.Code)
}
+func (ts *IncrementalPollTests) TestAppMetadata() {
+ ctx, cancel := context.WithCancel(context.Background())
+ stmt, err := ts.cnxn.NewStatement()
+ ts.NoError(err)
+ defer stmt.Close()
+
+ ts.NoError(stmt.SetOption(adbc.OptionKeyIncremental, adbc.OptionValueEnabled))
+
+ ts.NoError(stmt.SetSqlQuery("infinite"))
+ _, partitions, _, err := stmt.ExecutePartitions(ctx)
+ ts.NoError(err)
+ ts.Equalf(uint64(1), partitions.NumPartitions, "%#v", partitions)
+
+ progress := 0.0
+ go func() {
+ var err error
+ var info []byte
+ for {
+ // While the below is stuck, we should be able to get the app metadata and progress
+ progress, err = stmt.(adbc.GetSetOptions).GetOptionDouble(adbc.OptionKeyProgress)
+ ts.NoError(err)
+
+ info, err = stmt.(adbc.GetSetOptions).GetOptionBytes(driver.OptionLastFlightInfo)
+ ts.NoError(err)
+ var flightInfo flight.FlightInfo
+ ts.NoError(proto.Unmarshal(info, &flightInfo))
+ ts.Equal([]byte("app metadata"), flightInfo.AppMetadata)
+
+ if progress > 0.03 {
+ break
+ }
+ }
+ cancel()
+ }()
+
+ // will get stuck forever, but will "make progress"
+ _, _, _, err = stmt.ExecutePartitions(ctx)
+ var adbcErr adbc.Error
+ ts.ErrorAs(err, &adbcErr)
+ ts.Equal(adbc.StatusCancelled, adbcErr.Code)
+}
+
func (ts *IncrementalPollTests) TestUnavailable() {
// An error from the server should not tear down all the state. We
// should be able to retry the request.
diff --git a/go/adbc/driver/flightsql/flightsql_driver.go b/go/adbc/driver/flightsql/flightsql_driver.go
index 4914ad1cba..df1ae688b4 100644
--- a/go/adbc/driver/flightsql/flightsql_driver.go
+++ b/go/adbc/driver/flightsql/flightsql_driver.go
@@ -60,6 +60,7 @@ const (
OptionTimeoutUpdate = "adbc.flight.sql.rpc.timeout_seconds.update"
OptionRPCCallHeaderPrefix = "adbc.flight.sql.rpc.call_header."
OptionCookieMiddleware = "adbc.flight.sql.rpc.with_cookie_middleware"
+ OptionLastFlightInfo = "adbc.flight.sql.statement.exec.last_flight_info"
infoDriverName = "ADBC Flight SQL Driver - Go"
)
diff --git a/go/adbc/driver/flightsql/flightsql_statement.go b/go/adbc/driver/flightsql/flightsql_statement.go
index a1e33fd3e7..d78b653c81 100644
--- a/go/adbc/driver/flightsql/flightsql_statement.go
+++ b/go/adbc/driver/flightsql/flightsql_statement.go
@@ -166,6 +166,8 @@ type statement struct {
timeouts timeoutOption
incrementalState *incrementalState
progress float64
+ // may seem redundant, but incrementalState isn't locked
+ lastInfo atomic.Pointer[flight.FlightInfo]
}
func (s *statement) closePreparedStatement() error {
@@ -184,6 +186,7 @@ func (s *statement) clearIncrementalQuery() error {
}
}
s.incrementalState = &incrementalState{}
+ s.lastInfo.Store(nil)
}
return nil
}
@@ -249,6 +252,21 @@ func (s *statement) GetOption(key string) (string, error) {
}
}
func (s *statement) GetOptionBytes(key string) ([]byte, error) {
+ switch key {
+ case OptionLastFlightInfo:
+ info := s.lastInfo.Load()
+ if info == nil {
+ return []byte{}, nil
+ }
+ serialized, err := proto.Marshal(info)
+ if err != nil {
+ return nil, adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] Could not serialize result for '%s': %s", key, err.Error()),
+ Code: adbc.StatusInternal,
+ }
+ }
+ return serialized, nil
+ }
return nil, adbc.Error{
Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key),
Code: adbc.StatusNotFound,
@@ -594,6 +612,7 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc.
// Reset the statement for reuse
s.incrementalState = &incrementalState{}
atomicStoreFloat64(&s.progress, 0.0)
+ s.lastInfo.Store(nil)
return schema, adbc.Partitions{}, totalRecords, nil
}
@@ -628,6 +647,7 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc.
s.incrementalState.previousInfo = poll.GetInfo()
s.incrementalState.retryDescriptor = poll.GetFlightDescriptor()
atomicStoreFloat64(&s.progress, poll.GetProgress())
+ s.lastInfo.Store(poll.GetInfo())
if s.incrementalState.retryDescriptor == nil {
// Query is finished
@@ -651,6 +671,7 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc.
if s.incrementalState.complete && len(info.Endpoint) == 0 {
s.incrementalState = &incrementalState{}
atomicStoreFloat64(&s.progress, 0.0)
+ s.lastInfo.Store(nil)
}
} else if s.prepared != nil {
info, err = s.prepared.Execute(ctx, grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts)
diff --git a/python/adbc_driver_flightsql/adbc_driver_flightsql/__init__.py b/python/adbc_driver_flightsql/adbc_driver_flightsql/__init__.py
index 1b9adf319e..1af0c199d3 100644
--- a/python/adbc_driver_flightsql/adbc_driver_flightsql/__init__.py
+++ b/python/adbc_driver_flightsql/adbc_driver_flightsql/__init__.py
@@ -104,6 +104,15 @@ class ConnectionOptions(enum.Enum):
class StatementOptions(enum.Enum):
"""Statement options specific to the Flight SQL driver."""
+ #: The latest FlightInfo value.
+ #:
+ #: Thread-safe. Mostly useful when using incremental execution, where an
+ #: advanced client may want to inspect the latest FlightInfo from the
+ #: service, but without waiting for execute_partitions to return. (The
+ #: service may send an updated FlightInfo with progress/app_metadata
+ #: values, but execute_partitions will only return if there are new
+ #: endpoints.)
+ LAST_FLIGHT_INFO = "adbc.flight.sql.statement.exec.last_flight_info"
#: The number of batches to queue per partition. Defaults to 5.
#:
#: This controls how much we read ahead on result sets.
diff --git a/python/adbc_driver_flightsql/tests/test_incremental.py b/python/adbc_driver_flightsql/tests/test_incremental.py
index 8f47fd3923..285c18a831 100644
--- a/python/adbc_driver_flightsql/tests/test_incremental.py
+++ b/python/adbc_driver_flightsql/tests/test_incremental.py
@@ -16,13 +16,16 @@
# under the License.
import re
+import threading
import google.protobuf.any_pb2 as any_pb2
import google.protobuf.wrappers_pb2 as wrappers_pb2
import pyarrow
+import pyarrow.flight
import pytest
import adbc_driver_manager
+from adbc_driver_flightsql import StatementOptions as FlightSqlStatementOptions
from adbc_driver_manager import StatementOptions
SCHEMA = pyarrow.schema([("ints", "int32")])
@@ -106,6 +109,54 @@ def test_incremental_error_poll(test_dbapi) -> None:
assert partitions == []
+def test_incremental_cancel(test_dbapi) -> None:
+ with test_dbapi.cursor() as cur:
+ assert (
+ cur.adbc_statement.get_option_bytes(
+ FlightSqlStatementOptions.LAST_FLIGHT_INFO.value
+ )
+ == b""
+ )
+
+ cur.adbc_statement.set_options(
+ **{
+ StatementOptions.INCREMENTAL.value: "true",
+ }
+ )
+ partitions, schema = cur.adbc_execute_partitions("forever")
+ assert len(partitions) == 1
+
+ passed = False
+
+ def _bg():
+ nonlocal passed
+ while True:
+ progress = cur.adbc_statement.get_option_float(
+ StatementOptions.PROGRESS.value
+ )
+ # XXX: upstream PyArrow never bothered exposing app_metadata
+ raw_info = cur.adbc_statement.get_option_bytes(
+ FlightSqlStatementOptions.LAST_FLIGHT_INFO.value
+ )
+
+ # check that it's a valid info
+ pyarrow.flight.FlightInfo.deserialize(raw_info)
+ passed = b"app metadata" in raw_info
+
+ if progress > 0.07:
+ break
+ cur.adbc_cancel()
+
+ t = threading.Thread(target=_bg, daemon=True)
+ t.start()
+
+ with pytest.raises(test_dbapi.OperationalError, match="(?i)cancelled"):
+ cur.adbc_execute_partitions("forever")
+
+ t.join()
+ assert passed
+
+
def test_incremental_immediately(test_dbapi) -> None:
with test_dbapi.cursor() as cur:
cur.adbc_statement.set_options(