Skip to content

Commit

Permalink
Simpler constructors (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron authored Jul 1, 2024
1 parent f3a1108 commit 43fbf7c
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 37 deletions.
16 changes: 13 additions & 3 deletions pyo3-arrow/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::ffi::CString;
use std::sync::Arc;

use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema};
use arrow_array::{Array, ArrayRef};
use arrow_array::{make_array, Array, ArrayRef};
use arrow_schema::{Field, FieldRef};
use pyo3::intern;
use pyo3::prelude::*;
Expand All @@ -28,6 +28,17 @@ impl PyArray {
Self { array, field }
}

pub fn from_array<A: Array>(array: A) -> Self {
let array = make_array(array.into_data());
Self::from_array_ref(array)
}

/// Create a new PyArray from an [ArrayRef], inferring its data type automatically.
pub fn from_array_ref(array: ArrayRef) -> Self {
let field = Field::new("", array.data_type().clone(), true);
Self::new(array, Arc::new(field))
}

/// Access the underlying [ArrayRef].
pub fn array(&self) -> &ArrayRef {
&self.array
Expand Down Expand Up @@ -74,8 +85,7 @@ impl PyArray {

impl From<ArrayRef> for PyArray {
fn from(value: ArrayRef) -> Self {
let field = Field::new("", value.data_type().clone(), true);
Self::new(value, Arc::new(field))
Self::from_array_ref(value)
}
}

Expand Down
52 changes: 33 additions & 19 deletions pyo3-arrow/src/chunked.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::ffi::CString;
use std::sync::Arc;

use arrow_array::{Array, ArrayRef};
use arrow_array::{make_array, Array, ArrayRef};
use arrow_schema::{ArrowError, Field, FieldRef};
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::intern;
Expand Down Expand Up @@ -30,6 +30,35 @@ impl PyChunkedArray {
Self { chunks, field }
}

pub fn from_arrays<A: Array>(chunks: &[A]) -> PyArrowResult<Self> {
let arrays = chunks
.iter()
.map(|chunk| make_array(chunk.to_data()))
.collect::<Vec<_>>();
Self::from_array_refs(arrays)
}

/// Create a new PyChunkedArray from a vec of [ArrayRef]s, inferring their data type
/// automatically.
pub fn from_array_refs(chunks: Vec<ArrayRef>) -> PyArrowResult<Self> {
if chunks.is_empty() {
return Err(ArrowError::SchemaError(
"Cannot infer data type from empty Vec<ArrayRef>".to_string(),
)
.into());
}

if !chunks
.windows(2)
.all(|w| w[0].data_type() == w[1].data_type())
{
return Err(ArrowError::SchemaError("Mismatched data types".to_string()).into());
}

let field = Field::new("", chunks.first().unwrap().data_type().clone(), true);
Ok(Self::new(chunks, Arc::new(field)))
}

pub fn chunks(&self) -> &[ArrayRef] {
&self.chunks
}
Expand All @@ -43,7 +72,7 @@ impl PyChunkedArray {
}

/// Export this to a Python `arro3.core.ChunkedArray`.
pub fn to_arro3(&self, py: Python) -> PyArrowResult<PyObject> {
pub fn to_arro3(&self, py: Python) -> PyResult<PyObject> {
let arro3_mod = py.import_bound(intern!(py, "arro3.core"))?;
let core_obj = arro3_mod
.getattr(intern!(py, "ChunkedArray"))?
Expand All @@ -62,7 +91,7 @@ impl PyChunkedArray {
/// Export to a pyarrow.ChunkedArray
///
/// Requires pyarrow >=14
pub fn to_pyarrow(self, py: Python) -> PyArrowResult<PyObject> {
pub fn to_pyarrow(self, py: Python) -> PyResult<PyObject> {
let pyarrow_mod = py.import_bound(intern!(py, "pyarrow"))?;
let pyarrow_obj = pyarrow_mod
.getattr(intern!(py, "chunked_array"))?
Expand All @@ -75,22 +104,7 @@ impl TryFrom<Vec<ArrayRef>> for PyChunkedArray {
type Error = PyArrowError;

fn try_from(value: Vec<ArrayRef>) -> Result<Self, Self::Error> {
if value.is_empty() {
return Err(ArrowError::SchemaError(
"Cannot infer data type from empty Vec<ArrayRef>".to_string(),
)
.into());
}

if !value
.windows(2)
.all(|w| w[0].data_type() == w[1].data_type())
{
return Err(ArrowError::SchemaError("Mismatched data types".to_string()).into());
}

let field = Field::new("", value.first().unwrap().data_type().clone(), true);
Ok(Self::new(value, Arc::new(field)))
Self::from_array_refs(value)
}
}

Expand Down
4 changes: 4 additions & 0 deletions pyo3-arrow/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ impl PyRecordBatch {
Self(batch)
}

pub fn into_inner(self) -> RecordBatch {
self.0
}

/// Export this to a Python `arro3.core.RecordBatch`.
pub fn to_arro3(&self, py: Python) -> PyArrowResult<PyObject> {
let arro3_mod = py.import_bound(intern!(py, "arro3.core"))?;
Expand Down
26 changes: 16 additions & 10 deletions pyo3-arrow/src/record_batch_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use pyo3::types::{PyCapsule, PyTuple, PyType};
use crate::error::PyArrowResult;
use crate::ffi::from_python::utils::import_stream_pycapsule;
use crate::ffi::to_python::nanoarrow::to_nanoarrow_array_stream;
use crate::PyTable;
use crate::{PySchema, PyTable};

/// A Python-facing Arrow record batch reader.
///
Expand All @@ -24,15 +24,10 @@ impl PyRecordBatchReader {
Self(Some(reader))
}

/// Returns `true` if this reader has already been consumed.
pub fn closed(&self) -> bool {
self.0.is_none()
}

/// Consume this reader and convert into a [RecordBatchReader].
///
/// The reader can only be consumed once. Calling `into_reader`
pub fn into_reader(mut self) -> PyArrowResult<Box<dyn RecordBatchReader + Send>> {
pub fn into_reader(mut self) -> PyResult<Box<dyn RecordBatchReader + Send>> {
let stream = self
.0
.take()
Expand All @@ -57,7 +52,7 @@ impl PyRecordBatchReader {
/// Access the [SchemaRef] of this RecordBatchReader.
///
/// If the stream has already been consumed, this method will error.
pub fn schema_ref(&self) -> PyArrowResult<SchemaRef> {
pub fn schema_ref(&self) -> PyResult<SchemaRef> {
let stream = self
.0
.as_ref()
Expand All @@ -66,7 +61,7 @@ impl PyRecordBatchReader {
}

/// Export this to a Python `arro3.core.RecordBatchReader`.
pub fn to_arro3(&mut self, py: Python) -> PyArrowResult<PyObject> {
pub fn to_arro3(&mut self, py: Python) -> PyResult<PyObject> {
let arro3_mod = py.import_bound(intern!(py, "arro3.core"))?;
let core_obj = arro3_mod
.getattr(intern!(py, "RecordBatchReader"))?
Expand All @@ -85,7 +80,7 @@ impl PyRecordBatchReader {
/// Export to a pyarrow.RecordBatchReader
///
/// Requires pyarrow >=15
pub fn to_pyarrow(self, py: Python) -> PyArrowResult<PyObject> {
pub fn to_pyarrow(self, py: Python) -> PyResult<PyObject> {
let pyarrow_mod = py.import_bound(intern!(py, "pyarrow"))?;
let record_batch_reader_class = pyarrow_mod.getattr(intern!(py, "RecordBatchReader"))?;
let pyarrow_obj = record_batch_reader_class.call_method1(
Expand Down Expand Up @@ -127,6 +122,11 @@ impl PyRecordBatchReader {
PyCapsule::new_bound(py, ffi_stream, Some(stream_capsule_name))
}

/// Returns `true` if this reader has already been consumed.
pub fn closed(&self) -> bool {
self.0.is_none()
}

/// Construct this from an existing Arrow object.
///
/// It can be called on anything that exports the Arrow stream interface
Expand All @@ -148,4 +148,10 @@ impl PyRecordBatchReader {

Ok(Self(Some(Box::new(stream_reader))))
}

/// Access the schema of this table
#[getter]
fn schema(&self, py: Python) -> PyResult<PyObject> {
PySchema::new(self.schema_ref()?.clone()).to_arro3(py)
}
}
4 changes: 2 additions & 2 deletions pyo3-arrow/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ impl PySchema {
}

/// Export this to a Python `arro3.core.Schema`.
pub fn to_arro3(&self, py: Python) -> PyArrowResult<PyObject> {
pub fn to_arro3(&self, py: Python) -> PyResult<PyObject> {
let arro3_mod = py.import_bound(intern!(py, "arro3.core"))?;
let core_obj = arro3_mod.getattr(intern!(py, "Schema"))?.call_method1(
intern!(py, "from_arrow_pycapsule"),
Expand All @@ -41,7 +41,7 @@ impl PySchema {
/// Export to a pyarrow.Schema
///
/// Requires pyarrow >=14
pub fn to_pyarrow(self, py: Python) -> PyArrowResult<PyObject> {
pub fn to_pyarrow(self, py: Python) -> PyResult<PyObject> {
let pyarrow_mod = py.import_bound(intern!(py, "pyarrow"))?;
let pyarrow_obj = pyarrow_mod
.getattr(intern!(py, "schema"))?
Expand Down
18 changes: 15 additions & 3 deletions pyo3-arrow/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyCapsule, PyTuple, PyType};

use crate::error::PyArrowResult;
use crate::ffi::from_python::utils::import_stream_pycapsule;
use crate::ffi::to_python::nanoarrow::to_nanoarrow_array_stream;
use crate::PySchema;

/// A Python-facing Arrow table.
///
Expand All @@ -38,7 +38,7 @@ impl PyTable {
}

/// Export this to a Python `arro3.core.Table`.
pub fn to_arro3(&self, py: Python) -> PyArrowResult<PyObject> {
pub fn to_arro3(&self, py: Python) -> PyResult<PyObject> {
let arro3_mod = py.import_bound(intern!(py, "arro3.core"))?;
let core_obj = arro3_mod.getattr(intern!(py, "Table"))?.call_method1(
intern!(py, "from_arrow_pycapsule"),
Expand All @@ -55,7 +55,7 @@ impl PyTable {
/// Export to a pyarrow.Table
///
/// Requires pyarrow >=14
pub fn to_pyarrow(self, py: Python) -> PyArrowResult<PyObject> {
pub fn to_pyarrow(self, py: Python) -> PyResult<PyObject> {
let pyarrow_mod = py.import_bound(intern!(py, "pyarrow"))?;
let pyarrow_obj = pyarrow_mod
.getattr(intern!(py, "table"))?
Expand Down Expand Up @@ -135,4 +135,16 @@ impl PyTable {

Ok(Self::new(schema, batches))
}

/// Number of columns in this table.
#[getter]
fn num_columns(&self) -> usize {
self.schema.fields().len()
}

/// Access the schema of this table
#[getter]
fn schema(&self, py: Python) -> PyResult<PyObject> {
PySchema::new(self.schema.clone()).to_arro3(py)
}
}

0 comments on commit 43fbf7c

Please sign in to comment.