From 4217ad635da839e9dc619ae752a942db707f61ae Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 23 Jan 2024 16:47:47 -0500
Subject: [PATCH] feat(python/adbc_driver_manager): handle KeyboardInterrupt
Fixes #1484.
---
.github/workflows/native-unix.yml | 2 +
.../tests/test_errors.py | 39 ++++++++++++++++
.../adbc_driver_manager/dbapi.py | 17 ++++---
.../adbc_driver_manager/util.py | 44 +++++++++++++++++++
4 files changed, 97 insertions(+), 5 deletions(-)
create mode 100644 python/adbc_driver_manager/adbc_driver_manager/util.py
diff --git a/.github/workflows/native-unix.yml b/.github/workflows/native-unix.yml
index 142d5f4be7..7b99003f0d 100644
--- a/.github/workflows/native-unix.yml
+++ b/.github/workflows/native-unix.yml
@@ -477,6 +477,8 @@ jobs:
- name: Test Python Driver Flight SQL
shell: bash -l {0}
run: |
+ docker compose up -d flightsql-test
+ export ADBC_TEST_FLIGHTSQL_URI=grpc://localhost:41414
env BUILD_ALL=0 BUILD_DRIVER_FLIGHTSQL=1 ./ci/scripts/python_test.sh "$(pwd)" "$(pwd)/build" "$HOME/local"
- name: Build Python Driver PostgreSQL
shell: bash -l {0}
diff --git a/python/adbc_driver_flightsql/tests/test_errors.py b/python/adbc_driver_flightsql/tests/test_errors.py
index ed44b6a3fa..3192475941 100644
--- a/python/adbc_driver_flightsql/tests/test_errors.py
+++ b/python/adbc_driver_flightsql/tests/test_errors.py
@@ -15,7 +15,11 @@
# specific language governing permissions and limitations
# under the License.
+import os
import re
+import signal
+import threading
+import time
import google.protobuf.any_pb2 as any_pb2
import google.protobuf.wrappers_pb2 as wrappers_pb2
@@ -45,6 +49,41 @@ def test_query_cancel(test_dbapi):
cur.fetchone()
+def test_query_cancel_async(test_dbapi):
+ with test_dbapi.cursor() as cur:
+ cur.execute("forever")
+
+ def _cancel():
+ time.sleep(2)
+ cur.adbc_cancel()
+
+ t = threading.Thread(target=_cancel, daemon=True)
+ t.start()
+
+ with pytest.raises(
+ test_dbapi.OperationalError,
+ match=re.escape("CANCELLED: [FlightSQL] context canceled"),
+ ):
+ cur.fetchone()
+
+
+def test_query_cancel_sigint(test_dbapi):
+ with test_dbapi.cursor() as cur:
+ cur.execute("forever")
+
+ # XXX: this handles the case of DoGet taking forever, but we also will
+ # want to test GetFlightInfo taking forever
+
+ def _cancel():
+ time.sleep(2)
+ os.kill(os.getpid(), signal.SIGINT)
+
+ t = threading.Thread(target=_cancel, daemon=True)
+ t.start()
+ with pytest.raises(KeyboardInterrupt):
+ cur.fetchone()
+
+
def test_query_error_fetch(test_dbapi):
with test_dbapi.cursor() as cur:
cur.execute("error_do_get")
diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index 1e86144c12..1113137195 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -55,6 +55,7 @@
import adbc_driver_manager
from . import _lib, _reader
+from .util import _blocking_call
if typing.TYPE_CHECKING:
import pandas
@@ -677,9 +678,12 @@ def execute(self, operation: Union[bytes, str], parameters=None) -> None:
parameters, which will each be bound in turn).
"""
self._prepare_execute(operation, parameters)
- handle, self._rowcount = self._stmt.execute_query()
+
+ handle, self._rowcount = _blocking_call(
+ self._stmt.execute_query, [], {}, self._stmt.cancel
+ )
self._results = _RowIterator(
- _reader.AdbcRecordBatchReader._import_from_c(handle.address)
+ self._stmt, _reader.AdbcRecordBatchReader._import_from_c(handle.address)
)
def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None:
@@ -991,7 +995,7 @@ def adbc_read_partition(self, partition: bytes) -> None:
handle = self._conn._conn.read_partition(partition)
self._rowcount = -1
self._results = _RowIterator(
- pyarrow.RecordBatchReader._import_from_c(handle.address)
+ self._stmt, pyarrow.RecordBatchReader._import_from_c(handle.address)
)
@property
@@ -1095,7 +1099,8 @@ def fetch_record_batch(self) -> pyarrow.RecordBatchReader:
class _RowIterator(_Closeable):
"""Track state needed to iterate over the result set."""
- def __init__(self, reader: pyarrow.RecordBatchReader) -> None:
+ def __init__(self, stmt, reader: pyarrow.RecordBatchReader) -> None:
+ self._stmt = stmt
self._reader = reader
self._current_batch = None
self._next_row = 0
@@ -1118,7 +1123,9 @@ def fetchone(self) -> Optional[tuple]:
if self._current_batch is None or self._next_row >= len(self._current_batch):
try:
while True:
- self._current_batch = self._reader.read_next_batch()
+ self._current_batch = _blocking_call(
+ self._reader.read_next_batch, [], {}, self._stmt.cancel
+ )
if self._current_batch.num_rows > 0:
break
self._next_row = 0
diff --git a/python/adbc_driver_manager/adbc_driver_manager/util.py b/python/adbc_driver_manager/adbc_driver_manager/util.py
new file mode 100644
index 0000000000..57298bdee9
--- /dev/null
+++ b/python/adbc_driver_manager/adbc_driver_manager/util.py
@@ -0,0 +1,44 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import threading
+
+
+def _blocking_call(func, args, kwargs, cancel):
+ """Run functions that are expected to block off of the main thread."""
+ if threading.current_thread() is not threading.main_thread():
+ return func(*args, **kwargs)
+
+ ret = None
+
+ def _background_task():
+ nonlocal ret
+ ret = func(*args, **kwargs)
+
+ bg = threading.Thread(target=_background_task)
+ bg.start()
+
+ try:
+ bg.join()
+ except KeyboardInterrupt:
+ try:
+ cancel()
+ finally:
+ bg.join()
+ raise
+
+ return ret