diff --git a/.github/workflows/native-unix.yml b/.github/workflows/native-unix.yml index 142d5f4be7..7b99003f0d 100644 --- a/.github/workflows/native-unix.yml +++ b/.github/workflows/native-unix.yml @@ -477,6 +477,8 @@ jobs: - name: Test Python Driver Flight SQL shell: bash -l {0} run: | + docker compose up -d flightsql-test + export ADBC_TEST_FLIGHTSQL_URI=grpc://localhost:41414 env BUILD_ALL=0 BUILD_DRIVER_FLIGHTSQL=1 ./ci/scripts/python_test.sh "$(pwd)" "$(pwd)/build" "$HOME/local" - name: Build Python Driver PostgreSQL shell: bash -l {0} diff --git a/python/adbc_driver_flightsql/tests/test_errors.py b/python/adbc_driver_flightsql/tests/test_errors.py index ed44b6a3fa..3192475941 100644 --- a/python/adbc_driver_flightsql/tests/test_errors.py +++ b/python/adbc_driver_flightsql/tests/test_errors.py @@ -15,7 +15,11 @@ # specific language governing permissions and limitations # under the License. +import os import re +import signal +import threading +import time import google.protobuf.any_pb2 as any_pb2 import google.protobuf.wrappers_pb2 as wrappers_pb2 @@ -45,6 +49,41 @@ def test_query_cancel(test_dbapi): cur.fetchone() +def test_query_cancel_async(test_dbapi): + with test_dbapi.cursor() as cur: + cur.execute("forever") + + def _cancel(): + time.sleep(2) + cur.adbc_cancel() + + t = threading.Thread(target=_cancel, daemon=True) + t.start() + + with pytest.raises( + test_dbapi.OperationalError, + match=re.escape("CANCELLED: [FlightSQL] context canceled"), + ): + cur.fetchone() + + +def test_query_cancel_sigint(test_dbapi): + with test_dbapi.cursor() as cur: + cur.execute("forever") + + # XXX: this handles the case of DoGet taking forever, but we also will + # want to test GetFlightInfo taking forever + + def _cancel(): + time.sleep(2) + os.kill(os.getpid(), signal.SIGINT) + + t = threading.Thread(target=_cancel, daemon=True) + t.start() + with pytest.raises(KeyboardInterrupt): + cur.fetchone() + + def test_query_error_fetch(test_dbapi): with test_dbapi.cursor() as cur: cur.execute("error_do_get") diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py index 1e86144c12..1113137195 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py +++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py @@ -55,6 +55,7 @@ import adbc_driver_manager from . import _lib, _reader +from .util import _blocking_call if typing.TYPE_CHECKING: import pandas @@ -677,9 +678,12 @@ def execute(self, operation: Union[bytes, str], parameters=None) -> None: parameters, which will each be bound in turn). """ self._prepare_execute(operation, parameters) - handle, self._rowcount = self._stmt.execute_query() + + handle, self._rowcount = _blocking_call( + self._stmt.execute_query, [], {}, self._stmt.cancel + ) self._results = _RowIterator( - _reader.AdbcRecordBatchReader._import_from_c(handle.address) + self._stmt, _reader.AdbcRecordBatchReader._import_from_c(handle.address) ) def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None: @@ -991,7 +995,7 @@ def adbc_read_partition(self, partition: bytes) -> None: handle = self._conn._conn.read_partition(partition) self._rowcount = -1 self._results = _RowIterator( - pyarrow.RecordBatchReader._import_from_c(handle.address) + self._stmt, pyarrow.RecordBatchReader._import_from_c(handle.address) ) @property @@ -1095,7 +1099,8 @@ def fetch_record_batch(self) -> pyarrow.RecordBatchReader: class _RowIterator(_Closeable): """Track state needed to iterate over the result set.""" - def __init__(self, reader: pyarrow.RecordBatchReader) -> None: + def __init__(self, stmt, reader: pyarrow.RecordBatchReader) -> None: + self._stmt = stmt self._reader = reader self._current_batch = None self._next_row = 0 @@ -1118,7 +1123,9 @@ def fetchone(self) -> Optional[tuple]: if self._current_batch is None or self._next_row >= len(self._current_batch): try: while True: - self._current_batch = self._reader.read_next_batch() + self._current_batch = _blocking_call( + self._reader.read_next_batch, [], {}, self._stmt.cancel + ) if self._current_batch.num_rows > 0: break self._next_row = 0 diff --git a/python/adbc_driver_manager/adbc_driver_manager/util.py b/python/adbc_driver_manager/adbc_driver_manager/util.py new file mode 100644 index 0000000000..57298bdee9 --- /dev/null +++ b/python/adbc_driver_manager/adbc_driver_manager/util.py @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import threading + + +def _blocking_call(func, args, kwargs, cancel): + """Run functions that are expected to block off of the main thread.""" + if threading.current_thread() is not threading.main_thread(): + return func(*args, **kwargs) + + ret = None + + def _background_task(): + nonlocal ret + ret = func(*args, **kwargs) + + bg = threading.Thread(target=_background_task) + bg.start() + + try: + bg.join() + except KeyboardInterrupt: + try: + cancel() + finally: + bg.join() + raise + + return ret