Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow writing RecordBatch through write_* functions #65

Merged
merged 1 commit into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions arro3-io/python/arro3/io/_rust.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def read_csv(
comment: str | None = None,
) -> RecordBatchReader: ...
def write_csv(
data: ArrowStreamExportable,
data: ArrowStreamExportable | ArrowArrayExportable,
file: IO[bytes] | Path | str,
*,
header: bool | None = None,
Expand Down Expand Up @@ -69,13 +69,13 @@ def read_json(
batch_size: int | None = None,
) -> RecordBatchReader: ...
def write_json(
data: ArrowStreamExportable,
data: ArrowStreamExportable | ArrowArrayExportable,
file: IO[bytes] | Path | str,
*,
explicit_nulls: bool | None = None,
) -> None: ...
def write_ndjson(
data: ArrowStreamExportable,
data: ArrowStreamExportable | ArrowArrayExportable,
file: IO[bytes] | Path | str,
*,
explicit_nulls: bool | None = None,
Expand All @@ -85,14 +85,16 @@ def write_ndjson(

def read_ipc(file: IO[bytes] | Path | str) -> RecordBatchReader: ...
def read_ipc_stream(file: IO[bytes] | Path | str) -> RecordBatchReader: ...
def write_ipc(data: ArrowStreamExportable, file: IO[bytes] | Path | str) -> None: ...
def write_ipc(
data: ArrowStreamExportable | ArrowArrayExportable, file: IO[bytes] | Path | str
) -> None: ...
def write_ipc_stream(
data: ArrowStreamExportable, file: IO[bytes] | Path | str
data: ArrowStreamExportable | ArrowArrayExportable, file: IO[bytes] | Path | str
) -> None: ...

#### Parquet

def read_parquet(file: Path | str) -> RecordBatchReader: ...
def write_parquet(
data: ArrowStreamExportable, file: IO[bytes] | Path | str
data: ArrowStreamExportable | ArrowArrayExportable, file: IO[bytes] | Path | str
) -> None: ...
3 changes: 2 additions & 1 deletion arro3-io/src/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use arrow_csv::reader::Format;
use arrow_csv::{ReaderBuilder, WriterBuilder};
use pyo3::prelude::*;
use pyo3_arrow::error::PyArrowResult;
use pyo3_arrow::input::AnyRecordBatch;
use pyo3_arrow::{PyRecordBatchReader, PySchema};

use crate::utils::{FileReader, FileWriter};
Expand Down Expand Up @@ -133,7 +134,7 @@ pub fn read_csv(
))]
#[allow(clippy::too_many_arguments)]
pub fn write_csv(
data: PyRecordBatchReader,
data: AnyRecordBatch,
file: FileWriter,
header: Option<bool>,
delimiter: Option<char>,
Expand Down
9 changes: 3 additions & 6 deletions arro3-io/src/ipc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ use std::io::{BufReader, BufWriter};
use arrow_ipc::reader::{FileReaderBuilder, StreamReader};
use pyo3::prelude::*;
use pyo3_arrow::error::PyArrowResult;
use pyo3_arrow::input::AnyRecordBatch;
use pyo3_arrow::PyRecordBatchReader;

use crate::utils::{FileReader, FileWriter};

/// Read an Arrow IPC file to an Arrow RecordBatchReader
#[pyfunction]
#[allow(clippy::too_many_arguments)]
pub fn read_ipc(py: Python, file: FileReader) -> PyArrowResult<PyObject> {
let builder = FileReaderBuilder::new();
let buf_file = BufReader::new(file);
Expand All @@ -19,16 +19,14 @@ pub fn read_ipc(py: Python, file: FileReader) -> PyArrowResult<PyObject> {

/// Read an Arrow IPC Stream file to an Arrow RecordBatchReader
#[pyfunction]
#[allow(clippy::too_many_arguments)]
pub fn read_ipc_stream(py: Python, file: FileReader) -> PyArrowResult<PyObject> {
let reader = StreamReader::try_new(file, None)?;
Ok(PyRecordBatchReader::new(Box::new(reader)).to_arro3(py)?)
}

/// Write an Arrow Table or stream to an IPC File
#[pyfunction]
#[allow(clippy::too_many_arguments)]
pub fn write_ipc(data: PyRecordBatchReader, file: FileWriter) -> PyArrowResult<()> {
pub fn write_ipc(data: AnyRecordBatch, file: FileWriter) -> PyArrowResult<()> {
let buf_writer = BufWriter::new(file);
let reader = data.into_reader()?;
let mut writer = arrow_ipc::writer::FileWriter::try_new(buf_writer, &reader.schema())?;
Expand All @@ -40,8 +38,7 @@ pub fn write_ipc(data: PyRecordBatchReader, file: FileWriter) -> PyArrowResult<(

/// Write an Arrow Table or stream to an IPC Stream
#[pyfunction]
#[allow(clippy::too_many_arguments)]
pub fn write_ipc_stream(data: PyRecordBatchReader, file: FileWriter) -> PyArrowResult<()> {
pub fn write_ipc_stream(data: AnyRecordBatch, file: FileWriter) -> PyArrowResult<()> {
let buf_writer = BufWriter::new(file);
let reader = data.into_reader()?;
let mut writer = arrow_ipc::writer::StreamWriter::try_new(buf_writer, &reader.schema())?;
Expand Down
7 changes: 3 additions & 4 deletions arro3-io/src/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use arrow::json::writer::{JsonArray, LineDelimited};
use arrow::json::{ReaderBuilder, WriterBuilder};
use pyo3::prelude::*;
use pyo3_arrow::error::PyArrowResult;
use pyo3_arrow::input::AnyRecordBatch;
use pyo3_arrow::{PyRecordBatchReader, PySchema};

use crate::utils::{FileReader, FileWriter};
Expand All @@ -15,7 +16,6 @@ use crate::utils::{FileReader, FileWriter};
*,
max_records=None,
))]
#[allow(clippy::too_many_arguments)]
pub fn infer_json_schema(
py: Python,
file: FileReader,
Expand All @@ -34,7 +34,6 @@ pub fn infer_json_schema(
*,
batch_size=None,
))]
#[allow(clippy::too_many_arguments)]
pub fn read_json(
py: Python,
file: FileReader,
Expand Down Expand Up @@ -62,7 +61,7 @@ pub fn read_json(
))]
#[allow(clippy::too_many_arguments)]
pub fn write_json(
data: PyRecordBatchReader,
data: AnyRecordBatch,
file: FileWriter,
explicit_nulls: Option<bool>,
) -> PyArrowResult<()> {
Expand All @@ -89,7 +88,7 @@ pub fn write_json(
))]
#[allow(clippy::too_many_arguments)]
pub fn write_ndjson(
data: PyRecordBatchReader,
data: AnyRecordBatch,
file: FileWriter,
explicit_nulls: Option<bool>,
) -> PyArrowResult<()> {
Expand Down
5 changes: 2 additions & 3 deletions arro3-io/src/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ use parquet::arrow::ArrowWriter;
use pyo3::exceptions::PyTypeError;
use pyo3::prelude::*;
use pyo3_arrow::error::PyArrowResult;
use pyo3_arrow::input::AnyRecordBatch;
use pyo3_arrow::PyRecordBatchReader;

use crate::utils::{FileReader, FileWriter};

/// Read a Parquet file to an Arrow RecordBatchReader
#[pyfunction]
#[allow(clippy::too_many_arguments)]
pub fn read_parquet(py: Python, file: FileReader) -> PyArrowResult<PyObject> {
match file {
FileReader::File(f) => {
Expand All @@ -26,8 +26,7 @@ pub fn read_parquet(py: Python, file: FileReader) -> PyArrowResult<PyObject> {

/// Write an Arrow Table or stream to a Parquet file
#[pyfunction]
#[allow(clippy::too_many_arguments)]
pub fn write_parquet(data: PyRecordBatchReader, file: FileWriter) -> PyArrowResult<()> {
pub fn write_parquet(data: AnyRecordBatch, file: FileWriter) -> PyArrowResult<()> {
let reader = data.into_reader()?;
let mut writer = ArrowWriter::try_new(file, reader.schema(), None).unwrap();
for batch in reader {
Expand Down
42 changes: 42 additions & 0 deletions pyo3-arrow/src/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
use std::collections::HashMap;
use std::string::FromUtf8Error;

use arrow_array::{RecordBatchIterator, RecordBatchReader};
use arrow_schema::{FieldRef, SchemaRef};
use pyo3::prelude::*;

use crate::array_reader::PyArrayReader;
use crate::ffi::{ArrayIterator, ArrayReader};
use crate::{PyArray, PyRecordBatch, PyRecordBatchReader};

/// An enum over [PyRecordBatch] and [PyRecordBatchReader], used when a function accepts either
Expand All @@ -18,13 +21,52 @@ pub enum AnyRecordBatch {
Stream(PyRecordBatchReader),
}

impl AnyRecordBatch {
pub fn into_reader(self) -> PyResult<Box<dyn RecordBatchReader + Send>> {
match self {
Self::RecordBatch(batch) => {
let batch = batch.into_inner();
let schema = batch.schema();
Ok(Box::new(RecordBatchIterator::new(vec![Ok(batch)], schema)))
}
Self::Stream(stream) => stream.into_reader(),
}
}

pub fn schema(&self) -> PyResult<SchemaRef> {
match self {
Self::RecordBatch(batch) => Ok(batch.as_ref().schema()),
Self::Stream(stream) => stream.schema_ref(),
}
}
}

/// An enum over [PyArray] and [PyArrayReader], used when a function accepts either
/// Arrow object as input.
pub enum AnyArray {
Array(PyArray),
Stream(PyArrayReader),
}

impl AnyArray {
pub fn into_reader(self) -> PyResult<Box<dyn ArrayReader + Send>> {
match self {
Self::Array(array) => {
let (array, field) = array.into_inner();
Ok(Box::new(ArrayIterator::new(vec![Ok(array)], field)))
}
Self::Stream(stream) => stream.into_reader(),
}
}

pub fn field(&self) -> PyResult<FieldRef> {
match self {
Self::Array(array) => Ok(array.field().clone()),
Self::Stream(stream) => stream.field_ref(),
}
}
}

#[derive(FromPyObject)]
pub enum MetadataInput {
String(HashMap<String, String>),
Expand Down