From 92220c2b9fa341ccc428e630fc353768c5627ca5 Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 29 Feb 2024 16:41:22 -0500 Subject: [PATCH] test(python/adbc_driver_flightsql): test incremental execution Fixes #1570. --- docs/source/driver/status.rst | 6 + .../driver/flightsql/cmd/testserver/main.go | 103 ++++++++- .../tests/test_incremental.py | 204 ++++++++++++++++++ .../adbc_driver_manager/_lib.pyi | 4 +- .../adbc_driver_manager/_lib.pyx | 13 +- .../adbc_driver_manager/dbapi.py | 17 +- 6 files changed, 335 insertions(+), 12 deletions(-) create mode 100644 python/adbc_driver_flightsql/tests/test_incremental.py diff --git a/docs/source/driver/status.rst b/docs/source/driver/status.rst index d295bc3f7c..7337dd4e82 100644 --- a/docs/source/driver/status.rst +++ b/docs/source/driver/status.rst @@ -149,6 +149,7 @@ Update Queries :header-rows: 1 * - Driver + - Incremental Queries - Partitioned Data - Parameterized Queries - Prepared Statements @@ -161,8 +162,10 @@ Update Queries - Y - Y - Y + - Y * - Flight SQL (Java) + - N - Y - Y - Y @@ -170,6 +173,7 @@ Update Queries - Y * - JDBC + - N/A - N/A - Y - Y @@ -177,6 +181,7 @@ Update Queries - Y * - PostgreSQL + - N/A - N/A - Y - Y @@ -184,6 +189,7 @@ Update Queries - Y * - SQLite + - N/A - N/A - Y - Y diff --git a/go/adbc/driver/flightsql/cmd/testserver/main.go b/go/adbc/driver/flightsql/cmd/testserver/main.go index 987c582404..22a928adb7 100644 --- a/go/adbc/driver/flightsql/cmd/testserver/main.go +++ b/go/adbc/driver/flightsql/cmd/testserver/main.go @@ -30,6 +30,7 @@ import ( "os" "strconv" "strings" + "sync" "github.com/apache/arrow/go/v16/arrow" "github.com/apache/arrow/go/v16/arrow/array" @@ -45,6 +46,9 @@ import ( type ExampleServer struct { flightsql.BaseServer + + mu sync.Mutex + pollingStatus map[string]int } func StatusWithDetail(code codes.Code, message string, details ...proto.Message) error { @@ -120,6 +124,103 @@ func (srv *ExampleServer) GetFlightInfoStatement(ctx context.Context, cmd flight }, nil } +func (srv *ExampleServer) PollFlightInfo(ctx context.Context, desc *flight.FlightDescriptor) (*flight.PollInfo, error) { + srv.mu.Lock() + defer srv.mu.Unlock() + + var val wrapperspb.StringValue + var err error + if err = proto.Unmarshal(desc.Cmd, &val); err != nil { + return nil, err + } + + srv.pollingStatus[val.Value]-- + progress := srv.pollingStatus[val.Value] + + ticket, err := flightsql.CreateStatementQueryTicket([]byte(val.Value)) + if err != nil { + return nil, err + } + + endpoints := make([]*flight.FlightEndpoint, 5-progress) + for i := range endpoints { + endpoints[i] = &flight.FlightEndpoint{Ticket: &flight.Ticket{Ticket: ticket}} + } + + var schema []byte + if progress < 3 { + schema = flight.SerializeSchema(arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil), srv.Alloc) + } + if progress == 0 { + desc = nil + } + + if val.Value == "error_poll_later" && progress == 3 { + return nil, StatusWithDetail(codes.Unavailable, "expected error (PollFlightInfo)") + } + + return &flight.PollInfo{ + Info: &flight.FlightInfo{ + Schema: schema, + Endpoint: endpoints, + FlightDescriptor: desc, + TotalRecords: -1, + TotalBytes: -1, + }, + FlightDescriptor: desc, + Progress: proto.Float64(1.0 - (float64(progress) / 5.0)), + }, nil +} + +func (srv *ExampleServer) PollFlightInfoPreparedStatement(ctx context.Context, query flightsql.PreparedStatementQuery, desc *flight.FlightDescriptor) (*flight.PollInfo, error) { + srv.mu.Lock() + defer srv.mu.Unlock() + + switch string(query.GetPreparedStatementHandle()) { + case "error_poll": + detail1 := wrapperspb.String("detail1") + detail2 := wrapperspb.String("detail2") + return nil, StatusWithDetail(codes.InvalidArgument, "expected error (PollFlightInfo)", detail1, detail2) + case "finish_immediately": + ticket, err := flightsql.CreateStatementQueryTicket(query.GetPreparedStatementHandle()) + if err != nil { + return nil, err + } + return &flight.PollInfo{ + Info: &flight.FlightInfo{ + Schema: flight.SerializeSchema(arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil), srv.Alloc), + Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: ticket}}}, + FlightDescriptor: desc, + TotalRecords: -1, + TotalBytes: -1, + }, + FlightDescriptor: nil, + Progress: proto.Float64(1.0), + }, nil + } + + descriptor, err := proto.Marshal(&wrapperspb.StringValue{Value: string(query.GetPreparedStatementHandle())}) + if err != nil { + return nil, err + } + + srv.pollingStatus[string(query.GetPreparedStatementHandle())] = 5 + return &flight.PollInfo{ + Info: &flight.FlightInfo{ + Schema: nil, + Endpoint: []*flight.FlightEndpoint{}, + FlightDescriptor: desc, + TotalRecords: -1, + TotalBytes: -1, + }, + FlightDescriptor: &flight.FlightDescriptor{ + Type: flight.DescriptorCMD, + Cmd: descriptor, + }, + Progress: proto.Float64(0.0), + }, nil +} + func (srv *ExampleServer) DoGetPreparedStatement(ctx context.Context, cmd flightsql.PreparedStatementQuery) (schema *arrow.Schema, out <-chan flight.StreamChunk, err error) { log.Printf("DoGetPreparedStatement: %v", cmd.GetPreparedStatementHandle()) switch string(cmd.GetPreparedStatementHandle()) { @@ -226,7 +327,7 @@ func main() { flag.Parse() - srv := &ExampleServer{} + srv := &ExampleServer{pollingStatus: make(map[string]int)} srv.Alloc = memory.DefaultAllocator server := flight.NewServerWithMiddleware(nil) diff --git a/python/adbc_driver_flightsql/tests/test_incremental.py b/python/adbc_driver_flightsql/tests/test_incremental.py new file mode 100644 index 0000000000..8f47fd3923 --- /dev/null +++ b/python/adbc_driver_flightsql/tests/test_incremental.py @@ -0,0 +1,204 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import re + +import google.protobuf.any_pb2 as any_pb2 +import google.protobuf.wrappers_pb2 as wrappers_pb2 +import pyarrow +import pytest + +import adbc_driver_manager +from adbc_driver_manager import StatementOptions + +SCHEMA = pyarrow.schema([("ints", "int32")]) + + +def test_incremental_error(test_dbapi) -> None: + with test_dbapi.cursor() as cur: + cur.adbc_statement.set_options( + **{ + StatementOptions.INCREMENTAL.value: "true", + } + ) + with pytest.raises( + test_dbapi.ProgrammingError, + match=re.escape("[FlightSQL] expected error (PollFlightInfo)"), + ) as exc_info: + cur.adbc_execute_partitions("error_poll") + + found = set() + for _, detail in exc_info.value.details: + anyproto = any_pb2.Any() + anyproto.ParseFromString(detail) + string = wrappers_pb2.StringValue() + anyproto.Unpack(string) + found.add(string.value) + assert found == {"detail1", "detail2"} + + # After an error, we can execute a different query. + partitions, schema = cur.adbc_execute_partitions("finish_immediately") + assert len(partitions) == 1 + assert schema == SCHEMA + assert cur.adbc_statement.get_option_float( + StatementOptions.PROGRESS.value + ) == pytest.approx(1.0) + + +def test_incremental_error_poll(test_dbapi) -> None: + with test_dbapi.cursor() as cur: + cur.adbc_statement.set_options( + **{ + StatementOptions.INCREMENTAL.value: "true", + } + ) + partitions, schema = cur.adbc_execute_partitions("error_poll_later") + assert len(partitions) == 1 + assert schema is None + assert cur.adbc_statement.get_option_float( + StatementOptions.PROGRESS.value + ) == pytest.approx(0.2) + + # An error can be retried. + with pytest.raises( + test_dbapi.OperationalError, + match=re.escape("[FlightSQL] expected error (PollFlightInfo)"), + ) as excinfo: + partitions, schema = cur.adbc_execute_partitions("error_poll_later") + assert excinfo.value.status_code == adbc_driver_manager.AdbcStatusCode.IO + + partitions, schema = cur.adbc_execute_partitions("error_poll_later") + assert len(partitions) == 2 + assert schema == SCHEMA + assert cur.adbc_statement.get_option_float( + StatementOptions.PROGRESS.value + ) == pytest.approx(0.6) + + partitions, schema = cur.adbc_execute_partitions("error_poll_later") + assert len(partitions) == 1 + assert schema == SCHEMA + assert cur.adbc_statement.get_option_float( + StatementOptions.PROGRESS.value + ) == pytest.approx(0.8) + + partitions, schema = cur.adbc_execute_partitions("error_poll_later") + assert len(partitions) == 1 + assert schema == SCHEMA + assert cur.adbc_statement.get_option_float( + StatementOptions.PROGRESS.value + ) == pytest.approx(1.0) + + partitions, _ = cur.adbc_execute_partitions("error_poll_later") + assert partitions == [] + + +def test_incremental_immediately(test_dbapi) -> None: + with test_dbapi.cursor() as cur: + cur.adbc_statement.set_options( + **{ + StatementOptions.INCREMENTAL.value: "true", + } + ) + partitions, schema = cur.adbc_execute_partitions("finish_immediately") + assert len(partitions) == 1 + assert schema == SCHEMA + assert cur.adbc_statement.get_option_float( + StatementOptions.PROGRESS.value + ) == pytest.approx(1.0) + + partitions, schema = cur.adbc_execute_partitions("finish_immediately") + assert partitions == [] + + # reuse for a new query + partitions, schema = cur.adbc_execute_partitions("finish_immediately") + assert len(partitions) == 1 + partitions, schema = cur.adbc_execute_partitions("finish_immediately") + assert partitions == [] + + +def test_incremental_query(test_dbapi) -> None: + with test_dbapi.cursor() as cur: + cur.adbc_statement.set_options( + **{ + StatementOptions.INCREMENTAL.value: "true", + } + ) + partitions, schema = cur.adbc_execute_partitions("SELECT 1") + assert len(partitions) == 1 + assert schema is None + assert cur.adbc_statement.get_option_float( + StatementOptions.PROGRESS.value + ) == pytest.approx(0.2) + + message = ( + "[Flight SQL] Cannot disable incremental execution " + "while a query is in progress" + ) + with pytest.raises( + test_dbapi.ProgrammingError, + match=re.escape(message), + ) as excinfo: + cur.adbc_statement.set_options( + **{ + StatementOptions.INCREMENTAL.value: "false", + } + ) + assert ( + excinfo.value.status_code + == adbc_driver_manager.AdbcStatusCode.INVALID_STATE + ) + + partitions, schema = cur.adbc_execute_partitions("SELECT 1") + assert len(partitions) == 1 + assert schema is None + assert cur.adbc_statement.get_option_float( + StatementOptions.PROGRESS.value + ) == pytest.approx(0.4) + + partitions, schema = cur.adbc_execute_partitions("SELECT 1") + assert len(partitions) == 1 + assert schema == SCHEMA + assert cur.adbc_statement.get_option_float( + StatementOptions.PROGRESS.value + ) == pytest.approx(0.6) + + partitions, schema = cur.adbc_execute_partitions("SELECT 1") + assert len(partitions) == 1 + assert schema == SCHEMA + assert cur.adbc_statement.get_option_float( + StatementOptions.PROGRESS.value + ) == pytest.approx(0.8) + + partitions, schema = cur.adbc_execute_partitions("SELECT 1") + assert len(partitions) == 1 + assert schema == SCHEMA + assert cur.adbc_statement.get_option_float( + StatementOptions.PROGRESS.value + ) == pytest.approx(1.0) + + partitions, schema = cur.adbc_execute_partitions("SELECT 1") + assert len(partitions) == 0 + assert schema == SCHEMA + assert ( + cur.adbc_statement.get_option_float(StatementOptions.PROGRESS.value) == 0.0 + ) + + cur.adbc_statement.set_options( + **{ + StatementOptions.INCREMENTAL.value: "false", + } + ) diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi index 2a818839a1..0a19f92ed5 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi @@ -99,7 +99,9 @@ class AdbcStatement(_AdbcHandle): def bind_stream(self, *args, **kwargs) -> Any: ... def cancel(self) -> None: ... def close(self) -> None: ... - def execute_partitions(self, *args, **kwargs) -> Any: ... + def execute_partitions( + self, + ) -> Tuple[List[bytes], Optional[ArrowSchemaHandle], int]: ... def execute_query(self, *args, **kwargs) -> Any: ... def execute_schema(self) -> "ArrowSchemaHandle": ... def execute_update(self, *args, **kwargs) -> Any: ... diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx index 309cd76489..0592c8525e 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx @@ -26,7 +26,7 @@ import os import typing import sys import warnings -from typing import List, Tuple +from typing import List, Optional, Tuple cimport cpython import cython @@ -1195,7 +1195,7 @@ cdef class AdbcStatement(_AdbcHandle): check_error(status, &c_error) return (stream, rows_affected) - def execute_partitions(self) -> Tuple[List[bytes], ArrowSchemaHandle, int]: + def execute_partitions(self) -> Tuple[List[bytes], Optional[ArrowSchemaHandle], int]: """ Execute the query and get the partitions of the result set. @@ -1205,8 +1205,9 @@ cdef class AdbcStatement(_AdbcHandle): ------- list of byte The partitions of the distributed result set. - ArrowSchemaHandle - The schema of the result set. + ArrowSchemaHandle or None + The schema of the result set. May be None if incremental + execution is enabled and the server does not return a schema. int The number of rows if known, else -1. """ @@ -1232,7 +1233,9 @@ cdef class AdbcStatement(_AdbcHandle): partitions.append(PyBytes_FromStringAndSize(data, length)) c_partitions.release(&c_partitions) - return (partitions, schema, rows_affected) + if schema.schema.release == NULL: + return partitions, None, rows_affected + return partitions, schema, rows_affected def execute_schema(self) -> ArrowSchemaHandle: """ diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py index c296a25749..af34e04bab 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py +++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py @@ -919,7 +919,9 @@ def adbc_ingest( return self._stmt.execute_update() def adbc_execute_partitions( - self, operation, parameters=None + self, + operation, + parameters=None, ) -> Tuple[List[bytes], pyarrow.Schema]: """ Execute a query and get the partitions of a distributed result set. @@ -929,16 +931,21 @@ def adbc_execute_partitions( partitions : list of byte A list of partition descriptors, which can be read with read_partition. - schema : pyarrow.Schema - The schema of the result set. + schema : pyarrow.Schema or None + The schema of the result set. May be None if incremental query + execution is enabled and the server has not returned a schema. Notes ----- This is an extension and not part of the DBAPI standard. """ self._prepare_execute(operation, parameters) - partitions, schema, self._rowcount = self._stmt.execute_partitions() - return partitions, pyarrow.Schema._import_from_c(schema.address) + partitions, schema_handle, self._rowcount = self._stmt.execute_partitions() + if schema_handle and schema_handle.address: + schema = pyarrow.Schema._import_from_c(schema_handle.address) + else: + schema = None + return partitions, schema def adbc_execute_schema(self, operation, parameters=None) -> pyarrow.Schema: """