Skip to content

Commit

Permalink
feat(python/adbc_driver_manager): handle KeyboardInterrupt
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm committed Jan 23, 2024
1 parent ccc989a commit 4217ad6
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 5 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/native-unix.yml
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,8 @@ jobs:
- name: Test Python Driver Flight SQL
shell: bash -l {0}
run: |
docker compose up -d flightsql-test
export ADBC_TEST_FLIGHTSQL_URI=grpc://localhost:41414
env BUILD_ALL=0 BUILD_DRIVER_FLIGHTSQL=1 ./ci/scripts/python_test.sh "$(pwd)" "$(pwd)/build" "$HOME/local"
- name: Build Python Driver PostgreSQL
shell: bash -l {0}
Expand Down
39 changes: 39 additions & 0 deletions python/adbc_driver_flightsql/tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
# specific language governing permissions and limitations
# under the License.

import os
import re
import signal
import threading
import time

import google.protobuf.any_pb2 as any_pb2
import google.protobuf.wrappers_pb2 as wrappers_pb2
Expand Down Expand Up @@ -45,6 +49,41 @@ def test_query_cancel(test_dbapi):
cur.fetchone()


def test_query_cancel_async(test_dbapi):
with test_dbapi.cursor() as cur:
cur.execute("forever")

def _cancel():
time.sleep(2)
cur.adbc_cancel()

t = threading.Thread(target=_cancel, daemon=True)
t.start()

with pytest.raises(
test_dbapi.OperationalError,
match=re.escape("CANCELLED: [FlightSQL] context canceled"),
):
cur.fetchone()


def test_query_cancel_sigint(test_dbapi):
with test_dbapi.cursor() as cur:
cur.execute("forever")

# XXX: this handles the case of DoGet taking forever, but we also will
# want to test GetFlightInfo taking forever

def _cancel():
time.sleep(2)
os.kill(os.getpid(), signal.SIGINT)

t = threading.Thread(target=_cancel, daemon=True)
t.start()
with pytest.raises(KeyboardInterrupt):
cur.fetchone()


def test_query_error_fetch(test_dbapi):
with test_dbapi.cursor() as cur:
cur.execute("error_do_get")
Expand Down
17 changes: 12 additions & 5 deletions python/adbc_driver_manager/adbc_driver_manager/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import adbc_driver_manager

from . import _lib, _reader
from .util import _blocking_call

if typing.TYPE_CHECKING:
import pandas
Expand Down Expand Up @@ -677,9 +678,12 @@ def execute(self, operation: Union[bytes, str], parameters=None) -> None:
parameters, which will each be bound in turn).
"""
self._prepare_execute(operation, parameters)
handle, self._rowcount = self._stmt.execute_query()

handle, self._rowcount = _blocking_call(
self._stmt.execute_query, [], {}, self._stmt.cancel
)
self._results = _RowIterator(
_reader.AdbcRecordBatchReader._import_from_c(handle.address)
self._stmt, _reader.AdbcRecordBatchReader._import_from_c(handle.address)
)

def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None:
Expand Down Expand Up @@ -991,7 +995,7 @@ def adbc_read_partition(self, partition: bytes) -> None:
handle = self._conn._conn.read_partition(partition)
self._rowcount = -1
self._results = _RowIterator(
pyarrow.RecordBatchReader._import_from_c(handle.address)
self._stmt, pyarrow.RecordBatchReader._import_from_c(handle.address)
)

@property
Expand Down Expand Up @@ -1095,7 +1099,8 @@ def fetch_record_batch(self) -> pyarrow.RecordBatchReader:
class _RowIterator(_Closeable):
"""Track state needed to iterate over the result set."""

def __init__(self, reader: pyarrow.RecordBatchReader) -> None:
def __init__(self, stmt, reader: pyarrow.RecordBatchReader) -> None:
self._stmt = stmt
self._reader = reader
self._current_batch = None
self._next_row = 0
Expand All @@ -1118,7 +1123,9 @@ def fetchone(self) -> Optional[tuple]:
if self._current_batch is None or self._next_row >= len(self._current_batch):
try:
while True:
self._current_batch = self._reader.read_next_batch()
self._current_batch = _blocking_call(
self._reader.read_next_batch, [], {}, self._stmt.cancel
)
if self._current_batch.num_rows > 0:
break
self._next_row = 0
Expand Down
44 changes: 44 additions & 0 deletions python/adbc_driver_manager/adbc_driver_manager/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import threading


def _blocking_call(func, args, kwargs, cancel):
"""Run functions that are expected to block off of the main thread."""
if threading.current_thread() is not threading.main_thread():
return func(*args, **kwargs)

ret = None

def _background_task():
nonlocal ret
ret = func(*args, **kwargs)

bg = threading.Thread(target=_background_task)
bg.start()

try:
bg.join()
except KeyboardInterrupt:
try:
cancel()
finally:
bg.join()
raise

return ret

0 comments on commit 4217ad6

Please sign in to comment.