Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python/adbc_driver_manager): experiment with using PyCapsules #702

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 137 additions & 7 deletions python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@ import typing
from typing import List, Tuple

import cython
cimport cpython
from cpython.bytes cimport PyBytes_FromStringAndSize
from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t, uintptr_t
from libc.string cimport memset
from libc.string cimport memset, memcpy
from libcpp.vector cimport vector as c_vector
from libc.stdlib cimport malloc, free
from libc.errno cimport EIO

if typing.TYPE_CHECKING:
from typing import Self
Expand All @@ -40,8 +43,13 @@ cdef extern from "adbc.h" nogil:
pass
cdef struct CArrowArray"ArrowArray":
pass

cdef struct CArrowArrayStream"ArrowArrayStream":
pass
int (*get_schema)(CArrowArrayStream* stream, CArrowSchema* out) nogil noexcept
int (*get_next)(CArrowArrayStream* stream, CArrowArray* out) nogil noexcept
const char* (*get_last_error)(CArrowArrayStream*) nogil noexcept
void (*release)(CArrowArrayStream*) nogil noexcept
void* private_data

# ADBC
ctypedef uint8_t CAdbcStatusCode"AdbcStatusCode"
Expand Down Expand Up @@ -460,6 +468,20 @@ cdef class _AdbcHandle:
f"with open {self._child_type}")


