From 640e984969f1e03ff17d5cc4d06f84cf914a1fab Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 23 Jan 2024 16:47:47 -0500
Subject: [PATCH] feat(python/adbc_driver_manager): handle KeyboardInterrupt
Fixes #1484.
---
.github/workflows/native-unix.yml | 13 +
ci/scripts/python_test.sh | 4 +-
docker-compose.yml | 6 +
.../tests/test_errors.py | 20 ++
python/adbc_driver_manager/MANIFEST.in | 2 +
.../adbc_driver_manager/_blocking_impl.cc | 269 ++++++++++++++++++
.../adbc_driver_manager/_blocking_impl.h | 38 +++
.../adbc_driver_manager/_lib.pyi | 11 +
.../adbc_driver_manager/_lib.pyx | 98 +++++++
.../adbc_driver_manager/dbapi.py | 17 +-
python/adbc_driver_manager/pyproject.toml | 1 +
python/adbc_driver_manager/setup.py | 3 +
.../tests/test_blocking.py | 140 +++++++++
13 files changed, 615 insertions(+), 7 deletions(-)
create mode 100644 python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.cc
create mode 100644 python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.h
create mode 100644 python/adbc_driver_manager/tests/test_blocking.py
diff --git a/.github/workflows/native-unix.yml b/.github/workflows/native-unix.yml
index 142d5f4be7..c4ab9f44f8 100644
--- a/.github/workflows/native-unix.yml
+++ b/.github/workflows/native-unix.yml
@@ -477,7 +477,20 @@ jobs:
- name: Test Python Driver Flight SQL
shell: bash -l {0}
run: |
+ # Can't use Docker on macOS
+ pushd $(pwd)/go/adbc
+ go build -o testserver ./driver/flightsql/cmd/testserver
+ popd
+ $(pwd)/go/adbc/testserver -host 0.0.0.0 -port 41414 &
+ while ! curl --http2-prior-knowledge -H "content-type: application/grpc" -v localhost:41414 -XPOST;
+ do
+ echo "Waiting for test server..."
+ jobs
+ sleep 5
+ done
+ export ADBC_TEST_FLIGHTSQL_URI=grpc://localhost:41414
env BUILD_ALL=0 BUILD_DRIVER_FLIGHTSQL=1 ./ci/scripts/python_test.sh "$(pwd)" "$(pwd)/build" "$HOME/local"
+ kill %1
- name: Build Python Driver PostgreSQL
shell: bash -l {0}
run: |
diff --git a/ci/scripts/python_test.sh b/ci/scripts/python_test.sh
index f8d7091791..6f95b5898e 100755
--- a/ci/scripts/python_test.sh
+++ b/ci/scripts/python_test.sh
@@ -58,8 +58,8 @@ test_subproject() {
fi
echo "=== Testing ${subproject} ==="
- echo env ${options[@]} python -m pytest -vv "${source_dir}/python/${subproject}/tests"
- env ${options[@]} python -m pytest -vv "${source_dir}/python/${subproject}/tests"
+ echo env ${options[@]} python -m pytest -vvs --full-trace "${source_dir}/python/${subproject}/tests"
+ env ${options[@]} python -m pytest -vvs --full-trace "${source_dir}/python/${subproject}/tests"
echo
}
diff --git a/docker-compose.yml b/docker-compose.yml
index 89394e598f..789d5d450d 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -150,6 +150,12 @@ services:
dockerfile: ci/docker/flightsql-test.dockerfile
args:
GO: ${GO}
+ healthcheck:
+ test: ["CMD", "curl", "--http2-prior-knowledge", "-XPOST", "-H", "content-type: application/grpc"]
+ interval: 5s
+ timeout: 30s
+ retries: 3
+ start_period: 5m
ports:
- "41414:41414"
volumes:
diff --git a/python/adbc_driver_flightsql/tests/test_errors.py b/python/adbc_driver_flightsql/tests/test_errors.py
index ed44b6a3fa..ee2b62d3ee 100644
--- a/python/adbc_driver_flightsql/tests/test_errors.py
+++ b/python/adbc_driver_flightsql/tests/test_errors.py
@@ -16,6 +16,8 @@
# under the License.
import re
+import threading
+import time
import google.protobuf.any_pb2 as any_pb2
import google.protobuf.wrappers_pb2 as wrappers_pb2
@@ -45,6 +47,24 @@ def test_query_cancel(test_dbapi):
cur.fetchone()
+def test_query_cancel_async(test_dbapi):
+ with test_dbapi.cursor() as cur:
+ cur.execute("forever")
+
+ def _cancel():
+ time.sleep(2)
+ cur.adbc_cancel()
+
+ t = threading.Thread(target=_cancel, daemon=True)
+ t.start()
+
+ with pytest.raises(
+ test_dbapi.OperationalError,
+ match=re.escape("CANCELLED: [FlightSQL] context canceled"),
+ ):
+ cur.fetchone()
+
+
def test_query_error_fetch(test_dbapi):
with test_dbapi.cursor() as cur:
cur.execute("error_do_get")
diff --git a/python/adbc_driver_manager/MANIFEST.in b/python/adbc_driver_manager/MANIFEST.in
index 306c31144f..298ff3a9ca 100644
--- a/python/adbc_driver_manager/MANIFEST.in
+++ b/python/adbc_driver_manager/MANIFEST.in
@@ -22,6 +22,8 @@ include NOTICE.txt
include adbc_driver_manager/adbc.h
include adbc_driver_manager/adbc_driver_manager.cc
include adbc_driver_manager/adbc_driver_manager.h
+include adbc_driver_manager/_blocking_impl.cc
+include adbc_driver_manager/_blocking_impl.h
include adbc_driver_manager/_lib.pxd
include adbc_driver_manager/_lib.pyi
include adbc_driver_manager/_reader.pyi
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.cc b/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.cc
new file mode 100644
index 0000000000..766b3964cd
--- /dev/null
+++ b/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.cc
@@ -0,0 +1,269 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "_blocking_impl.h"
+
+#if defined(_WIN32)
+#define NOMINMAX
+#define WIN32_LEAN_AND_MEAN
+#include
+#include
+#include
+#include
+#else
+#include
+#include
+#include
+#endif
+
+#include
+#include
+#include
+#include
+#include
+
+namespace pyadbc_driver_manager {
+
+// This is somewhat derived from io_util.cc in arrow, but that implementation
+// isn't easily used outside of Arrow's monolith.
+namespace {
+static std::once_flag kInitOnce;
+// We may encounter errors below that we can't do anything about. Use this to
+// print out an error, once.
+static std::once_flag kWarnOnce;
+// This thread reads from a pipe forever. Whenever it reads something, it
+// calls the callback below.
+static std::thread kCancelThread;
+
+static std::mutex cancel_mutex;
+// This callback is registered by the Python side; basically it will call
+// cancel() on an ADBC object.
+static void (*cancel_callback)(void*) = nullptr;
+// Callback state (a pointer to the ADBC PyObject).
+static void* cancel_callback_data = nullptr;
+// A nonblocking self-pipe.
+static int pipe[2];
+#if defined(_WIN32)
+void (*old_sigint)(int);
+#else
+// The old signal handler (most likely Python's).
+struct sigaction old_sigint;
+// Our signal handler (below).
+struct sigaction our_sigint;
+#endif
+
+std::string MakePipe() {
+ int rc = 0;
+#if defined(__linux__) && defined(__GLIBC__)
+ rc = pipe2(pipe, O_CLOEXEC);
+#elif defined(_WIN32)
+ rc = _pipe(pipe, 4096, _O_BINARY);
+#else
+ rc = ::pipe(pipe);
+#endif
+
+ if (rc != 0) {
+ return std::strerror(errno);
+ }
+
+#if (!defined(__linux__) || !defined(__GLIBC__)) && !defined(_WIN32)
+ {
+ int flags = fcntl(pipe[0], F_GETFD, 0);
+ if (flags < 0) {
+ return std::strerror(errno);
+ }
+ rc = fcntl(pipe[0], F_SETFD, flags | FD_CLOEXEC);
+ if (rc < 0) {
+ return std::strerror(errno);
+ }
+
+ flags = fcntl(pipe[1], F_GETFD, 0);
+ if (flags < 0) {
+ return std::strerror(errno);
+ }
+ rc = fcntl(pipe[1], F_SETFD, flags | FD_CLOEXEC);
+ if (rc < 0) {
+ return std::strerror(errno);
+ }
+ }
+#endif
+
+ // Make the write side nonblocking (the read side should stay blocking!)
+#if defined(_WIN32)
+ const auto handle = reinterpret_cast(_get_osfhandle(pipe[1]));
+ DWORD mode = PIPE_NOWAIT;
+ if (!SetNamedPipeHandleState(handle, &mode, nullptr, nullptr)) {
+ DWORD last_error = GetLastError();
+ LPVOID message;
+
+ FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM |
+ FORMAT_MESSAGE_IGNORE_INSERTS,
+ /*lpSource=*/nullptr, last_error,
+ MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
+ reinterpret_cast(&message), /*nSize=*/0, /*Arguments=*/nullptr);
+
+ std::string buffer = "(";
+ buffer += std::to_string(last_error);
+ buffer += ") ";
+ buffer += reinterpret_cast(message);
+ LocalFree(message);
+ return buffer;
+ }
+#else
+ {
+ int flags = fcntl(pipe[1], F_GETFL, 0);
+ if (flags < 0) {
+ return std::strerror(errno);
+ }
+ rc = fcntl(pipe[1], F_SETFL, flags | O_NONBLOCK);
+ if (rc < 0) {
+ return std::strerror(errno);
+ }
+ }
+#endif
+
+ return "";
+}
+
+void InterruptThread() {
+#if defined(__APPLE__)
+ pthread_setname_np("AdbcInterrupt");
+#endif
+
+ while (true) {
+ char buf = 0;
+ // Anytime something is written to the pipe, attempt to call the callback
+ auto bytes_read = read(pipe[0], &buf, 1);
+ if (bytes_read < 0) {
+ if (errno == EINTR) continue;
+
+ // XXX: we failed reading from the pipe
+ std::string message = std::strerror(errno);
+ std::call_once(kWarnOnce, [&]() {
+ std::cerr << "adbc_driver_manager (native code): error handling interrupt: "
+ << message << std::endl;
+ });
+ } else if (bytes_read > 0) {
+ // Save the callback locally instead of calling it under the lock, since
+ // otherwise we may deadlock with the Python side trying to call us
+ void (*local_callback)(void*) = nullptr;
+ void* local_callback_data = nullptr;
+
+ {
+ std::lock_guard lock(cancel_mutex);
+ if (cancel_callback != nullptr) {
+ local_callback = cancel_callback;
+ local_callback_data = cancel_callback_data;
+ }
+ cancel_callback = nullptr;
+ cancel_callback_data = nullptr;
+ }
+
+ if (local_callback != nullptr) {
+ local_callback(local_callback_data);
+ }
+ }
+ }
+}
+
+// We can't do much about failures here, so ignore the result. If the pipe is
+// full, that's fine; it just means the thread has fallen behind in processing
+// earlier interrupts.
+void SigintHandler(int) {
+#if defined(_WIN32)
+ (void)_write(pipe[1], "X", 1);
+#else
+ (void)write(pipe[1], "X", 1);
+#endif
+}
+
+} // namespace
+
+std::string InitBlockingCallback() {
+ std::string error;
+ std::call_once(kInitOnce, [&]() {
+ error = MakePipe();
+ if (!error.empty()) {
+ return;
+ }
+
+#if !defined(_WIN32)
+ our_sigint.sa_handler = &SigintHandler;
+ our_sigint.sa_flags = 0;
+ sigemptyset(&our_sigint.sa_mask);
+#endif
+
+ kCancelThread = std::thread(InterruptThread);
+#if defined(__linux__)
+ pthread_setname_np(kCancelThread.native_handle(), "AdbcInterrupt");
+#endif
+ kCancelThread.detach();
+ });
+ return error;
+}
+
+std::string SetBlockingCallback(void (*callback)(void*), void* data) {
+ std::lock_guard lock(cancel_mutex);
+ cancel_callback = callback;
+ cancel_callback_data = data;
+
+#if defined(_WIN32)
+ if (old_sigint == nullptr) {
+ old_sigint = signal(SIGINT, &SigintHandler);
+ if (old_sigint == SIG_ERR) {
+ old_sigint = nullptr;
+ return std::strerror(errno);
+ }
+ }
+#else
+ // Don't set the handler again if we're somehow called twice
+ if (old_sigint.sa_handler == nullptr && old_sigint.sa_sigaction == nullptr) {
+ int rc = sigaction(SIGINT, &our_sigint, &old_sigint);
+ if (rc != 0) {
+ return std::strerror(errno);
+ }
+ }
+#endif
+ return "";
+}
+
+std::string ClearBlockingCallback() {
+ std::lock_guard lock(cancel_mutex);
+ cancel_callback = nullptr;
+ cancel_callback_data = nullptr;
+
+#if defined(_WIN32)
+ if (old_sigint != nullptr) {
+ auto rc = signal(SIGINT, old_sigint);
+ old_sigint = nullptr;
+ if (rc == SIG_ERR) {
+ return std::strerror(errno);
+ }
+ }
+#else
+ if (old_sigint.sa_handler != nullptr || old_sigint.sa_sigaction != nullptr) {
+ int rc = sigaction(SIGINT, &old_sigint, nullptr);
+ std::memset(&old_sigint, 0, sizeof(old_sigint));
+ if (rc != 0) {
+ return std::strerror(errno);
+ }
+ }
+#endif
+ return "";
+}
+
+} // namespace pyadbc_driver_manager
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.h b/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.h
new file mode 100644
index 0000000000..ac76252f3e
--- /dev/null
+++ b/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.h
@@ -0,0 +1,38 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/// Allow KeyboardInterrupt to function with ADBC in Python.
+///
+/// Call SetBlockingCallback to register a callback. This will temporarily
+/// suppress the Python SIGINT handler. When SIGINT is received, this module
+/// will handle it by calling the callback.
+
+#include
+
+namespace pyadbc_driver_manager {
+
+/// \brief Set up internal state to handle.
+/// \return An error message (or empty string).
+std::string InitBlockingCallback();
+/// \brief Set the callback for when SIGINT is received.
+/// \return An error message (or empty string).
+std::string SetBlockingCallback(void (*callback)(void*), void* data);
+/// \brief Clear the callback for when SIGINT is received.
+/// \return An error message (or empty string).
+std::string ClearBlockingCallback();
+
+} // namespace pyadbc_driver_manager
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
index 7afada9ecc..2a818839a1 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
@@ -17,6 +17,7 @@
# NOTE: generated with mypy's stubgen, then hand-edited to fix things
+import typing_extensions
from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Union
from typing import overload
@@ -201,3 +202,13 @@ def _test_error(
vendor_code: Optional[int],
sqlstate: Optional[str],
) -> Error: ...
+
+_P = typing_extensions.ParamSpec("_P")
+_T = typing.TypeVar("_T")
+
+def _blocking_call(
+ func: typing.Callable[_P, _T],
+ args: tuple,
+ kwargs: dict,
+ cancel: typing.Callable[[], None],
+) -> _T: ...
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index 91139100bb..79222d6082 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -20,8 +20,12 @@
"""Low-level ADBC API."""
import enum
+import functools
import threading
+import os
import typing
+import sys
+import warnings
from typing import List, Tuple
cimport cpython
@@ -33,6 +37,7 @@ from cpython.pycapsule cimport (
from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t, uintptr_t
from libc.stdlib cimport malloc, free
from libc.string cimport memcpy, memset
+from libcpp.string cimport string as c_string
from libcpp.vector cimport vector as c_vector
if typing.TYPE_CHECKING:
@@ -1481,3 +1486,96 @@ cdef class AdbcStatement(_AdbcHandle):
cdef const CAdbcError* PyAdbcErrorFromArrayStream(
CArrowArrayStream* stream, CAdbcStatusCode* status):
return AdbcErrorFromArrayStream(stream, status)
+
+
+cdef extern from "_blocking_impl.h" nogil:
+ ctypedef void (*BlockingCallback)(void*) noexcept nogil
+ c_string CInitBlockingCallback"pyadbc_driver_manager::InitBlockingCallback"()
+ c_string CSetBlockingCallback"pyadbc_driver_manager::SetBlockingCallback"(BlockingCallback, void* data)
+ c_string CClearBlockingCallback"pyadbc_driver_manager::ClearBlockingCallback"()
+
+
+@functools.cache
+def _init_blocking_call():
+ error = bytes(CInitBlockingCallback()).decode("utf-8")
+ if error:
+ warnings.warn(
+ f"Failed to initialize KeyboardInterrupt support: {error}",
+ RuntimeWarning,
+ )
+
+
+_blocking_lock = threading.Lock()
+_blocking_exc = None
+
+
+def _blocking_call_impl(func, args, kwargs, cancel):
+ """
+ Run functions that are expected to block with a native SIGINT handler.
+
+ Parameters
+ ----------
+ """
+ global _blocking_exc
+
+ if threading.current_thread() is not threading.main_thread():
+ return func(*args, **kwargs)
+
+ _init_blocking_call()
+
+ with _blocking_lock:
+ if _blocking_exc:
+ _blocking_exc = None
+
+ # Set the callback for the background thread and save the signal handler
+ # TODO: ideally this would be no-op if already set
+ error = bytes(
+ CSetBlockingCallback(&_handle_blocking_call, cancel)
+ ).decode("utf-8")
+ if error:
+ warnings.warn(
+ f"Failed to set SIGINT handler: {error}",
+ RuntimeWarning,
+ )
+
+ try:
+ return func(*args, **kwargs)
+ except BaseException as e:
+ with _blocking_lock:
+ if _blocking_exc:
+ exc = _blocking_exc
+ _blocking_exc = None
+ raise e from exc[1].with_traceback(exc[2])
+ raise e
+ finally:
+ # Restore the signal handler
+ error = bytes(CClearBlockingCallback()).decode("utf-8")
+ if error:
+ warnings.warn(
+ f"Failed to restore SIGINT handler: {error}",
+ RuntimeWarning,
+ )
+ with _blocking_lock:
+ if _blocking_exc:
+ exc = _blocking_exc
+ _blocking_exc = None
+ raise exc[1].with_traceback(exc[2]) from KeyboardInterrupt
+
+
+if os.name != "nt":
+ _blocking_call = _blocking_call_impl
+else:
+ def _blocking_call(func, args, kwargs, cancel):
+ return func(*args, **kwargs)
+
+
+
+cdef void _handle_blocking_call(void* c_cancel) noexcept nogil:
+ with gil:
+ try:
+ cancel =