Skip to content

Commit

Permalink
feat(python/adbc_driver_manager): handle KeyboardInterrupt
Browse files Browse the repository at this point in the history
Fixes #1484.
  • Loading branch information
lidavidm committed Feb 2, 2024
1 parent 5e21134 commit 3c436d7
Show file tree
Hide file tree
Showing 9 changed files with 266 additions and 5 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/native-unix.yml
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,8 @@ jobs:
- name: Test Python Driver Flight SQL
shell: bash -l {0}
run: |
docker compose up -d flightsql-test
export ADBC_TEST_FLIGHTSQL_URI=grpc://localhost:41414
env BUILD_ALL=0 BUILD_DRIVER_FLIGHTSQL=1 ./ci/scripts/python_test.sh "$(pwd)" "$(pwd)/build" "$HOME/local"
- name: Build Python Driver PostgreSQL
shell: bash -l {0}
Expand Down
48 changes: 48 additions & 0 deletions python/adbc_driver_flightsql/tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
# specific language governing permissions and limitations
# under the License.

import os
import re
import signal
import threading
import time

import google.protobuf.any_pb2 as any_pb2
import google.protobuf.wrappers_pb2 as wrappers_pb2
Expand Down Expand Up @@ -45,6 +49,50 @@ def test_query_cancel(test_dbapi):
cur.fetchone()


def test_query_cancel_async(test_dbapi):
with test_dbapi.cursor() as cur:
cur.execute("forever")

def _cancel():
time.sleep(2)
cur.adbc_cancel()

t = threading.Thread(target=_cancel, daemon=True)
t.start()

with pytest.raises(
test_dbapi.OperationalError,
match=re.escape("CANCELLED: [FlightSQL] context canceled"),
):
cur.fetchone()


def test_query_cancel_sigint(test_dbapi):
with test_dbapi.cursor() as cur:
for _ in range(3):
cur.execute("forever")

# XXX: this handles the case of DoGet taking forever, but we also will
# want to test GetFlightInfo taking forever

def _cancel():
time.sleep(2)
os.kill(os.getpid(), signal.SIGINT)

t = threading.Thread(target=_cancel, daemon=True)
t.start()
with pytest.raises(
test_dbapi.OperationalError,
match=re.escape("CANCELLED: [FlightSQL] context canceled"),
):
cur.fetchone()

# The cursor should still be usable
cur.execute("error_do_get")
with pytest.raises(test_dbapi.ProgrammingError):
cur.fetch_arrow_table()


def test_query_error_fetch(test_dbapi):
with test_dbapi.cursor() as cur:
cur.execute("error_do_get")
Expand Down
119 changes: 119 additions & 0 deletions python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "_blocking_impl.h"

#include <fcntl.h>
#include <unistd.h>
#include <csignal>
#include <cstring>
#include <mutex>
#include <thread>

namespace pyadbc_driver_manager {

// This is somewhat derived from io_util.cc in arrow, but that implementation
// isn't easily used outside of Arrow's monolith.
namespace {
static std::once_flag kSpawnThread;
static std::thread kCancelThread;

static std::mutex cancel_mutex;
static void (*cancel_callback)(void*) = nullptr;
static void* cancel_callback_data = nullptr;
static int pipe[2];
struct sigaction old_sigint;
struct sigaction our_sigint;

std::string MakePipe() {
int rc = 0;
#if defined(__linux__) && defined(__GLIBC__)
rc = pipe2(pipe, O_CLOEXEC | O_NONBLOCK);
#else
return "Unsupported platform";
#endif

if (rc != 0) {
return std::strerror(errno);
}
return "";
}

void InterruptThread() {
while (true) {
char buf = 0;
ssize_t bytes_read = read(pipe[0], &buf, 1);
if (bytes_read < 0) {
if (errno == EINTR) continue;
// XXX: we failed reading from the pipe; warn?
} else if (bytes_read > 0) {
std::lock_guard<std::mutex> lock(cancel_mutex);
if (cancel_callback != nullptr) {
cancel_callback(cancel_callback_data);
}
cancel_callback = nullptr;
cancel_callback_data = nullptr;
}
}
}

void SigintHandler(int) { (void)write(pipe[1], "X", 1); }

} // namespace

std::string InitBlockingCallback() {
std::string error;
std::call_once(kSpawnThread, [&]() {
error = MakePipe();
if (!error.empty()) {
return;
}

our_sigint.sa_handler = &SigintHandler;
our_sigint.sa_flags = 0;
sigemptyset(&our_sigint.sa_mask);

kCancelThread = std::thread(InterruptThread);
kCancelThread.detach();
// TODO: set name of thread
});
return error;
}

void SetBlockingCallback(void (*callback)(void*), void* data) {
std::lock_guard<std::mutex> lock(cancel_mutex);
cancel_callback = callback;
cancel_callback_data = data;

int rc = sigaction(SIGINT, &our_sigint, &old_sigint);
if (rc != 0) {
// XXX: warn?
}
}

void ClearBlockingCallback() {
std::lock_guard<std::mutex> lock(cancel_mutex);
cancel_callback = nullptr;
cancel_callback_data = nullptr;

int rc = sigaction(SIGINT, &old_sigint, nullptr);
if (rc != 0) {
// XXX: warn?
}
}

} // namespace pyadbc_driver_manager
26 changes: 26 additions & 0 deletions python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include <string>

namespace pyadbc_driver_manager {

std::string InitBlockingCallback();
void SetBlockingCallback(void (*callback)(void*), void* data);
void ClearBlockingCallback();

} // namespace pyadbc_driver_manager
11 changes: 11 additions & 0 deletions python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