cdef void pycapsule_stream_deleter(object stream_capsule):
cdef:
CArrowArrayStream* stream
# Do not invoke the deleter on a used/moved capsule
stream = <CArrowArrayStream*>cpython.PyCapsule_GetPointer(
stream_capsule, 'arrowarraystream'
)
Comment on lines +475 to +477
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any situation in which stream can be NULL? (In R this happens if somebody tries the equivalent of pickling and unpickling, but I presume that would error at the pickling stage here?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it can be. You can see this in the CPython documentation:

https://docs.python.org/3/c-api/capsule.html#c.PyCapsule_GetPointer

I think just need to immediately return if that is NULL

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is pretty tricky though. PyCapsule_GetPointer will set the global python error but I'm not sure how you'd know to check for that after this is executed; so this could potentially be a pitfall of segfaults or leaked pointers

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to check for error, Cython will do it for you thanks to its PyCapsule_GetPointer declaration here:
https://github.com/cython/cython/blob/d73164b56544def09b65d250d72b227a38944bb1/Cython/Includes/cpython/pycapsule.pxd#L50

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for having a NULL stream pointer in a C ArrowArrayStream capsule, this should probably be disallowed by the spec.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like it is not possible anyway ( https://docs.python.org/3/c-api/capsule.html#c.PyCapsule_New ) and would only ever occur on error (perhaps if there was a capsule name mismatch). The fact that this can happen in R is a peculiarity of R's save/load...it seems unlikely in the destructor here but perhaps worth checking to avoid a crash.

if stream.release != NULL:
print("calling the release callback")
stream.release(stream)

free(stream)


cdef class ArrowSchemaHandle:
"""
A wrapper for an allocated ArrowSchema.
Expand All @@ -486,6 +508,26 @@ cdef class ArrowArrayHandle:
return <uintptr_t> &self.array


def _create_stream_capsule():
"""
Create PyCapsule holding a newly allocated (blank) ArrowArrayStream
"""
cdef CArrowArrayStream* stream = <CArrowArrayStream*>malloc(
cython.sizeof(CArrowArrayStream)
)
memset(stream, 0, cython.sizeof(CArrowArrayStream))

return cpython.PyCapsule_New(
stream, 'arrowarraystream', pycapsule_stream_deleter
)


cdef CArrowArrayStream* _get_stream_pointer(stream_capsule):
return <CArrowArrayStream*>cpython.PyCapsule_GetPointer(
stream_capsule, 'arrowarraystream'
)


cdef class ArrowArrayStreamHandle:
"""
A wrapper for an allocated ArrowArrayStream.
Expand Down Expand Up @@ -878,6 +920,7 @@ cdef class AdbcStatement(_AdbcHandle):
cdef:
AdbcConnection connection
CAdbcStatement statement
bint closed

def __init__(self, AdbcConnection connection) -> None:
super().__init__("(no child type)")
Expand All @@ -893,6 +936,7 @@ cdef class AdbcStatement(_AdbcHandle):
check_error(status, &c_error)

connection._open_child()
self.closed = False

def bind(self, data, schema) -> None:
"""
Expand Down Expand Up @@ -960,6 +1004,7 @@ cdef class AdbcStatement(_AdbcHandle):
cdef CAdbcError c_error = empty_error()
cdef CAdbcStatusCode status
self.connection._close_child()
self.closed = True
with self._lock:
if self.statement.private_data == NULL:
return
Expand All @@ -968,28 +1013,31 @@ cdef class AdbcStatement(_AdbcHandle):
status = AdbcStatementRelease(&self.statement, &c_error)
check_error(status, &c_error)

def execute_query(self) -> Tuple[ArrowArrayStreamHandle, int]:
def execute_query(self) -> Tuple["PyCapsule", int]:
"""
Execute the query and get the result set.

Returns
-------
ArrowArrayStreamHandle
PyCapsule holding an ArrowArrayStream
The result set.
int
The number of rows if known, else -1.
"""
cdef CAdbcError c_error = empty_error()
cdef ArrowArrayStreamHandle stream = ArrowArrayStreamHandle()
cdef int64_t rows_affected = 0

stream_capsule = _create_stream_capsule()
cdef CArrowArrayStream* stream = _get_stream_pointer(stream_capsule)

with nogil:
status = AdbcStatementExecuteQuery(
&self.statement,
&stream.stream,
stream,
&rows_affected,
&c_error)
check_error(status, &c_error)
return (stream, rows_affected)
return (stream_capsule, rows_affected)

def execute_partitions(self) -> Tuple[List[bytes], ArrowSchemaHandle, int]:
"""
Expand Down Expand Up @@ -1132,3 +1180,85 @@ cdef class AdbcStatement(_AdbcHandle):
status = AdbcStatementSetSubstraitPlan(
&self.statement, c_plan, length, &c_error)
check_error(status, &c_error)


# Implementation of an ArrowArrayStream that keeps a dependent object valid


cdef struct ArrowArrayStreamWrapper:
cpython.PyObject* parent_statement
CArrowArrayStream* parent_array_stream
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you may want this struct to own the memory here rather than just a pointer (i.e., CArrowArrayStream parent_array_stream).

bint error_set


cdef void wrapper_array_stream_release(CArrowArrayStream* array_stream) nogil noexcept:
cdef ArrowArrayStreamWrapper* data

if array_stream.private_data != NULL:
data = <ArrowArrayStreamWrapper*>array_stream.private_data
data.parent_array_stream.release(data.parent_array_stream)

with gil:
cpython.Py_DECREF(<AdbcStatement>data.parent_statement)

free(array_stream.private_data)

array_stream.release = NULL


cdef const char* wrapper_array_stream_get_last_error(CArrowArrayStream* array_stream) nogil noexcept:
cdef ArrowArrayStreamWrapper* data = <ArrowArrayStreamWrapper*>array_stream.private_data
if data.error_set:
return "AdbcStatement already closed"
return data.parent_array_stream.get_last_error(data.parent_array_stream)


cdef int wrapper_array_stream_get_schema(CArrowArrayStream* array_stream, CArrowSchema* out) nogil noexcept:
cdef ArrowArrayStreamWrapper* data = <ArrowArrayStreamWrapper*>array_stream.private_data
if (<AdbcStatement>data.parent_statement).closed:
data.error_set = True
return EIO
return data.parent_array_stream.get_schema(data.parent_array_stream, out)


cdef int wrapper_array_stream_get_next(CArrowArrayStream* array_stream, CArrowArray* out) nogil noexcept:
cdef ArrowArrayStreamWrapper* data = <ArrowArrayStreamWrapper*>(array_stream.private_data)
if (<AdbcStatement>data.parent_statement).closed:
data.error_set = True
Comment on lines +1226 to +1227
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure that you need error_set: in theory, the array stream that you are wrapping should be performing that check before doing something that might crash (although I get that right now it might not be)?

return EIO
return data.parent_array_stream.get_next(data.parent_array_stream, out)


def export_array_stream(object array_stream_capsule, AdbcStatement parent_statement):
"""
Given an ArrowArrayStream PyCapsule, return a new ArrowArrayStream capsule
wrapping the original stream and statement object.
"""
cdef CArrowArrayStream* array_stream = _get_stream_pointer(array_stream_capsule)

array_stream_capsule_exported = _create_stream_capsule()
cdef CArrowArrayStream* array_stream_exported = _get_stream_pointer(
array_stream_capsule_exported)

# move input array stream
cdef CArrowArrayStream* array_stream_moved = <CArrowArrayStream*>malloc(
cython.sizeof(CArrowArrayStream))
memset(array_stream_moved, 0, cython.sizeof(CArrowArrayStream))
memcpy(array_stream_moved, array_stream, sizeof(CArrowArrayStream))
array_stream.release = NULL

array_stream_exported.private_data = NULL
array_stream_exported.get_last_error = &wrapper_array_stream_get_last_error
array_stream_exported.get_schema = &wrapper_array_stream_get_schema
array_stream_exported.get_next = &wrapper_array_stream_get_next
array_stream_exported.release = &wrapper_array_stream_release

cdef ArrowArrayStreamWrapper* data = <ArrowArrayStreamWrapper*>malloc(
cython.sizeof(ArrowArrayStreamWrapper))
data.parent_array_stream = array_stream_moved
data.parent_statement = <cpython.PyObject*>parent_statement
cpython.Py_INCREF(parent_statement)
data.error_set = False
array_stream_exported.private_data = data

return array_stream_capsule_exported