diff --git a/pyo3-arrow/src/chunked.rs b/pyo3-arrow/src/chunked.rs index f076599..9a6c2bf 100644 --- a/pyo3-arrow/src/chunked.rs +++ b/pyo3-arrow/src/chunked.rs @@ -33,7 +33,7 @@ impl PyChunkedArray { assert!( chunks .iter() - .all(|chunk| chunk.data_type() == field.data_type()), + .all(|chunk| chunk.data_type().equals_datatype(field.data_type())), "All chunks must have same data type" ); Self { chunks, field } diff --git a/pyo3-arrow/src/lib.rs b/pyo3-arrow/src/lib.rs index 2395072..57f4721 100644 --- a/pyo3-arrow/src/lib.rs +++ b/pyo3-arrow/src/lib.rs @@ -13,6 +13,7 @@ mod record_batch; mod record_batch_reader; mod schema; mod table; +mod utils; pub use array::PyArray; pub use array_reader::PyArrayReader; diff --git a/pyo3-arrow/src/table.rs b/pyo3-arrow/src/table.rs index 5a6009e..4795b71 100644 --- a/pyo3-arrow/src/table.rs +++ b/pyo3-arrow/src/table.rs @@ -21,6 +21,7 @@ use crate::input::{ AnyArray, AnyRecordBatch, FieldIndexInput, MetadataInput, NameOrField, SelectIndices, }; use crate::schema::display_schema; +use crate::utils::schema_equals; use crate::{PyChunkedArray, PyField, PyRecordBatch, PyRecordBatchReader, PySchema}; /// A Python-facing Arrow table. @@ -35,9 +36,10 @@ pub struct PyTable { impl PyTable { pub fn new(batches: Vec, schema: SchemaRef) -> Self { - // TODO: allow batches to have different schema metadata? assert!( - batches.iter().all(|rb| rb.schema_ref() == &schema), + batches + .iter() + .all(|rb| schema_equals(rb.schema_ref(), &schema)), "All batches must have same schema" ); Self { schema, batches } diff --git a/pyo3-arrow/src/utils.rs b/pyo3-arrow/src/utils.rs new file mode 100644 index 0000000..a826e69 --- /dev/null +++ b/pyo3-arrow/src/utils.rs @@ -0,0 +1,17 @@ +use arrow_schema::Schema; + +/// Check whether two schemas are equal +/// +/// This allows schemas to have different top-level metadata, as well as different nested field +/// names and keys. +pub(crate) fn schema_equals(left: &Schema, right: &Schema) -> bool { + left.fields + .iter() + .zip(right.fields.iter()) + .all(|(left_field, right_field)| { + left_field.name() == right_field.name() + && left_field + .data_type() + .equals_datatype(right_field.data_type()) + }) +}