# NOTE: generated with mypy's stubgen, then hand-edited to fix things

import typing_extensions
from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Union

from typing import overload
Expand Down Expand Up @@ -201,3 +202,13 @@ def _test_error(
vendor_code: Optional[int],
sqlstate: Optional[str],
) -> Error: ...

_P = typing_extensions.ParamSpec("_P")
_T = typing.TypeVar("_T")

def _blocking_call(
func: typing.Callable[_P, _T],
args: _P.args,
kwargs: _P.kwargs,
cancel: typing.Callable[[], None],
) -> _T: ...
46 changes: 46 additions & 0 deletions python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""Low-level ADBC API."""

import enum
import functools
import threading
import typing
from typing import List, Tuple
Expand All @@ -33,6 +34,7 @@ from cpython.pycapsule cimport (
from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t, uintptr_t
from libc.stdlib cimport malloc, free
from libc.string cimport memcpy, memset
from libcpp.string cimport string as c_string
from libcpp.vector cimport vector as c_vector

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -1481,3 +1483,47 @@ cdef class AdbcStatement(_AdbcHandle):
cdef const CAdbcError* PyAdbcErrorFromArrayStream(
CArrowArrayStream* stream, CAdbcStatusCode* status):
return AdbcErrorFromArrayStream(stream, status)


cdef extern from "_blocking_impl.h" nogil:
ctypedef void (*BlockingCallback)(void*) noexcept nogil
c_string CInitBlockingCallback"pyadbc_driver_manager::InitBlockingCallback"()
void CSetBlockingCallback"pyadbc_driver_manager::SetBlockingCallback"(BlockingCallback, void* data)
void CClearBlockingCallback"pyadbc_driver_manager::ClearBlockingCallback"()


@functools.cache
def _init_blocking_call():
error = bytes(CInitBlockingCallback()).decode("utf-8")
if error:
raise RuntimeError(error)


def _blocking_call(func, args, kwargs, cancel):
"""
Run functions that are expected to block with a native SIGINT handler.
Parameters
----------
"""
if threading.current_thread() is not threading.main_thread():
return func(*args, **kwargs)

_init_blocking_call()

# Set the callback for the background thread and save the signal handler
CSetBlockingCallback(&_handle_blocking_call, <void*>cancel)

try:
return func(*args, **kwargs)
finally:
# Restore the signal handler
CClearBlockingCallback()


cdef void _handle_blocking_call(void* c_cancel) noexcept nogil:
# TODO: if this throws, we could save and restore the traceback later
# TODO: we could record that this was hit and raise a KeyboardInterrupt above
with gil:
cancel = <object> c_cancel
cancel()
17 changes: 12 additions & 5 deletions python/adbc_driver_manager/adbc_driver_manager/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import adbc_driver_manager

from . import _lib, _reader
from ._lib import _blocking_call

if typing.TYPE_CHECKING:
import pandas
Expand Down Expand Up @@ -677,9 +678,12 @@ def execute(self, operation: Union[bytes, str], parameters=None) -> None:
parameters, which will each be bound in turn).
"""
self._prepare_execute(operation, parameters)
handle, self._rowcount = self._stmt.execute_query()

handle, self._rowcount = _blocking_call(
self._stmt.execute_query, [], {}, self._stmt.cancel
)
self._results = _RowIterator(
_reader.AdbcRecordBatchReader._import_from_c(handle.address)
self._stmt, _reader.AdbcRecordBatchReader._import_from_c(handle.address)
)

def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None:
Expand Down Expand Up @@ -991,7 +995,7 @@ def adbc_read_partition(self, partition: bytes) -> None:
handle = self._conn._conn.read_partition(partition)
self._rowcount = -1
self._results = _RowIterator(
pyarrow.RecordBatchReader._import_from_c(handle.address)
self._stmt, pyarrow.RecordBatchReader._import_from_c(handle.address)
)

@property
Expand Down Expand Up @@ -1095,7 +1099,8 @@ def fetch_record_batch(self) -> pyarrow.RecordBatchReader:
class _RowIterator(_Closeable):
"""Track state needed to iterate over the result set."""

def __init__(self, reader: pyarrow.RecordBatchReader) -> None:
def __init__(self, stmt, reader: pyarrow.RecordBatchReader) -> None:
self._stmt = stmt
self._reader = reader
self._current_batch = None
self._next_row = 0
Expand All @@ -1118,7 +1123,9 @@ def fetchone(self) -> Optional[tuple]:
if self._current_batch is None or self._next_row >= len(self._current_batch):
try:
while True:
self._current_batch = self._reader.read_next_batch()
self._current_batch = _blocking_call(
self._reader.read_next_batch, [], {}, self._stmt.cancel
)
if self._current_batch.num_rows > 0:
break
self._next_row = 0
Expand Down
1 change: 1 addition & 0 deletions python/adbc_driver_manager/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ license = {text = "Apache-2.0"}
readme = "README.md"
requires-python = ">=3.9"
dynamic = ["version"]
dependencies = ["typing-extensions"]

[project.optional-dependencies]
dbapi = ["pandas", "pyarrow>=8.0.0"]
Expand Down
1 change: 1 addition & 0 deletions python/adbc_driver_manager/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def get_version(pkg_path):
include_dirs=[str(source_root.joinpath("adbc_driver_manager").resolve())],
language="c++",
sources=[
"adbc_driver_manager/_blocking_impl.cc",
"adbc_driver_manager/_lib.pyx",
"adbc_driver_manager/adbc_driver_manager.cc",
],
Expand Down

0 comments on commit 3c436d7

Please sign in to comment.