diff --git a/go/adbc/driver/flightsql/cmd/testserver/main.go b/go/adbc/driver/flightsql/cmd/testserver/main.go index 22a928adb7..8ce65c9f72 100644 --- a/go/adbc/driver/flightsql/cmd/testserver/main.go +++ b/go/adbc/driver/flightsql/cmd/testserver/main.go @@ -134,15 +134,32 @@ func (srv *ExampleServer) PollFlightInfo(ctx context.Context, desc *flight.Fligh return nil, err } - srv.pollingStatus[val.Value]-- - progress := srv.pollingStatus[val.Value] - ticket, err := flightsql.CreateStatementQueryTicket([]byte(val.Value)) if err != nil { return nil, err } - endpoints := make([]*flight.FlightEndpoint, 5-progress) + if val.Value == "forever" { + srv.pollingStatus[val.Value]++ + return &flight.PollInfo{ + Info: &flight.FlightInfo{ + Schema: flight.SerializeSchema(arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil), srv.Alloc), + Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: ticket}}}, + FlightDescriptor: desc, + TotalRecords: -1, + TotalBytes: -1, + AppMetadata: []byte("app metadata"), + }, + FlightDescriptor: desc, + Progress: proto.Float64(float64(srv.pollingStatus[val.Value]) / 100.0), + }, nil + } + + srv.pollingStatus[val.Value]-- + progress := srv.pollingStatus[val.Value] + + numEndpoints := 5 - progress + endpoints := make([]*flight.FlightEndpoint, numEndpoints) for i := range endpoints { endpoints[i] = &flight.FlightEndpoint{Ticket: &flight.Ticket{Ticket: ticket}} } diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go index 66e94da44e..78ae01b441 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go +++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go @@ -546,6 +546,32 @@ func (srv *IncrementalPollTestServer) PollFlightInfo(ctx context.Context, desc * return nil, status.Errorf(codes.NotFound, "Query ID not found") } + if query.query == "infinite" { + query.nextIndex++ + + descriptor, err := proto.Marshal(&wrapperspb.StringValue{Value: queryId}) + if err != nil { + return nil, err + } + return &flight.PollInfo{ + Info: &flight.FlightInfo{ + Schema: nil, + Endpoint: []*flight.FlightEndpoint{{ + Ticket: &flight.Ticket{ + Ticket: []byte{}, + }, + }}, + AppMetadata: []byte("app metadata"), + }, + FlightDescriptor: &flight.FlightDescriptor{ + Type: flight.DescriptorCMD, + Cmd: descriptor, + }, + // always makes a bit of progress, never gets anywhere + Progress: proto.Float64(float64(query.nextIndex) / 100.0), + }, nil + } + testCase, ok := srv.testCases[query.query] if !ok { if query.query == "unavailable" { @@ -581,6 +607,32 @@ func (srv *IncrementalPollTestServer) PollFlightInfoStatement(ctx context.Contex } return srv.MakePollInfo(&unavailableCase, srv.queries[queryId], queryId) + } else if query.GetQuery() == "infinite" { + srv.queries[queryId] = &IncrementalQuery{ + query: query.GetQuery(), + nextIndex: 0, + } + + descriptor, err := proto.Marshal(&wrapperspb.StringValue{Value: queryId}) + if err != nil { + return nil, err + } + return &flight.PollInfo{ + Info: &flight.FlightInfo{ + Schema: nil, + Endpoint: []*flight.FlightEndpoint{{ + Ticket: &flight.Ticket{ + Ticket: []byte{}, + }, + }}, + AppMetadata: []byte("app metadata"), + }, + FlightDescriptor: &flight.FlightDescriptor{ + Type: flight.DescriptorCMD, + Cmd: descriptor, + }, + Progress: proto.Float64(0), + }, nil } testCase, ok := srv.testCases[query.GetQuery()] @@ -790,6 +842,48 @@ func (ts *IncrementalPollTests) TestOptionValue() { ts.Equal(adbc.StatusInvalidArgument, adbcErr.Code) } +func (ts *IncrementalPollTests) TestAppMetadata() { + ctx, cancel := context.WithCancel(context.Background()) + stmt, err := ts.cnxn.NewStatement() + ts.NoError(err) + defer stmt.Close() + + ts.NoError(stmt.SetOption(adbc.OptionKeyIncremental, adbc.OptionValueEnabled)) + + ts.NoError(stmt.SetSqlQuery("infinite")) + _, partitions, _, err := stmt.ExecutePartitions(ctx) + ts.NoError(err) + ts.Equalf(uint64(1), partitions.NumPartitions, "%#v", partitions) + + progress := 0.0 + go func() { + var err error + var info []byte + for { + // While the below is stuck, we should be able to get the app metadata and progress + progress, err = stmt.(adbc.GetSetOptions).GetOptionDouble(adbc.OptionKeyProgress) + ts.NoError(err) + + info, err = stmt.(adbc.GetSetOptions).GetOptionBytes(driver.OptionLastFlightInfo) + ts.NoError(err) + var flightInfo flight.FlightInfo + ts.NoError(proto.Unmarshal(info, &flightInfo)) + ts.Equal([]byte("app metadata"), flightInfo.AppMetadata) + + if progress > 0.03 { + break + } + } + cancel() + }() + + // will get stuck forever, but will "make progress" + _, _, _, err = stmt.ExecutePartitions(ctx) + var adbcErr adbc.Error + ts.ErrorAs(err, &adbcErr) + ts.Equal(adbc.StatusCancelled, adbcErr.Code) +} + func (ts *IncrementalPollTests) TestUnavailable() { // An error from the server should not tear down all the state. We // should be able to retry the request. diff --git a/go/adbc/driver/flightsql/flightsql_driver.go b/go/adbc/driver/flightsql/flightsql_driver.go index 4914ad1cba..df1ae688b4 100644 --- a/go/adbc/driver/flightsql/flightsql_driver.go +++ b/go/adbc/driver/flightsql/flightsql_driver.go @@ -60,6 +60,7 @@ const ( OptionTimeoutUpdate = "adbc.flight.sql.rpc.timeout_seconds.update" OptionRPCCallHeaderPrefix = "adbc.flight.sql.rpc.call_header." OptionCookieMiddleware = "adbc.flight.sql.rpc.with_cookie_middleware" + OptionLastFlightInfo = "adbc.flight.sql.statement.exec.last_flight_info" infoDriverName = "ADBC Flight SQL Driver - Go" ) diff --git a/go/adbc/driver/flightsql/flightsql_statement.go b/go/adbc/driver/flightsql/flightsql_statement.go index a1e33fd3e7..d78b653c81 100644 --- a/go/adbc/driver/flightsql/flightsql_statement.go +++ b/go/adbc/driver/flightsql/flightsql_statement.go @@ -166,6 +166,8 @@ type statement struct { timeouts timeoutOption incrementalState *incrementalState progress float64 + // may seem redundant, but incrementalState isn't locked + lastInfo atomic.Pointer[flight.FlightInfo] } func (s *statement) closePreparedStatement() error { @@ -184,6 +186,7 @@ func (s *statement) clearIncrementalQuery() error { } } s.incrementalState = &incrementalState{} + s.lastInfo.Store(nil) } return nil } @@ -249,6 +252,21 @@ func (s *statement) GetOption(key string) (string, error) { } } func (s *statement) GetOptionBytes(key string) ([]byte, error) { + switch key { + case OptionLastFlightInfo: + info := s.lastInfo.Load() + if info == nil { + return []byte{}, nil + } + serialized, err := proto.Marshal(info) + if err != nil { + return nil, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Could not serialize result for '%s': %s", key, err.Error()), + Code: adbc.StatusInternal, + } + } + return serialized, nil + } return nil, adbc.Error{ Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), Code: adbc.StatusNotFound, @@ -594,6 +612,7 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc. // Reset the statement for reuse s.incrementalState = &incrementalState{} atomicStoreFloat64(&s.progress, 0.0) + s.lastInfo.Store(nil) return schema, adbc.Partitions{}, totalRecords, nil } @@ -628,6 +647,7 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc. s.incrementalState.previousInfo = poll.GetInfo() s.incrementalState.retryDescriptor = poll.GetFlightDescriptor() atomicStoreFloat64(&s.progress, poll.GetProgress()) + s.lastInfo.Store(poll.GetInfo()) if s.incrementalState.retryDescriptor == nil { // Query is finished @@ -651,6 +671,7 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc. if s.incrementalState.complete && len(info.Endpoint) == 0 { s.incrementalState = &incrementalState{} atomicStoreFloat64(&s.progress, 0.0) + s.lastInfo.Store(nil) } } else if s.prepared != nil { info, err = s.prepared.Execute(ctx, grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts) diff --git a/python/adbc_driver_flightsql/adbc_driver_flightsql/__init__.py b/python/adbc_driver_flightsql/adbc_driver_flightsql/__init__.py index 1b9adf319e..1af0c199d3 100644 --- a/python/adbc_driver_flightsql/adbc_driver_flightsql/__init__.py +++ b/python/adbc_driver_flightsql/adbc_driver_flightsql/__init__.py @@ -104,6 +104,15 @@ class ConnectionOptions(enum.Enum): class StatementOptions(enum.Enum): """Statement options specific to the Flight SQL driver.""" + #: The latest FlightInfo value. + #: + #: Thread-safe. Mostly useful when using incremental execution, where an + #: advanced client may want to inspect the latest FlightInfo from the + #: service, but without waiting for execute_partitions to return. (The + #: service may send an updated FlightInfo with progress/app_metadata + #: values, but execute_partitions will only return if there are new + #: endpoints.) + LAST_FLIGHT_INFO = "adbc.flight.sql.statement.exec.last_flight_info" #: The number of batches to queue per partition. Defaults to 5. #: #: This controls how much we read ahead on result sets. diff --git a/python/adbc_driver_flightsql/tests/test_incremental.py b/python/adbc_driver_flightsql/tests/test_incremental.py index 8f47fd3923..285c18a831 100644 --- a/python/adbc_driver_flightsql/tests/test_incremental.py +++ b/python/adbc_driver_flightsql/tests/test_incremental.py @@ -16,13 +16,16 @@ # under the License. import re +import threading import google.protobuf.any_pb2 as any_pb2 import google.protobuf.wrappers_pb2 as wrappers_pb2 import pyarrow +import pyarrow.flight import pytest import adbc_driver_manager +from adbc_driver_flightsql import StatementOptions as FlightSqlStatementOptions from adbc_driver_manager import StatementOptions SCHEMA = pyarrow.schema([("ints", "int32")]) @@ -106,6 +109,54 @@ def test_incremental_error_poll(test_dbapi) -> None: assert partitions == [] +def test_incremental_cancel(test_dbapi) -> None: + with test_dbapi.cursor() as cur: + assert ( + cur.adbc_statement.get_option_bytes( + FlightSqlStatementOptions.LAST_FLIGHT_INFO.value + ) + == b"" + ) + + cur.adbc_statement.set_options( + **{ + StatementOptions.INCREMENTAL.value: "true", + } + ) + partitions, schema = cur.adbc_execute_partitions("forever") + assert len(partitions) == 1 + + passed = False + + def _bg(): + nonlocal passed + while True: + progress = cur.adbc_statement.get_option_float( + StatementOptions.PROGRESS.value + ) + # XXX: upstream PyArrow never bothered exposing app_metadata + raw_info = cur.adbc_statement.get_option_bytes( + FlightSqlStatementOptions.LAST_FLIGHT_INFO.value + ) + + # check that it's a valid info + pyarrow.flight.FlightInfo.deserialize(raw_info) + passed = b"app metadata" in raw_info + + if progress > 0.07: + break + cur.adbc_cancel() + + t = threading.Thread(target=_bg, daemon=True) + t.start() + + with pytest.raises(test_dbapi.OperationalError, match="(?i)cancelled"): + cur.adbc_execute_partitions("forever") + + t.join() + assert passed + + def test_incremental_immediately(test_dbapi) -> None: with test_dbapi.cursor() as cur: cur.adbc_statement.set_options(