Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test(python/adbc_driver_flightsql): test incremental execution #1575

Merged
merged 1 commit into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/driver/status.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ Update Queries
:header-rows: 1

* - Driver
- Incremental Queries
- Partitioned Data
- Parameterized Queries
- Prepared Statements
Expand All @@ -161,29 +162,34 @@ Update Queries
- Y
- Y
- Y
- Y

* - Flight SQL (Java)
- N
- Y
- Y
- Y
- Y
- Y

* - JDBC
- N/A
- N/A
- Y
- Y
- Y
- Y

* - PostgreSQL
- N/A
- N/A
- Y
- Y
- Y
- Y

* - SQLite
- N/A
- N/A
- Y
- Y
Expand Down
103 changes: 102 additions & 1 deletion go/adbc/driver/flightsql/cmd/testserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"os"
"strconv"
"strings"
"sync"

"github.com/apache/arrow/go/v16/arrow"
"github.com/apache/arrow/go/v16/arrow/array"
Expand All @@ -45,6 +46,9 @@ import (

type ExampleServer struct {
flightsql.BaseServer

mu sync.Mutex
pollingStatus map[string]int
}

func StatusWithDetail(code codes.Code, message string, details ...proto.Message) error {
Expand Down Expand Up @@ -120,6 +124,103 @@ func (srv *ExampleServer) GetFlightInfoStatement(ctx context.Context, cmd flight
}, nil
}

func (srv *ExampleServer) PollFlightInfo(ctx context.Context, desc *flight.FlightDescriptor) (*flight.PollInfo, error) {
srv.mu.Lock()
defer srv.mu.Unlock()

var val wrapperspb.StringValue
var err error
if err = proto.Unmarshal(desc.Cmd, &val); err != nil {
return nil, err
}

srv.pollingStatus[val.Value]--
progress := srv.pollingStatus[val.Value]

ticket, err := flightsql.CreateStatementQueryTicket([]byte(val.Value))
if err != nil {
return nil, err
}

endpoints := make([]*flight.FlightEndpoint, 5-progress)
for i := range endpoints {
endpoints[i] = &flight.FlightEndpoint{Ticket: &flight.Ticket{Ticket: ticket}}
}

var schema []byte
if progress < 3 {
schema = flight.SerializeSchema(arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil), srv.Alloc)
}
if progress == 0 {
desc = nil
}

if val.Value == "error_poll_later" && progress == 3 {
return nil, StatusWithDetail(codes.Unavailable, "expected error (PollFlightInfo)")
}

return &flight.PollInfo{
Info: &flight.FlightInfo{
Schema: schema,
Endpoint: endpoints,
FlightDescriptor: desc,
TotalRecords: -1,
TotalBytes: -1,
},
FlightDescriptor: desc,
Progress: proto.Float64(1.0 - (float64(progress) / 5.0)),
}, nil
}

func (srv *ExampleServer) PollFlightInfoPreparedStatement(ctx context.Context, query flightsql.PreparedStatementQuery, desc *flight.FlightDescriptor) (*flight.PollInfo, error) {
srv.mu.Lock()
defer srv.mu.Unlock()

switch string(query.GetPreparedStatementHandle()) {
case "error_poll":
detail1 := wrapperspb.String("detail1")
detail2 := wrapperspb.String("detail2")
return nil, StatusWithDetail(codes.InvalidArgument, "expected error (PollFlightInfo)", detail1, detail2)
case "finish_immediately":
ticket, err := flightsql.CreateStatementQueryTicket(query.GetPreparedStatementHandle())
if err != nil {
return nil, err
}
return &flight.PollInfo{
Info: &flight.FlightInfo{
Schema: flight.SerializeSchema(arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil), srv.Alloc),
Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: ticket}}},
FlightDescriptor: desc,
TotalRecords: -1,
TotalBytes: -1,
},
FlightDescriptor: nil,
Progress: proto.Float64(1.0),
}, nil
}

