From 92220c2b9fa341ccc428e630fc353768c5627ca5 Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 29 Feb 2024 16:41:22 -0500
Subject: [PATCH] test(python/adbc_driver_flightsql): test incremental
execution
Fixes #1570.
---
docs/source/driver/status.rst | 6 +
.../driver/flightsql/cmd/testserver/main.go | 103 ++++++++-
.../tests/test_incremental.py | 204 ++++++++++++++++++
.../adbc_driver_manager/_lib.pyi | 4 +-
.../adbc_driver_manager/_lib.pyx | 13 +-
.../adbc_driver_manager/dbapi.py | 17 +-
6 files changed, 335 insertions(+), 12 deletions(-)
create mode 100644 python/adbc_driver_flightsql/tests/test_incremental.py
diff --git a/docs/source/driver/status.rst b/docs/source/driver/status.rst
index d295bc3f7c..7337dd4e82 100644
--- a/docs/source/driver/status.rst
+++ b/docs/source/driver/status.rst
@@ -149,6 +149,7 @@ Update Queries
:header-rows: 1
* - Driver
+ - Incremental Queries
- Partitioned Data
- Parameterized Queries
- Prepared Statements
@@ -161,8 +162,10 @@ Update Queries
- Y
- Y
- Y
+ - Y
* - Flight SQL (Java)
+ - N
- Y
- Y
- Y
@@ -170,6 +173,7 @@ Update Queries
- Y
* - JDBC
+ - N/A
- N/A
- Y
- Y
@@ -177,6 +181,7 @@ Update Queries
- Y
* - PostgreSQL
+ - N/A
- N/A
- Y
- Y
@@ -184,6 +189,7 @@ Update Queries
- Y
* - SQLite
+ - N/A
- N/A
- Y
- Y
diff --git a/go/adbc/driver/flightsql/cmd/testserver/main.go b/go/adbc/driver/flightsql/cmd/testserver/main.go
index 987c582404..22a928adb7 100644
--- a/go/adbc/driver/flightsql/cmd/testserver/main.go
+++ b/go/adbc/driver/flightsql/cmd/testserver/main.go
@@ -30,6 +30,7 @@ import (
"os"
"strconv"
"strings"
+ "sync"
"github.com/apache/arrow/go/v16/arrow"
"github.com/apache/arrow/go/v16/arrow/array"
@@ -45,6 +46,9 @@ import (
type ExampleServer struct {
flightsql.BaseServer
+
+ mu sync.Mutex
+ pollingStatus map[string]int
}
func StatusWithDetail(code codes.Code, message string, details ...proto.Message) error {
@@ -120,6 +124,103 @@ func (srv *ExampleServer) GetFlightInfoStatement(ctx context.Context, cmd flight
}, nil
}
+func (srv *ExampleServer) PollFlightInfo(ctx context.Context, desc *flight.FlightDescriptor) (*flight.PollInfo, error) {
+ srv.mu.Lock()
+ defer srv.mu.Unlock()
+
+ var val wrapperspb.StringValue
+ var err error
+ if err = proto.Unmarshal(desc.Cmd, &val); err != nil {
+ return nil, err
+ }
+
+ srv.pollingStatus[val.Value]--
+ progress := srv.pollingStatus[val.Value]
+
+ ticket, err := flightsql.CreateStatementQueryTicket([]byte(val.Value))
+ if err != nil {
+ return nil, err
+ }
+
+ endpoints := make([]*flight.FlightEndpoint, 5-progress)
+ for i := range endpoints {
+ endpoints[i] = &flight.FlightEndpoint{Ticket: &flight.Ticket{Ticket: ticket}}
+ }
+
+ var schema []byte
+ if progress < 3 {
+ schema = flight.SerializeSchema(arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil), srv.Alloc)
+ }
+ if progress == 0 {
+ desc = nil
+ }
+
+ if val.Value == "error_poll_later" && progress == 3 {
+ return nil, StatusWithDetail(codes.Unavailable, "expected error (PollFlightInfo)")
+ }
+
+ return &flight.PollInfo{
+ Info: &flight.FlightInfo{
+ Schema: schema,
+ Endpoint: endpoints,
+ FlightDescriptor: desc,
+ TotalRecords: -1,
+ TotalBytes: -1,
+ },
+ FlightDescriptor: desc,
+ Progress: proto.Float64(1.0 - (float64(progress) / 5.0)),
+ }, nil
+}
+
+func (srv *ExampleServer) PollFlightInfoPreparedStatement(ctx context.Context, query flightsql.PreparedStatementQuery, desc *flight.FlightDescriptor) (*flight.PollInfo, error) {
+ srv.mu.Lock()
+ defer srv.mu.Unlock()
+
+ switch string(query.GetPreparedStatementHandle()) {
+ case "error_poll":
+ detail1 := wrapperspb.String("detail1")
+ detail2 := wrapperspb.String("detail2")
+ return nil, StatusWithDetail(codes.InvalidArgument, "expected error (PollFlightInfo)", detail1, detail2)
+ case "finish_immediately":
+ ticket, err := flightsql.CreateStatementQueryTicket(query.GetPreparedStatementHandle())
+ if err != nil {
+ return nil, err
+ }
+ return &flight.PollInfo{
+ Info: &flight.FlightInfo{
+ Schema: flight.SerializeSchema(arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil), srv.Alloc),
+ Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: ticket}}},
+ FlightDescriptor: desc,
+ TotalRecords: -1,
+ TotalBytes: -1,
+ },
+ FlightDescriptor: nil,
+ Progress: proto.Float64(1.0),
+ }, nil
+ }
+
+ descriptor, err := proto.Marshal(&wrapperspb.StringValue{Value: string(query.GetPreparedStatementHandle())})
+ if err != nil {
+ return nil, err
+ }
+
+ srv.pollingStatus[string(query.GetPreparedStatementHandle())] = 5
+ return &flight.PollInfo{
+ Info: &flight.FlightInfo{
+ Schema: nil,
+ Endpoint: []*flight.FlightEndpoint{},
+ FlightDescriptor: desc,
+ TotalRecords: -1,
+ TotalBytes: -1,
+ },
+ FlightDescriptor: &flight.FlightDescriptor{
+ Type: flight.DescriptorCMD,
+ Cmd: descriptor,
+ },
+ Progress: proto.Float64(0.0),
+ }, nil
+}
+
func (srv *ExampleServer) DoGetPreparedStatement(ctx context.Context, cmd flightsql.PreparedStatementQuery) (schema *arrow.Schema, out <-chan flight.StreamChunk, err error) {
log.Printf("DoGetPreparedStatement: %v", cmd.GetPreparedStatementHandle())
switch string(cmd.GetPreparedStatementHandle()) {
@@ -226,7 +327,7 @@ func main() {
flag.Parse()
- srv := &ExampleServer{}
+ srv := &ExampleServer{pollingStatus: make(map[string]int)}
srv.Alloc = memory.DefaultAllocator
server := flight.NewServerWithMiddleware(nil)
diff --git a/python/adbc_driver_flightsql/tests/test_incremental.py b/python/adbc_driver_flightsql/tests/test_incremental.py
new file mode 100644
index 0000000000..8f47fd3923
--- /dev/null
+++ b/python/adbc_driver_flightsql/tests/test_incremental.py
@@ -0,0 +1,204 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import re
+
+import google.protobuf.any_pb2 as any_pb2
+import google.protobuf.wrappers_pb2 as wrappers_pb2
+import pyarrow
+import pytest
+
+import adbc_driver_manager
+from adbc_driver_manager import StatementOptions
+
+SCHEMA = pyarrow.schema([("ints", "int32")])
+
+
+def test_incremental_error(test_dbapi) -> None:
+ with test_dbapi.cursor() as cur:
+ cur.adbc_statement.set_options(
+ **{
+ StatementOptions.INCREMENTAL.value: "true",
+ }
+ )
+ with pytest.raises(
+ test_dbapi.ProgrammingError,
+ match=re.escape("[FlightSQL] expected error (PollFlightInfo)"),
+ ) as exc_info:
+ cur.adbc_execute_partitions("error_poll")
+
+ found = set()
+ for _, detail in exc_info.value.details:
+ anyproto = any_pb2.Any()
+ anyproto.ParseFromString(detail)
+ string = wrappers_pb2.StringValue()
+ anyproto.Unpack(string)
+ found.add(string.value)
+ assert found == {"detail1", "detail2"}
+
+ # After an error, we can execute a different query.
+ partitions, schema = cur.adbc_execute_partitions("finish_immediately")
+ assert len(partitions) == 1
+ assert schema == SCHEMA
+ assert cur.adbc_statement.get_option_float(
+ StatementOptions.PROGRESS.value
+ ) == pytest.approx(1.0)
+
+
+def test_incremental_error_poll(test_dbapi) -> None:
+ with test_dbapi.cursor() as cur:
+ cur.adbc_statement.set_options(
+ **{
+ StatementOptions.INCREMENTAL.value: "true",
+ }
+ )
+ partitions, schema = cur.adbc_execute_partitions("error_poll_later")
+ assert len(partitions) == 1
+ assert schema is None
+ assert cur.adbc_statement.get_option_float(
+ StatementOptions.PROGRESS.value
+ ) == pytest.approx(0.2)
+
+ # An error can be retried.
+ with pytest.raises(
+ test_dbapi.OperationalError,
+ match=re.escape("[FlightSQL] expected error (PollFlightInfo)"),
+ ) as excinfo:
+ partitions, schema = cur.adbc_execute_partitions("error_poll_later")
+ assert excinfo.value.status_code == adbc_driver_manager.AdbcStatusCode.IO
+
+ partitions, schema = cur.adbc_execute_partitions("error_poll_later")
+ assert len(partitions) == 2
+ assert schema == SCHEMA
+ assert cur.adbc_statement.get_option_float(
+ StatementOptions.PROGRESS.value
+ ) == pytest.approx(0.6)
+
+ partitions, schema = cur.adbc_execute_partitions("error_poll_later")
+ assert len(partitions) == 1
+ assert schema == SCHEMA
+ assert cur.adbc_statement.get_option_float(
+ StatementOptions.PROGRESS.value
+ ) == pytest.approx(0.8)
+
+ partitions, schema = cur.adbc_execute_partitions("error_poll_later")
+ assert len(partitions) == 1
+ assert schema == SCHEMA
+ assert cur.adbc_statement.get_option_float(
+ StatementOptions.PROGRESS.value
+ ) == pytest.approx(1.0)
+
+ partitions, _ = cur.adbc_execute_partitions("error_poll_later")
+ assert partitions == []
+
+
+def test_incremental_immediately(test_dbapi) -> None:
+ with test_dbapi.cursor() as cur:
+ cur.adbc_statement.set_options(
+ **{
+ StatementOptions.INCREMENTAL.value: "true",
+ }
+ )
+ partitions, schema = cur.adbc_execute_partitions("finish_immediately")
+ assert len(partitions) == 1
+ assert schema == SCHEMA
+ assert cur.adbc_statement.get_option_float(
+ StatementOptions.PROGRESS.value
+ ) == pytest.approx(1.0)
+
+ partitions, schema = cur.adbc_execute_partitions("finish_immediately")
+ assert partitions == []
+
+ # reuse for a new query
+ partitions, schema = cur.adbc_execute_partitions("finish_immediately")
+ assert len(partitions) == 1
+ partitions, schema = cur.adbc_execute_partitions("finish_immediately")
+ assert partitions == []
+
+
+def test_incremental_query(test_dbapi) -> None:
+ with test_dbapi.cursor() as cur:
+ cur.adbc_statement.set_options(
+ **{
+ StatementOptions.INCREMENTAL.value: "true",
+ }
+ )
+ partitions, schema = cur.adbc_execute_partitions("SELECT 1")
+ assert len(partitions) == 1
+ assert schema is None
+ assert cur.adbc_statement.get_option_float(
+ StatementOptions.PROGRESS.value
+ ) == pytest.approx(0.2)
+
+ message = (
+ "[Flight SQL] Cannot disable incremental execution "
+ "while a query is in progress"
+ )
+ with pytest.raises(
+ test_dbapi.ProgrammingError,
+ match=re.escape(message),
+ ) as excinfo:
+ cur.adbc_statement.set_options(
+ **{
+ StatementOptions.INCREMENTAL.value: "false",
+ }
+ )
+ assert (
+ excinfo.value.status_code
+ == adbc_driver_manager.AdbcStatusCode.INVALID_STATE
+ )
+
+ partitions, schema = cur.adbc_execute_partitions("SELECT 1")
+ assert len(partitions) == 1
+ assert schema is None
+ assert cur.adbc_statement.get_option_float(
+ StatementOptions.PROGRESS.value
+ ) == pytest.approx(0.4)
+
+ partitions, schema = cur.adbc_execute_partitions("SELECT 1")
+ assert len(partitions) == 1
+ assert schema == SCHEMA
+ assert cur.adbc_statement.get_option_float(
+ StatementOptions.PROGRESS.value
+ ) == pytest.approx(0.6)
+
+ partitions, schema = cur.adbc_execute_partitions("SELECT 1")
+ assert len(partitions) == 1
+ assert schema == SCHEMA
+ assert cur.adbc_statement.get_option_float(
+ StatementOptions.PROGRESS.value
+ ) == pytest.approx(0.8)
+
+ partitions, schema = cur.adbc_execute_partitions("SELECT 1")
+ assert len(partitions) == 1
+ assert schema == SCHEMA
+ assert cur.adbc_statement.get_option_float(
+ StatementOptions.PROGRESS.value
+ ) == pytest.approx(1.0)
+
+ partitions, schema = cur.adbc_execute_partitions("SELECT 1")
+ assert len(partitions) == 0
+ assert schema == SCHEMA
+ assert (
+ cur.adbc_statement.get_option_float(StatementOptions.PROGRESS.value) == 0.0
+ )
+
+ cur.adbc_statement.set_options(
+ **{
+ StatementOptions.INCREMENTAL.value: "false",
+ }
+ )
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
index 2a818839a1..0a19f92ed5 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
@@ -99,7 +99,9 @@ class AdbcStatement(_AdbcHandle):
def bind_stream(self, *args, **kwargs) -> Any: ...
def cancel(self) -> None: ...
def close(self) -> None: ...
- def execute_partitions(self, *args, **kwargs) -> Any: ...
+ def execute_partitions(
+ self,
+ ) -> Tuple[List[bytes], Optional[ArrowSchemaHandle], int]: ...
def execute_query(self, *args, **kwargs) -> Any: ...
def execute_schema(self) -> "ArrowSchemaHandle": ...
def execute_update(self, *args, **kwargs) -> Any: ...
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index 309cd76489..0592c8525e 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -26,7 +26,7 @@ import os
import typing
import sys
import warnings
-from typing import List, Tuple
+from typing import List, Optional, Tuple
cimport cpython
import cython
@@ -1195,7 +1195,7 @@ cdef class AdbcStatement(_AdbcHandle):
check_error(status, &c_error)
return (stream, rows_affected)
- def execute_partitions(self) -> Tuple[List[bytes], ArrowSchemaHandle, int]:
+ def execute_partitions(self) -> Tuple[List[bytes], Optional[ArrowSchemaHandle], int]:
"""
Execute the query and get the partitions of the result set.
@@ -1205,8 +1205,9 @@ cdef class AdbcStatement(_AdbcHandle):
-------
list of byte
The partitions of the distributed result set.
- ArrowSchemaHandle
- The schema of the result set.
+ ArrowSchemaHandle or None
+ The schema of the result set. May be None if incremental
+ execution is enabled and the server does not return a schema.
int
The number of rows if known, else -1.
"""
@@ -1232,7 +1233,9 @@ cdef class AdbcStatement(_AdbcHandle):
partitions.append(PyBytes_FromStringAndSize(data, length))
c_partitions.release(&c_partitions)
- return (partitions, schema, rows_affected)
+ if schema.schema.release == NULL:
+ return partitions, None, rows_affected
+ return partitions, schema, rows_affected
def execute_schema(self) -> ArrowSchemaHandle:
"""
diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index c296a25749..af34e04bab 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -919,7 +919,9 @@ def adbc_ingest(
return self._stmt.execute_update()
def adbc_execute_partitions(
- self, operation, parameters=None
+ self,
+ operation,
+ parameters=None,
) -> Tuple[List[bytes], pyarrow.Schema]:
"""
Execute a query and get the partitions of a distributed result set.
@@ -929,16 +931,21 @@ def adbc_execute_partitions(
partitions : list of byte
A list of partition descriptors, which can be read with
read_partition.
- schema : pyarrow.Schema
- The schema of the result set.
+ schema : pyarrow.Schema or None
+ The schema of the result set. May be None if incremental query
+ execution is enabled and the server has not returned a schema.
Notes
-----
This is an extension and not part of the DBAPI standard.
"""
self._prepare_execute(operation, parameters)
- partitions, schema, self._rowcount = self._stmt.execute_partitions()
- return partitions, pyarrow.Schema._import_from_c(schema.address)
+ partitions, schema_handle, self._rowcount = self._stmt.execute_partitions()
+ if schema_handle and schema_handle.address:
+ schema = pyarrow.Schema._import_from_c(schema_handle.address)
+ else:
+ schema = None
+ return partitions, schema
def adbc_execute_schema(self, operation, parameters=None) -> pyarrow.Schema:
"""