Skip to content

Commit

Permalink
AnyRecordBatch: enum over data interface and stream interface. (#25)
Browse files Browse the repository at this point in the history
* Any array input

* Rename to AnyRecordBatch
  • Loading branch information
kylebarron authored Jun 26, 2024
1 parent ff8e9a8 commit 182e7df
Show file tree
Hide file tree
Showing 11 changed files with 89 additions and 11 deletions.
5 changes: 4 additions & 1 deletion pyo3-arrow/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ impl PyArray {
self.array.len()
}

/// Construct this object from existing Arrow data
/// Construct this object from an existing Arrow object.
///
/// It can be called on anything that exports the Arrow data interface
/// (`__arrow_c_array__`).
///
/// Args:
/// input: Arrow array to use for constructing this object
Expand Down
6 changes: 5 additions & 1 deletion pyo3-arrow/src/chunked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ impl PyChunkedArray {
(self.chunks, self.field)
}

/// Convert this to a Python `arro3.core.ChunkedArray`.
/// Export this to a Python `arro3.core.ChunkedArray`.
pub fn to_python(&self, py: Python) -> PyArrowResult<PyObject> {
let arro3_mod = py.import_bound(intern!(py, "arro3.core"))?;
let core_obj = arro3_mod
Expand Down Expand Up @@ -102,6 +102,10 @@ impl PyChunkedArray {
self.__array__(py)
}

/// Construct this from an existing Arrow object.
///
/// It can be called on anything that exports the Arrow stream interface
/// (`__arrow_c_stream__`). All batches will be materialized in memory.
#[classmethod]
pub fn from_arrow(_cls: &Bound<PyType>, input: &Bound<PyAny>) -> PyResult<Self> {
input.extract()
Expand Down
19 changes: 19 additions & 0 deletions pyo3-arrow/src/ffi/from_python/input.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use crate::input::AnyRecordBatch;
use crate::{PyRecordBatch, PyRecordBatchReader};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::{PyAny, PyResult};

impl<'a> FromPyObject<'a> for AnyRecordBatch {
fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
if ob.hasattr("__arrow_c_array__")? {
Ok(Self::RecordBatch(PyRecordBatch::extract_bound(ob)?))
} else if ob.hasattr("__arrow_c_stream__")? {
Ok(Self::Stream(PyRecordBatchReader::extract_bound(ob)?))
} else {
Err(PyValueError::new_err(
"Expected object with __arrow_c_array__ or __arrow_c_stream__ method",
))
}
}
}
1 change: 1 addition & 0 deletions pyo3-arrow/src/ffi/from_python/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod array;
pub mod chunked;
pub mod ffi_stream;
pub mod field;
pub mod input;
pub mod record_batch;
pub mod record_batch_reader;
pub mod schema;
Expand Down
6 changes: 5 additions & 1 deletion pyo3-arrow/src/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl PyField {
Self(field)
}

/// Convert this to a Python `arro3.core.Field`.
/// Export this to a Python `arro3.core.Field`.
pub fn to_python(&self, py: Python) -> PyArrowResult<PyObject> {
let arro3_mod = py.import_bound(intern!(py, "arro3.core"))?;
let core_obj = arro3_mod.getattr(intern!(py, "Field"))?.call_method1(
Expand Down Expand Up @@ -74,6 +74,10 @@ impl PyField {
self.0 == other.0
}

/// Construct this from an existing Arrow object.
///
/// It can be called on anything that exports the Arrow schema interface
/// (`__arrow_c_schema__`).
#[classmethod]
pub fn from_arrow(_cls: &Bound<PyType>, input: &Bound<PyAny>) -> PyResult<Self> {
input.extract()
Expand Down
8 changes: 8 additions & 0 deletions pyo3-arrow/src/input.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
use crate::{PyRecordBatch, PyRecordBatchReader};

/// An enum over [PyRecordBatch] and [PyRecordBatchReader], used when a function accepts either
/// Arrow object as input.
pub enum AnyRecordBatch {
RecordBatch(PyRecordBatch),
Stream(PyRecordBatchReader),
}
1 change: 1 addition & 0 deletions pyo3-arrow/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod chunked;
pub mod error;
mod ffi;
mod field;
pub mod input;
mod interop;
mod record_batch;
mod record_batch_reader;
Expand Down
7 changes: 5 additions & 2 deletions pyo3-arrow/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl PyRecordBatch {
Self(batch)
}

/// Convert this to a Python `arro3.core.RecordBatch`.
/// Export this to a Python `arro3.core.RecordBatch`.
pub fn to_python(&self, py: Python) -> PyArrowResult<PyObject> {
let arro3_mod = py.import_bound(intern!(py, "arro3.core"))?;
let core_obj = arro3_mod
Expand Down Expand Up @@ -88,7 +88,10 @@ impl PyRecordBatch {
self.0 == other.0
}

/// Construct this object from existing Arrow data
/// Construct this from an existing Arrow RecordBatch.
///
/// It can be called on anything that exports the Arrow data interface
/// (`__arrow_c_array__`) and returns a StructArray..
///
/// Args:
/// input: Arrow array to use for constructing this object
Expand Down
31 changes: 29 additions & 2 deletions pyo3-arrow/src/record_batch_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use pyo3::types::{PyCapsule, PyTuple, PyType};

use crate::error::PyArrowResult;
use crate::ffi::from_python::utils::import_stream_pycapsule;
use crate::PyTable;

/// A Python-facing Arrow record batch reader.
///
Expand All @@ -17,6 +18,14 @@ use crate::ffi::from_python::utils::import_stream_pycapsule;
pub struct PyRecordBatchReader(pub(crate) Option<Box<dyn RecordBatchReader + Send>>);

impl PyRecordBatchReader {
/// Returns `true` if this reader has already been consumed.
pub fn closed(&self) -> bool {
self.0.is_none()
}

/// Consume this reader and convert into a [RecordBatchReader].
///
/// The reader can only be consumed once. Calling `into_reader`
pub fn into_reader(mut self) -> PyArrowResult<Box<dyn RecordBatchReader + Send>> {
let stream = self
.0
Expand All @@ -25,7 +34,21 @@ impl PyRecordBatchReader {
Ok(stream)
}

/// Convert this to a Python `arro3.core.RecordBatchReader`.
/// Consume this reader and create a [PyTable] object
pub fn into_table(mut self) -> PyArrowResult<PyTable> {
let stream = self
.0
.take()
.ok_or(PyIOError::new_err("Cannot write from closed stream."))?;
let schema = stream.schema();
let mut batches = vec![];
for batch in stream {
batches.push(batch?);
}
Ok(PyTable::new(schema, batches))
}

/// Export this to a Python `arro3.core.RecordBatchReader`.
pub fn to_python(&mut self, py: Python) -> PyArrowResult<PyObject> {
let arro3_mod = py.import_bound(intern!(py, "arro3.core"))?;
let core_obj = arro3_mod
Expand Down Expand Up @@ -63,12 +86,16 @@ impl PyRecordBatchReader {
PyCapsule::new_bound(py, ffi_stream, Some(stream_capsule_name))
}

/// Construct this from an existing Arrow object.
///
/// It can be called on anything that exports the Arrow stream interface
/// (`__arrow_c_stream__`), such as a `Table` or `RecordBatchReader`.
#[classmethod]
pub fn from_arrow(_cls: &Bound<PyType>, input: &Bound<PyAny>) -> PyResult<Self> {
input.extract()
}

/// Construct this object from a bare Arrow PyCapsule
/// Construct this object from a bare Arrow PyCapsule.
#[classmethod]
pub fn from_arrow_pycapsule(
_cls: &Bound<PyType>,
Expand Down
7 changes: 5 additions & 2 deletions pyo3-arrow/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl PySchema {
Self(schema)
}

/// Convert this to a Python `arro3.core.Schema`.
/// Export this to a Python `arro3.core.Schema`.
pub fn to_python(&self, py: Python) -> PyArrowResult<PyObject> {
let arro3_mod = py.import_bound(intern!(py, "arro3.core"))?;
let core_obj = arro3_mod.getattr(intern!(py, "Schema"))?.call_method1(
Expand Down Expand Up @@ -67,7 +67,10 @@ impl PySchema {
Ok(schema_capsule)
}

/// Construct this object from existing Arrow data
/// Construct this object from an existing Arrow object
///
/// It can be called on anything that exports the Arrow data interface
/// (`__arrow_c_array__`) and returns a struct field.
///
/// Args:
/// input: Arrow array to use for constructing this object
Expand Down
9 changes: 7 additions & 2 deletions pyo3-arrow/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl PyTable {
(self.batches, self.schema)
}

/// Convert this to a Python `arro3.core.Table`.
/// Export this to a Python `arro3.core.Table`.
pub fn to_python(&self, py: Python) -> PyArrowResult<PyObject> {
let arro3_mod = py.import_bound(intern!(py, "arro3.core"))?;
let core_obj = arro3_mod.getattr(intern!(py, "Table"))?.call_method1(
Expand Down Expand Up @@ -82,7 +82,12 @@ impl PyTable {
self.batches.iter().fold(0, |acc, x| acc + x.num_rows())
}

/// Construct this object from existing Arrow data
/// Construct this object from an existing Arrow object.
///
/// It can be called on anything that exports the Arrow stream interface
/// (`__arrow_c_stream__`) and yields a StructArray for each item. This Table will materialize
/// all items from the iterator in memory at once. Use RecordBatchReader if you don't wish to
/// materialize all batches in memory at once.
///
/// Args:
/// input: Arrow array to use for constructing this object
Expand Down

0 comments on commit 182e7df

Please sign in to comment.