descriptor, err := proto.Marshal(&wrapperspb.StringValue{Value: string(query.GetPreparedStatementHandle())})
if err != nil {
return nil, err
}

srv.pollingStatus[string(query.GetPreparedStatementHandle())] = 5
return &flight.PollInfo{
Info: &flight.FlightInfo{
Schema: nil,
Endpoint: []*flight.FlightEndpoint{},
FlightDescriptor: desc,
TotalRecords: -1,
TotalBytes: -1,
},
FlightDescriptor: &flight.FlightDescriptor{
Type: flight.DescriptorCMD,
Cmd: descriptor,
},
Progress: proto.Float64(0.0),
}, nil
}

func (srv *ExampleServer) DoGetPreparedStatement(ctx context.Context, cmd flightsql.PreparedStatementQuery) (schema *arrow.Schema, out <-chan flight.StreamChunk, err error) {
log.Printf("DoGetPreparedStatement: %v", cmd.GetPreparedStatementHandle())
switch string(cmd.GetPreparedStatementHandle()) {
Expand Down Expand Up @@ -226,7 +327,7 @@ func main() {

flag.Parse()

srv := &ExampleServer{}
srv := &ExampleServer{pollingStatus: make(map[string]int)}
srv.Alloc = memory.DefaultAllocator

server := flight.NewServerWithMiddleware(nil)
Expand Down
204 changes: 204 additions & 0 deletions python/adbc_driver_flightsql/tests/test_incremental.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import re

import google.protobuf.any_pb2 as any_pb2
import google.protobuf.wrappers_pb2 as wrappers_pb2
import pyarrow
import pytest

import adbc_driver_manager
from adbc_driver_manager import StatementOptions

SCHEMA = pyarrow.schema([("ints", "int32")])


def test_incremental_error(test_dbapi) -> None:
with test_dbapi.cursor() as cur:
cur.adbc_statement.set_options(
**{
StatementOptions.INCREMENTAL.value: "true",
}
)
with pytest.raises(
test_dbapi.ProgrammingError,
match=re.escape("[FlightSQL] expected error (PollFlightInfo)"),
) as exc_info:
cur.adbc_execute_partitions("error_poll")

found = set()
for _, detail in exc_info.value.details:
anyproto = any_pb2.Any()
anyproto.ParseFromString(detail)
string = wrappers_pb2.StringValue()
anyproto.Unpack(string)
found.add(string.value)
assert found == {"detail1", "detail2"}

# After an error, we can execute a different query.
partitions, schema = cur.adbc_execute_partitions("finish_immediately")
assert len(partitions) == 1
assert schema == SCHEMA
assert cur.adbc_statement.get_option_float(
StatementOptions.PROGRESS.value
) == pytest.approx(1.0)


def test_incremental_error_poll(test_dbapi) -> None:
with test_dbapi.cursor() as cur:
cur.adbc_statement.set_options(
**{
StatementOptions.INCREMENTAL.value: "true",
}
)
partitions, schema = cur.adbc_execute_partitions("error_poll_later")
assert len(partitions) == 1
assert schema is None
assert cur.adbc_statement.get_option_float(
StatementOptions.PROGRESS.value
) == pytest.approx(0.2)

# An error can be retried.
with pytest.raises(
test_dbapi.OperationalError,
match=re.escape("[FlightSQL] expected error (PollFlightInfo)"),
) as excinfo:
partitions, schema = cur.adbc_execute_partitions("error_poll_later")
assert excinfo.value.status_code == adbc_driver_manager.AdbcStatusCode.IO

partitions, schema = cur.adbc_execute_partitions("error_poll_later")
assert len(partitions) == 2
assert schema == SCHEMA
assert cur.adbc_statement.get_option_float(
StatementOptions.PROGRESS.value
) == pytest.approx(0.6)

partitions, schema = cur.adbc_execute_partitions("error_poll_later")
assert len(partitions) == 1
assert schema == SCHEMA
assert cur.adbc_statement.get_option_float(
StatementOptions.PROGRESS.value
) == pytest.approx(0.8)

partitions, schema = cur.adbc_execute_partitions("error_poll_later")
assert len(partitions) == 1
assert schema == SCHEMA
assert cur.adbc_statement.get_option_float(
StatementOptions.PROGRESS.value
) == pytest.approx(1.0)

partitions, _ = cur.adbc_execute_partitions("error_poll_later")
assert partitions == []


def test_incremental_immediately(test_dbapi) -> None:
with test_dbapi.cursor() as cur:
cur.adbc_statement.set_options(
**{
StatementOptions.INCREMENTAL.value: "true",
}
)
partitions, schema = cur.adbc_execute_partitions("finish_immediately")
assert len(partitions) == 1
assert schema == SCHEMA
assert cur.adbc_statement.get_option_float(
StatementOptions.PROGRESS.value
) == pytest.approx(1.0)

partitions, schema = cur.adbc_execute_partitions("finish_immediately")
assert partitions == []

# reuse for a new query
partitions, schema = cur.adbc_execute_partitions("finish_immediately")
assert len(partitions) == 1
partitions, schema = cur.adbc_execute_partitions("finish_immediately")
assert partitions == []


def test_incremental_query(test_dbapi) -> None:
with test_dbapi.cursor() as cur:
cur.adbc_statement.set_options(
**{
StatementOptions.INCREMENTAL.value: "true",
}
)
partitions, schema = cur.adbc_execute_partitions("SELECT 1")
assert len(partitions) == 1
assert schema is None
assert cur.adbc_statement.get_option_float(
StatementOptions.PROGRESS.value
) == pytest.approx(0.2)

message = (
"[Flight SQL] Cannot disable incremental execution "
"while a query is in progress"
)
with pytest.raises(
test_dbapi.ProgrammingError,
match=re.escape(message),
) as excinfo:
cur.adbc_statement.set_options(
**{
StatementOptions.INCREMENTAL.value: "false",
}
)
assert (
excinfo.value.status_code
== adbc_driver_manager.AdbcStatusCode.INVALID_STATE
)

partitions, schema = cur.adbc_execute_partitions("SELECT 1")
assert len(partitions) == 1
assert schema is None
assert cur.adbc_statement.get_option_float(
StatementOptions.PROGRESS.value
) == pytest.approx(0.4)

partitions, schema = cur.adbc_execute_partitions("SELECT 1")
assert len(partitions) == 1
assert schema == SCHEMA
assert cur.adbc_statement.get_option_float(
StatementOptions.PROGRESS.value
) == pytest.approx(0.6)

partitions, schema = cur.adbc_execute_partitions("SELECT 1")
assert len(partitions) == 1
assert schema == SCHEMA
assert cur.adbc_statement.get_option_float(
StatementOptions.PROGRESS.value
) == pytest.approx(0.8)

partitions, schema = cur.adbc_execute_partitions("SELECT 1")
assert len(partitions) == 1
assert schema == SCHEMA
assert cur.adbc_statement.get_option_float(
StatementOptions.PROGRESS.value
) == pytest.approx(1.0)

partitions, schema = cur.adbc_execute_partitions("SELECT 1")
assert len(partitions) == 0
assert schema == SCHEMA
assert (
cur.adbc_statement.get_option_float(StatementOptions.PROGRESS.value) == 0.0
)

cur.adbc_statement.set_options(
**{
StatementOptions.INCREMENTAL.value: "false",
}
)
4 changes: 3 additions & 1 deletion python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ class AdbcStatement(_AdbcHandle):
def bind_stream(self, *args, **kwargs) -> Any: ...
def cancel(self) -> None: ...
def close(self) -> None: ...
def execute_partitions(self, *args, **kwargs) -> Any: ...
def execute_partitions(
self,
) -> Tuple[List[bytes], Optional[ArrowSchemaHandle], int]: ...
def execute_query(self, *args, **kwargs) -> Any: ...
def execute_schema(self) -> "ArrowSchemaHandle": ...
def execute_update(self, *args, **kwargs) -> Any: ...
Expand Down
Loading
Loading