Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AnyRecordBatch: enum over data interface and stream interface. #25

Merged
merged 2 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pyo3-arrow/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ impl PyArray {
self.array.len()
}

/// Construct this object from existing Arrow data
/// Construct this object from an existing Arrow object.
///
/// It can be called on anything that exports the Arrow data interface
/// (`__arrow_c_array__`).
///
/// Args:
/// input: Arrow array to use for constructing this object
Expand Down
6 changes: 5 additions & 1 deletion pyo3-arrow/src/chunked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ impl PyChunkedArray {
(self.chunks, self.field)
}

/// Convert this to a Python `arro3.core.ChunkedArray`.
/// Export this to a Python `arro3.core.ChunkedArray`.
pub fn to_python(&self, py: Python) -> PyArrowResult<PyObject> {
let arro3_mod = py.import_bound(intern!(py, "arro3.core"))?;
let core_obj = arro3_mod
Expand Down Expand Up @@ -102,6 +102,10 @@ impl PyChunkedArray {
self.__array__(py)
}

/// Construct this from an existing Arrow object.
///
/// It can be called on anything that exports the Arrow stream interface
/// (`__arrow_c_stream__`). All batches will be materialized in memory.
#[classmethod]
pub fn from_arrow(_cls: &Bound<PyType>, input: &Bound<PyAny>) -> PyResult<Self> {
input.extract()
Expand Down
19 changes: 19 additions & 0 deletions pyo3-arrow/src/ffi/from_python/input.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use crate::input::AnyRecordBatch;
use crate::{PyRecordBatch, PyRecordBatchReader};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::{PyAny, PyResult};

impl<'a> FromPyObject<'a> for AnyRecordBatch {
fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
if ob.hasattr("__arrow_c_array__")? {
Ok(Self::RecordBatch(PyRecordBatch::extract_bound(ob)?))
} else if ob.hasattr("__arrow_c_stream__")? {
Ok(Self::Stream(PyRecordBatchReader::extract_bound(ob)?))
} else {
Err(PyValueError::new_err(
"Expected object with __arrow_c_array__ or __arrow_c_stream__ method",
))
}
}
}
1 change: 1 addition & 0 deletions pyo3-arrow/src/ffi/from_python/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod array;
pub mod chunked;
pub mod ffi_stream;
pub mod field;
pub mod input;
pub mod record_batch;
pub mod record_batch_reader;
pub mod schema;
Expand Down
6 changes: 5 additions & 1 deletion pyo3-arrow/src/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl PyField {
Self(field)
}

/// Convert this to a Python `arro3.core.Field`.
/// Export this to a Python `arro3.core.Field`.
pub fn to_python(&self, py: Python) -> PyArrowResult<PyObject> {
let arro3_mod = py.import_bound(intern!(py, "arro3.core"))?;
let core_obj = arro3_mod.getattr(intern!(py, "Field"))?.call_method1(
Expand Down Expand Up @@ -74,6 +74,10 @@ impl PyField {
self.0 == other.0
}

/// Construct this from an existing Arrow object.
///
/// It can be called on anything that exports the Arrow schema interface
/// (`__arrow_c_schema__`).
#[classmethod]
pub fn from_arrow(_cls: &Bound<PyType>, input: &Bound<PyAny>) -> PyResult<Self> {
input.extract()
Expand Down
8 changes: 8 additions & 0 deletions pyo3-arrow/src/input.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
use crate::{PyRecordBatch, PyRecordBatchReader};

/// An enum over [PyRecordBatch] and [PyRecordBatchReader], used when a function accepts either
/// Arrow object as input.
pub enum AnyRecordBatch {
RecordBatch(PyRecordBatch),
Stream(PyRecordBatchReader),
}
1 change: 1 addition & 0 deletions pyo3-arrow/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod chunked;
pub mod error;
mod ffi;
mod field;
pub mod input;
mod interop;
mod record_batch;
mod record_batch_reader;
Expand Down
7 changes: 5 additions & 2 deletions pyo3-arrow/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl PyRecordBatch {
Self(batch)
}

/// Convert this to a Python `arro3.core.RecordBatch`.
/// Export this to a Python `arro3.core.RecordBatch`.
pub fn to_python(&self, py: Python) -> PyArrowResult<PyObject> {
let arro3_mod = py.import_bound(intern!(py, "arro3.core"))?;
let core_obj = arro3_mod
Expand Down Expand Up @@ -88,7 +88,10 @@ impl PyRecordBatch {
self.0 == other.0
}

/// Construct this object from existing Arrow data
/// Construct this from an existing Arrow RecordBatch.
///
/// It can be called on anything that exports the Arrow data interface
/// (`__arrow_c_array__`) and returns a StructArray..
///
/// Args:
/// input: Arrow array to use for constructing this object
Expand Down
31 changes: 29 additions & 2 deletions pyo3-arrow/src/record_batch_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use pyo3::types::{PyCapsule, PyTuple, PyType};

use crate::error::PyArrowResult;
use crate::ffi::from_python::utils::import_stream_pycapsule;
use crate::PyTable;

/// A Python-facing Arrow record batch reader.
///
Expand All @@ -17,6 +18,14 @@ use crate::ffi::from_python::utils::import_stream_pycapsule;
pub struct PyRecordBatchReader(pub(crate) Option<Box<dyn RecordBatchReader + Send>>);

impl PyRecordBatchReader {
/// Returns `true` if this reader has already been consumed.
pub fn closed(&self) -> bool {
self.0.is_none()
}

/// Consume this reader and convert into a [RecordBatchReader].
///
/// The reader can only be consumed once. Calling `into_reader`
pub fn into_reader(mut self) -> PyArrowResult<Box<dyn RecordBatchReader + Send>> {
let stream = self
.0
Expand All @@ -25,7 +34,21 @@ impl PyRecordBatchReader {
Ok(stream)
}

/// Convert this to a Python `arro3.core.RecordBatchReader`.
/// Consume this reader and create a [PyTable] object
pub fn into_table(mut self) -> PyArrowResult<PyTable> {
let stream = self
.0
.take()
.ok_or(PyIOError::new_err("Cannot write from closed stream."))?;
let schema = stream.schema();
let mut batches = vec![];
for batch in stream {
batches.push(batch?);
}
Ok(PyTable::new(schema, batches))
}

/// Export this to a Python `arro3.core.RecordBatchReader`.
pub fn to_python(&mut self, py: Python) -> PyArrowResult<PyObject> {
let arro3_mod = py.import_bound(intern!(py, "arro3.core"))?;
let core_obj = arro3_mod
Expand Down Expand Up @@ -63,12 +86,16 @@ impl PyRecordBatchReader {
PyCapsule::new_bound(py, ffi_stream, Some(stream_capsule_name))
}

/// Construct this from an existing Arrow object.
///
/// It can be called on anything that exports the Arrow stream interface
/// (`__arrow_c_stream__`), such as a `Table` or `RecordBatchReader`.
#[classmethod]
pub fn from_arrow(_cls: &Bound<PyType>, input: &Bound<PyAny>) -> PyResult<Self> {
input.extract()
}

/// Construct this object from a bare Arrow PyCapsule
/// Construct this object from a bare Arrow PyCapsule.
#[classmethod]
pub fn from_arrow_pycapsule(
_cls: &Bound<PyType>,
Expand Down
7 changes: 5 additions & 2 deletions pyo3-arrow/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl PySchema {
Self(schema)
}

/// Convert this to a Python `arro3.core.Schema`.
/// Export this to a Python `arro3.core.Schema`.
pub fn to_python(&self, py: Python) -> PyArrowResult<PyObject> {
let arro3_mod = py.import_bound(intern!(py, "arro3.core"))?;
let core_obj = arro3_mod.getattr(intern!(py, "Schema"))?.call_method1(
Expand Down Expand Up @@ -67,7 +67,10 @@ impl PySchema {
Ok(schema_capsule)
}

/// Construct this object from existing Arrow data
/// Construct this object from an existing Arrow object
///
/// It can be called on anything that exports the Arrow data interface
/// (`__arrow_c_array__`) and returns a struct field.
///
/// Args:
/// input: Arrow array to use for constructing this object
Expand Down
9 changes: 7 additions & 2 deletions pyo3-arrow/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl PyTable {
(self.batches, self.schema)
}

/// Convert this to a Python `arro3.core.Table`.
/// Export this to a Python `arro3.core.Table`.
pub fn to_python(&self, py: Python) -> PyArrowResult<PyObject> {
let arro3_mod = py.import_bound(intern!(py, "arro3.core"))?;
let core_obj = arro3_mod.getattr(intern!(py, "Table"))?.call_method1(
Expand Down Expand Up @@ -82,7 +82,12 @@ impl PyTable {
self.batches.iter().fold(0, |acc, x| acc + x.num_rows())
}

/// Construct this object from existing Arrow data
/// Construct this object from an existing Arrow object.
///
/// It can be called on anything that exports the Arrow stream interface
/// (`__arrow_c_stream__`) and yields a StructArray for each item. This Table will materialize
/// all items from the iterator in memory at once. Use RecordBatchReader if you don't wish to
/// materialize all batches in memory at once.
///
/// Args:
/// input: Arrow array to use for constructing this object
Expand Down