From 95ce68a4812a71dd9242f501a09234bd59f15d20 Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Mon, 29 Jul 2024 18:58:47 -0400 Subject: [PATCH 1/2] Export DataType constructors --- arro3-core/python/arro3/core/_core.pyi | 259 ++++++++++++++- pyo3-arrow/src/datatypes.rs | 428 +++++++++++++------------ 2 files changed, 487 insertions(+), 200 deletions(-) diff --git a/arro3-core/python/arro3/core/_core.pyi b/arro3-core/python/arro3/core/_core.pyi index 6eb78b2..a31e6b1 100644 --- a/arro3-core/python/arro3/core/_core.pyi +++ b/arro3-core/python/arro3/core/_core.pyi @@ -1,4 +1,4 @@ -from typing import Any, Sequence +from typing import Any, Literal, Sequence import numpy as np from numpy.typing import NDArray @@ -125,6 +125,263 @@ class DataType: def from_arrow_pycapsule(cls, capsule) -> DataType: """Construct this object from a bare Arrow PyCapsule""" def bit_width(self) -> int | None: ... + #### Constructors + @classmethod + def null(cls) -> DataType: + """Create instance of null type.""" + @classmethod + def bool(cls) -> DataType: + """Create instance of boolean type.""" + @classmethod + def int8(cls) -> DataType: + """Create instance of signed int8 type.""" + @classmethod + def int16(cls) -> DataType: + """Create instance of signed int16 type.""" + @classmethod + def int32(cls) -> DataType: + """Create instance of signed int32 type.""" + @classmethod + def int64(cls) -> DataType: + """Create instance of signed int64 type.""" + @classmethod + def uint8(cls) -> DataType: + """Create instance of unsigned int8 type.""" + @classmethod + def uint16(cls) -> DataType: + """Create instance of unsigned int16 type.""" + @classmethod + def uint32(cls) -> DataType: + """Create instance of unsigned int32 type.""" + @classmethod + def uint64(cls) -> DataType: + """Create instance of unsigned int64 type.""" + @classmethod + def float16(cls) -> DataType: + """Create half-precision floating point type.""" + @classmethod + def float32(cls) -> DataType: + """Create single-precision floating point type.""" + @classmethod + def float64(cls) -> DataType: + """Create double-precision floating point type.""" + @classmethod + def time32(cls, unit: Literal["s", "ms"]) -> DataType: + """Create instance of 32-bit time (time of day) type with unit resolution. + + Args: + unit: one of `'s'` [second], or `'ms'` [millisecond] + + Returns: + _description_ + """ + @classmethod + def time64(cls, unit: Literal["us", "ns"]) -> DataType: + """Create instance of 64-bit time (time of day) type with unit resolution. + + Args: + unit: One of `'us'` [microsecond], or `'ns'` [nanosecond]. + + Returns: + _description_ + """ + @classmethod + def timestamp( + cls, unit: Literal["s", "ms", "us", "ns"], *, tz: str | None = None + ) -> DataType: + """Create instance of timestamp type with resolution and optional time zone. + + Args: + unit: one of `'s'` [second], `'ms'` [millisecond], `'us'` [microsecond], or `'ns'` [nanosecond] + tz: Time zone name. None indicates time zone naive. Defaults to None. + + Returns: + _description_ + """ + @classmethod + def date32(cls) -> DataType: + """Create instance of 32-bit date (days since UNIX epoch 1970-01-01).""" + @classmethod + def date64(cls) -> DataType: + """Create instance of 64-bit date (milliseconds since UNIX epoch 1970-01-01).""" + @classmethod + def duration(cls, unit: Literal["s", "ms", "us", "ns"]) -> DataType: + """Create instance of a duration type with unit resolution. + + Args: + unit: one of `'s'` [second], `'ms'` [millisecond], `'us'` [microsecond], or `'ns'` [nanosecond] + + Returns: + _description_ + """ + @classmethod + def month_day_nano_interval(cls) -> DataType: + """ + Create instance of an interval type representing months, days and nanoseconds + between two dates. + """ + @classmethod + def binary(cls, length: int | None = None) -> DataType: + """Create variable-length or fixed size binary type. + + Args: + length: If length is `None` then return a variable length binary type. If length is provided, then return a fixed size binary type of width `length`. Defaults to None. + + Returns: + _description_ + """ + @classmethod + def string(cls) -> DataType: + """Create UTF8 variable-length string type.""" + @classmethod + def utf8(cls) -> DataType: + """Alias for string().""" + @classmethod + def large_binary(cls) -> DataType: + """Create large variable-length binary type.""" + @classmethod + def large_string(cls) -> DataType: + """Create large UTF8 variable-length string type.""" + @classmethod + def large_utf8(cls) -> DataType: + """Alias for large_string().""" + @classmethod + def binary_view(cls) -> DataType: + """Create a variable-length binary view type.""" + @classmethod + def string_view(cls) -> DataType: + """Create UTF8 variable-length string view type.""" + @classmethod + def decimal128(cls, precision: int, scale: int) -> DataType: + """Create decimal type with precision and scale and 128-bit width. + + Arrow decimals are fixed-point decimal numbers encoded as a scaled integer. The + precision is the number of significant digits that the decimal type can + represent; the scale is the number of digits after the decimal point (note the + scale can be negative). + + As an example, `decimal128(7, 3)` can exactly represent the numbers 1234.567 and + -1234.567 (encoded internally as the 128-bit integers 1234567 and -1234567, + respectively), but neither 12345.67 nor 123.4567. + + `decimal128(5, -3)` can exactly represent the number 12345000 (encoded + internally as the 128-bit integer 12345), but neither 123450000 nor 1234500. + + If you need a precision higher than 38 significant digits, consider using + `decimal256`. + + Args: + precision: Must be between 1 and 38 scale: _description_ + """ + @classmethod + def decimal256(cls, precision: int, scale: int) -> DataType: + """Create decimal type with precision and scale and 256-bit width.""" + @classmethod + def list(cls, value_type: ArrowSchemaExportable, list_size: int | None) -> DataType: + """Create ListType instance from child data type or field. + + Args: + value_type: _description_ + list_size: If length is `None` then return a variable length list type. If length is provided then return a fixed size list type. + + Returns: + _description_ + """ + @classmethod + def large_list(cls, value_type: ArrowSchemaExportable) -> DataType: + """Create LargeListType instance from child data type or field. + + This data type may not be supported by all Arrow implementations. Unless you + need to represent data larger than 2**31 elements, you should prefer `list()`. + + Args: + value_type: _description_ + + Returns: + _description_ + """ + @classmethod + def list_view(cls, value_type: ArrowSchemaExportable) -> DataType: + """ + Create ListViewType instance from child data type or field. + + This data type may not be supported by all Arrow implementations because it is + an alternative to the ListType. + + """ + @classmethod + def large_list_view(cls, value_type: ArrowSchemaExportable) -> DataType: + """Create LargeListViewType instance from child data type or field. + + This data type may not be supported by all Arrow implementations because it is + an alternative to the ListType. + + Args: + value_type: _description_ + + Returns: + _description_ + """ + + @classmethod + def map( + cls, + key_type: ArrowSchemaExportable, + item_type: ArrowSchemaExportable, + keys_sorted: bool, + ) -> DataType: + """Create MapType instance from key and item data types or fields. + + Args: + key_type: _description_ + item_type: _description_ + keys_sorted: _description_ + + Returns: + _description_ + """ + + @classmethod + def struct(cls, fields: Sequence[ArrowSchemaExportable]) -> DataType: + """Create StructType instance from fields. + + A struct is a nested type parameterized by an ordered sequence of types (which + can all be distinct), called its fields. + + Args: + fields: Each field must have a UTF8-encoded name, and these field names are part of the type metadata. + + Returns: + _description_ + """ + + @classmethod + def dictionary( + cls, index_type: ArrowSchemaExportable, value_type: ArrowSchemaExportable + ) -> DataType: + """Dictionary (categorical, or simply encoded) type. + + Args: + index_type: _description_ + value_type: _description_ + + Returns: + _description_ + """ + + @classmethod + def run_end_encoded( + cls, run_end_type: ArrowSchemaExportable, value_type: ArrowSchemaExportable + ) -> DataType: + """Create RunEndEncodedType from run-end and value types. + + Args: + run_end_type: The integer type of the run_ends array. Must be `'int16'`, `'int32'`, or `'int64'`. + value_type: The type of the values array. + + Returns: + _description_ + """ class Field: def __init__( diff --git a/pyo3-arrow/src/datatypes.rs b/pyo3-arrow/src/datatypes.rs index bab97b2..73ac550 100644 --- a/pyo3-arrow/src/datatypes.rs +++ b/pyo3-arrow/src/datatypes.rs @@ -1,7 +1,8 @@ use std::fmt::Display; +use std::sync::Arc; use arrow::datatypes::DataType; -use arrow_schema::TimeUnit; +use arrow_schema::{Field, IntervalUnit, TimeUnit}; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::intern; use pyo3::prelude::*; @@ -11,12 +12,12 @@ use crate::error::PyArrowResult; use crate::ffi::from_python::utils::import_schema_pycapsule; use crate::ffi::to_python::nanoarrow::to_nanoarrow_schema; use crate::ffi::to_schema_pycapsule; +use crate::PyField; -#[allow(dead_code)] -pub struct PyTimeUnit(arrow_schema::TimeUnit); +struct PyTimeUnit(arrow_schema::TimeUnit); impl<'a> FromPyObject<'a> for PyTimeUnit { - fn extract(ob: &'a PyAny) -> PyResult { + fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult { let s: String = ob.extract()?; match s.to_lowercase().as_str() { "s" => Ok(Self(TimeUnit::Second)), @@ -150,199 +151,228 @@ impl PyDataType { self.0.primitive_width() } - // TODO: decide whether to make this public - - // #[classmethod] - // fn null(_: &Bound) -> Self { - // Self(DataType::Null) - // } - - // #[classmethod] - // fn bool(_: &Bound) -> Self { - // Self(DataType::Boolean) - // } - - // #[classmethod] - // fn int8(_: &Bound) -> Self { - // Self(DataType::Int8) - // } - - // #[classmethod] - // fn int16(_: &Bound) -> Self { - // Self(DataType::Int16) - // } - - // #[classmethod] - // fn int32(_: &Bound) -> Self { - // Self(DataType::Int32) - // } - - // #[classmethod] - // fn int64(_: &Bound) -> Self { - // Self(DataType::Int64) - // } - - // #[classmethod] - // fn uint8(_: &Bound) -> Self { - // Self(DataType::UInt8) - // } - - // #[classmethod] - // fn uint16(_: &Bound) -> Self { - // Self(DataType::UInt16) - // } - - // #[classmethod] - // fn uint32(_: &Bound) -> Self { - // Self(DataType::UInt32) - // } - - // #[classmethod] - // fn uint64(_: &Bound) -> Self { - // Self(DataType::UInt64) - // } - - // #[classmethod] - // fn float16(_: &Bound) -> Self { - // Self(DataType::Float16) - // } - - // #[classmethod] - // fn float32(_: &Bound) -> Self { - // Self(DataType::Float32) - // } - - // #[classmethod] - // fn float64(_: &Bound) -> Self { - // Self(DataType::Float64) - // } - - // #[classmethod] - // fn time32(_: &Bound, unit: PyTimeUnit) -> PyArrowResult { - // if unit.0 == TimeUnit::Microsecond || unit.0 == TimeUnit::Nanosecond { - // return Err(PyValueError::new_err("Unexpected timeunit for time32").into()); - // } - - // Ok(Self(DataType::Time32(unit.0))) - // } - - // #[classmethod] - // fn time64(_: &Bound, unit: PyTimeUnit) -> PyArrowResult { - // if unit.0 == TimeUnit::Second || unit.0 == TimeUnit::Millisecond { - // return Err(PyValueError::new_err("Unexpected timeunit for time64").into()); - // } - - // Ok(Self(DataType::Time64(unit.0))) - // } - - // #[classmethod] - // fn timestamp(_: &Bound, unit: PyTimeUnit, tz: Option) -> Self { - // Self(DataType::Timestamp(unit.0, tz.map(|s| s.into()))) - // } - - // #[classmethod] - // fn date32(_: &Bound) -> Self { - // Self(DataType::Date32) - // } - - // #[classmethod] - // fn date64(_: &Bound) -> Self { - // Self(DataType::Date64) - // } - - // #[classmethod] - // fn duration(_: &Bound, unit: PyTimeUnit) -> Self { - // Self(DataType::Duration(unit.0)) - // } - - // #[classmethod] - // fn month_day_nano_interval(_: &Bound) -> Self { - // Self(DataType::Interval(IntervalUnit::MonthDayNano)) - // } - - // #[classmethod] - // fn binary(_: &Bound) -> Self { - // Self(DataType::Binary) - // } - - // #[classmethod] - // fn string(_: &Bound) -> Self { - // Self(DataType::Utf8) - // } - - // #[classmethod] - // fn utf8(_: &Bound) -> Self { - // Self(DataType::Utf8) - // } - - // #[classmethod] - // fn large_binary(_: &Bound) -> Self { - // Self(DataType::LargeBinary) - // } - - // #[classmethod] - // fn large_string(_: &Bound) -> Self { - // Self(DataType::LargeUtf8) - // } - - // #[classmethod] - // fn large_utf8(_: &Bound) -> Self { - // Self(DataType::LargeUtf8) - // } - - // #[classmethod] - // fn binary_view(_: &Bound) -> Self { - // Self(DataType::BinaryView) - // } - - // #[classmethod] - // fn string_view(_: &Bound) -> Self { - // Self(DataType::Utf8View) - // } - - // #[classmethod] - // fn decimal128(_: &Bound, precision: u8, scale: i8) -> Self { - // Self(DataType::Decimal128(precision, scale)) - // } - - // #[classmethod] - // fn decimal256(_: &Bound, precision: u8, scale: i8) -> Self { - // Self(DataType::Decimal256(precision, scale)) - // } - - // #[classmethod] - // fn list(_: &Bound, value_type: PyField, list_size: Option) -> Self { - // if let Some(list_size) = list_size { - // Self(DataType::FixedSizeList(value_type.into(), list_size)) - // } else { - // Self(DataType::List(value_type.into())) - // } - // } - - // #[classmethod] - // fn large_list(_: &Bound, value_type: PyField) -> Self { - // Self(DataType::LargeList(value_type.into())) - // } - - // #[classmethod] - // fn list_view(_: &Bound, value_type: PyField) -> Self { - // Self(DataType::ListView(value_type.into())) - // } - - // #[classmethod] - // fn large_list_view(_: &Bound, value_type: PyField) -> Self { - // Self(DataType::LargeListView(value_type.into())) - // } - - // TODO: fix this. - // #[classmethod] - // fn map(_: &PyType, key_type: PyField, item_type: PyField, keys_sorted: bool) -> Self { - // let field = Field::new( - // "entries", - // DataType::Struct(vec![Arc::new(key_type.0), Arc::new(item_type.0)].into()), - // true, - // ); - // // ::new_struct("entries", , true); - // Self(DataType::Map(field.into(), keys_sorted)) - // } + #[classmethod] + fn null(_: &Bound) -> Self { + Self(DataType::Null) + } + + #[classmethod] + fn bool(_: &Bound) -> Self { + Self(DataType::Boolean) + } + + #[classmethod] + fn int8(_: &Bound) -> Self { + Self(DataType::Int8) + } + + #[classmethod] + fn int16(_: &Bound) -> Self { + Self(DataType::Int16) + } + + #[classmethod] + fn int32(_: &Bound) -> Self { + Self(DataType::Int32) + } + + #[classmethod] + fn int64(_: &Bound) -> Self { + Self(DataType::Int64) + } + + #[classmethod] + fn uint8(_: &Bound) -> Self { + Self(DataType::UInt8) + } + + #[classmethod] + fn uint16(_: &Bound) -> Self { + Self(DataType::UInt16) + } + + #[classmethod] + fn uint32(_: &Bound) -> Self { + Self(DataType::UInt32) + } + + #[classmethod] + fn uint64(_: &Bound) -> Self { + Self(DataType::UInt64) + } + + #[classmethod] + fn float16(_: &Bound) -> Self { + Self(DataType::Float16) + } + + #[classmethod] + fn float32(_: &Bound) -> Self { + Self(DataType::Float32) + } + + #[classmethod] + fn float64(_: &Bound) -> Self { + Self(DataType::Float64) + } + + #[classmethod] + fn time32(_: &Bound, unit: PyTimeUnit) -> PyArrowResult { + if unit.0 == TimeUnit::Microsecond || unit.0 == TimeUnit::Nanosecond { + return Err(PyValueError::new_err("Unexpected timeunit for time32").into()); + } + + Ok(Self(DataType::Time32(unit.0))) + } + + #[classmethod] + fn time64(_: &Bound, unit: PyTimeUnit) -> PyArrowResult { + if unit.0 == TimeUnit::Second || unit.0 == TimeUnit::Millisecond { + return Err(PyValueError::new_err("Unexpected timeunit for time64").into()); + } + + Ok(Self(DataType::Time64(unit.0))) + } + + #[classmethod] + #[pyo3(signature = (unit, *, tz=None))] + fn timestamp(_: &Bound, unit: PyTimeUnit, tz: Option) -> Self { + Self(DataType::Timestamp(unit.0, tz.map(|s| s.into()))) + } + + #[classmethod] + fn date32(_: &Bound) -> Self { + Self(DataType::Date32) + } + + #[classmethod] + fn date64(_: &Bound) -> Self { + Self(DataType::Date64) + } + + #[classmethod] + fn duration(_: &Bound, unit: PyTimeUnit) -> Self { + Self(DataType::Duration(unit.0)) + } + + #[classmethod] + fn month_day_nano_interval(_: &Bound) -> Self { + Self(DataType::Interval(IntervalUnit::MonthDayNano)) + } + + #[classmethod] + fn binary(_: &Bound, length: Option) -> Self { + if let Some(length) = length { + Self(DataType::FixedSizeBinary(length)) + } else { + Self(DataType::Binary) + } + } + + #[classmethod] + fn string(_: &Bound) -> Self { + Self(DataType::Utf8) + } + + #[classmethod] + fn utf8(_: &Bound) -> Self { + Self(DataType::Utf8) + } + + #[classmethod] + fn large_binary(_: &Bound) -> Self { + Self(DataType::LargeBinary) + } + + #[classmethod] + fn large_string(_: &Bound) -> Self { + Self(DataType::LargeUtf8) + } + + #[classmethod] + fn large_utf8(_: &Bound) -> Self { + Self(DataType::LargeUtf8) + } + + #[classmethod] + fn binary_view(_: &Bound) -> Self { + Self(DataType::BinaryView) + } + + #[classmethod] + fn string_view(_: &Bound) -> Self { + Self(DataType::Utf8View) + } + + #[classmethod] + fn decimal128(_: &Bound, precision: u8, scale: i8) -> Self { + Self(DataType::Decimal128(precision, scale)) + } + + #[classmethod] + fn decimal256(_: &Bound, precision: u8, scale: i8) -> Self { + Self(DataType::Decimal256(precision, scale)) + } + + #[classmethod] + fn list(_: &Bound, value_type: PyField, list_size: Option) -> Self { + if let Some(list_size) = list_size { + Self(DataType::FixedSizeList(value_type.into(), list_size)) + } else { + Self(DataType::List(value_type.into())) + } + } + + #[classmethod] + fn large_list(_: &Bound, value_type: PyField) -> Self { + Self(DataType::LargeList(value_type.into())) + } + + #[classmethod] + fn list_view(_: &Bound, value_type: PyField) -> Self { + Self(DataType::ListView(value_type.into())) + } + + #[classmethod] + fn large_list_view(_: &Bound, value_type: PyField) -> Self { + Self(DataType::LargeListView(value_type.into())) + } + + #[classmethod] + fn map(_: &Bound, key_type: PyField, item_type: PyField, keys_sorted: bool) -> Self { + // Note: copied from source of `Field::new_map` + // https://github.com/apache/arrow-rs/blob/bf9ce475df82d362631099d491d3454d64d50217/arrow-schema/src/field.rs#L251-L258 + let data_type = DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(vec![key_type.into_inner(), item_type.into_inner()].into()), + false, // The inner map field is always non-nullable (arrow-rs #1697), + )), + keys_sorted, + ); + Self(data_type) + } + + #[classmethod] + fn r#struct(_: &Bound, fields: Vec) -> Self { + Self(DataType::Struct( + fields.into_iter().map(|field| field.into_inner()).collect(), + )) + } + + #[classmethod] + fn dictionary(_: &Bound, index_type: PyField, value_type: PyField) -> Self { + Self(DataType::Dictionary( + Box::new(index_type.into_inner().data_type().clone()), + Box::new(value_type.into_inner().data_type().clone()), + )) + } + + #[classmethod] + fn run_end_encoded(_: &Bound, run_end_type: PyField, value_type: PyField) -> Self { + Self(DataType::RunEndEncoded( + run_end_type.into_inner(), + value_type.into_inner(), + )) + } } From 4ba24bef409a9ec5fbf06278cb89f9688ecf87ba Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Mon, 29 Jul 2024 20:09:20 -0400 Subject: [PATCH 2/2] Added DataType methods --- DEVELOP.md | 7 +- arro3-core/python/arro3/core/_core.pyi | 145 ++++++++++++ docs/api/core/datatype.md | 8 + mkdocs.yml | 1 + pyo3-arrow/src/datatypes.rs | 295 ++++++++++++++++++++++++- 5 files changed, 447 insertions(+), 9 deletions(-) create mode 100644 docs/api/core/datatype.md diff --git a/DEVELOP.md b/DEVELOP.md index c681f62..d01772a 100644 --- a/DEVELOP.md +++ b/DEVELOP.md @@ -5,9 +5,8 @@ rm -rf .venv poetry install # Note: need to install core first because others depend on core -poetry run maturin build -m arro3-core/Cargo.toml -o dist -poetry run maturin build -m arro3-compute/Cargo.toml -o dist -poetry run maturin build -m arro3-io/Cargo.toml -o dist -poetry run pip install dist/* +poetry run maturin develop -m arro3-core/Cargo.toml +poetry run maturin develop -m arro3-compute/Cargo.toml +poetry run maturin develop -m arro3-io/Cargo.toml poetry run mkdocs serve ``` diff --git a/arro3-core/python/arro3/core/_core.pyi b/arro3-core/python/arro3/core/_core.pyi index a31e6b1..14107e8 100644 --- a/arro3-core/python/arro3/core/_core.pyi +++ b/arro3-core/python/arro3/core/_core.pyi @@ -124,8 +124,43 @@ class DataType: @classmethod def from_arrow_pycapsule(cls, capsule) -> DataType: """Construct this object from a bare Arrow PyCapsule""" + @property def bit_width(self) -> int | None: ... + def equals( + self, other: ArrowSchemaExportable, *, check_metadata: bool = False + ) -> bool: + """Return true if type is equivalent to passed value. + + Args: + other: _description_ + check_metadata: Whether nested Field metadata equality should be checked as well. Defaults to False. + + Returns: + _description_ + """ + @property + def list_size(self) -> int | None: + """The size of the list in the case of fixed size lists. + + This will return `None` if the data type is not a fixed size list. + + Examples: + + ```py + from arro3.core import DataType + DataType.list(DataType.int32(), 2).list_size + # 2 + ``` + + Returns: + _description_ + """ + @property + def num_fields(self) -> int: + """The number of child fields.""" + ################# #### Constructors + ################# @classmethod def null(cls) -> DataType: """Create instance of null type.""" @@ -383,6 +418,116 @@ class DataType: _description_ """ + ################## + #### Type Checking + ################## + @staticmethod + def is_boolean(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_integer(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_signed_integer(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_unsigned_integer(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_int8(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_int16(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_int32(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_int64(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_uint8(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_uint16(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_uint32(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_uint64(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_floating(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_float16(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_float32(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_float64(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_decimal(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_decimal128(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_decimal256(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_list(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_large_list(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_fixed_size_list(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_list_view(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_large_list_view(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_struct(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_union(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_nested(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_run_end_encoded(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_temporal(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_timestamp(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_date(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_date32(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_date64(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_time(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_time32(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_time64(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_duration(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_interval(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_null(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_binary(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_unicode(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_string(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_large_binary(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_large_unicode(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_large_string(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_binary_view(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_string_view(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_fixed_size_binary(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_map(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_dictionary(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_primitive(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_numeric(t: ArrowSchemaExportable) -> bool: ... + @staticmethod + def is_dictionary_key_type(t: ArrowSchemaExportable) -> bool: ... + class Field: def __init__( self, diff --git a/docs/api/core/datatype.md b/docs/api/core/datatype.md new file mode 100644 index 0000000..672baed --- /dev/null +++ b/docs/api/core/datatype.md @@ -0,0 +1,8 @@ +# arro3.core.DataType + +::: arro3.core.DataType + options: + filters: + - "!^_" + - "^__arrow" + show_if_no_docstring: true diff --git a/mkdocs.yml b/mkdocs.yml index 91fa4a7..074cbd7 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -28,6 +28,7 @@ nav: - api/core/record-batch.md - api/core/schema.md - api/core/table.md + - api/core/datatype.md - api/core/types.md - api/compute.md - api/io.md diff --git a/pyo3-arrow/src/datatypes.rs b/pyo3-arrow/src/datatypes.rs index 73ac550..a5799b4 100644 --- a/pyo3-arrow/src/datatypes.rs +++ b/pyo3-arrow/src/datatypes.rs @@ -118,8 +118,8 @@ impl PyDataType { to_schema_pycapsule(py, &self.0) } - pub fn __eq__(&self, other: &PyDataType) -> bool { - self.0 == other.0 + pub fn __eq__(&self, other: PyDataType) -> bool { + self.equals(other, false) } pub fn __repr__(&self) -> String { @@ -147,10 +147,75 @@ impl PyDataType { Ok(Self::new(data_type)) } + #[getter] pub fn bit_width(&self) -> Option { self.0.primitive_width() } + #[pyo3(signature=(other, *, check_metadata=false))] + fn equals(&self, other: PyDataType, check_metadata: bool) -> bool { + let other = other.into_inner(); + if check_metadata { + self.0 == other + } else { + self.0.equals_datatype(&other) + } + } + + #[getter] + fn list_size(&self) -> Option { + match &self.0 { + DataType::FixedSizeList(_, list_size) => Some(*list_size), + _ => None, + } + } + + #[getter] + fn num_fields(&self) -> usize { + match &self.0 { + DataType::Null + | DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Timestamp(_, _) + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Interval(_) + | DataType::Binary + | DataType::FixedSizeBinary(_) + | DataType::LargeBinary + | DataType::BinaryView + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Utf8View + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) => 0, + DataType::List(_) + | DataType::ListView(_) + | DataType::FixedSizeList(_, _) + | DataType::LargeList(_) + | DataType::LargeListView(_) => 1, + DataType::Struct(fields) => fields.len(), + DataType::Union(fields, _) => fields.len(), + // Is this accurate? + DataType::Dictionary(_, _) | DataType::Map(_, _) | DataType::RunEndEncoded(_, _) => 2, + } + } + + ///////////////////// Constructors + #[classmethod] fn null(_: &Bound) -> Self { Self(DataType::Null) @@ -361,10 +426,10 @@ impl PyDataType { } #[classmethod] - fn dictionary(_: &Bound, index_type: PyField, value_type: PyField) -> Self { + fn dictionary(_: &Bound, index_type: PyDataType, value_type: PyDataType) -> Self { Self(DataType::Dictionary( - Box::new(index_type.into_inner().data_type().clone()), - Box::new(value_type.into_inner().data_type().clone()), + Box::new(index_type.into_inner()), + Box::new(value_type.into_inner()), )) } @@ -375,4 +440,224 @@ impl PyDataType { value_type.into_inner(), )) } + + ///////////////////// Type checking + + #[staticmethod] + fn is_boolean(t: PyDataType) -> bool { + t.0 == DataType::Boolean + } + + #[staticmethod] + fn is_integer(t: PyDataType) -> bool { + t.0.is_integer() + } + + #[staticmethod] + fn is_signed_integer(t: PyDataType) -> bool { + t.0.is_signed_integer() + } + + #[staticmethod] + fn is_unsigned_integer(t: PyDataType) -> bool { + t.0.is_unsigned_integer() + } + + #[staticmethod] + fn is_int8(t: PyDataType) -> bool { + t.0 == DataType::Int8 + } + #[staticmethod] + fn is_int16(t: PyDataType) -> bool { + t.0 == DataType::Int16 + } + #[staticmethod] + fn is_int32(t: PyDataType) -> bool { + t.0 == DataType::Int32 + } + #[staticmethod] + fn is_int64(t: PyDataType) -> bool { + t.0 == DataType::Int64 + } + #[staticmethod] + fn is_uint8(t: PyDataType) -> bool { + t.0 == DataType::UInt8 + } + #[staticmethod] + fn is_uint16(t: PyDataType) -> bool { + t.0 == DataType::UInt16 + } + #[staticmethod] + fn is_uint32(t: PyDataType) -> bool { + t.0 == DataType::UInt32 + } + #[staticmethod] + fn is_uint64(t: PyDataType) -> bool { + t.0 == DataType::UInt64 + } + #[staticmethod] + fn is_floating(t: PyDataType) -> bool { + t.0.is_floating() + } + #[staticmethod] + fn is_float16(t: PyDataType) -> bool { + t.0 == DataType::Float16 + } + #[staticmethod] + fn is_float32(t: PyDataType) -> bool { + t.0 == DataType::Float32 + } + #[staticmethod] + fn is_float64(t: PyDataType) -> bool { + t.0 == DataType::Float64 + } + #[staticmethod] + fn is_decimal(t: PyDataType) -> bool { + matches!(t.0, DataType::Decimal128(_, _) | DataType::Decimal256(_, _)) + } + #[staticmethod] + fn is_decimal128(t: PyDataType) -> bool { + matches!(t.0, DataType::Decimal128(_, _)) + } + #[staticmethod] + fn is_decimal256(t: PyDataType) -> bool { + matches!(t.0, DataType::Decimal256(_, _)) + } + + #[staticmethod] + fn is_list(t: PyDataType) -> bool { + matches!(t.0, DataType::List(_)) + } + #[staticmethod] + fn is_large_list(t: PyDataType) -> bool { + matches!(t.0, DataType::LargeList(_)) + } + #[staticmethod] + fn is_fixed_size_list(t: PyDataType) -> bool { + matches!(t.0, DataType::FixedSizeList(_, _)) + } + #[staticmethod] + fn is_list_view(t: PyDataType) -> bool { + matches!(t.0, DataType::ListView(_)) + } + #[staticmethod] + fn is_large_list_view(t: PyDataType) -> bool { + matches!(t.0, DataType::LargeListView(_)) + } + #[staticmethod] + fn is_struct(t: PyDataType) -> bool { + matches!(t.0, DataType::Struct(_)) + } + #[staticmethod] + fn is_union(t: PyDataType) -> bool { + matches!(t.0, DataType::Union(_, _)) + } + #[staticmethod] + fn is_nested(t: PyDataType) -> bool { + t.0.is_nested() + } + #[staticmethod] + fn is_run_end_encoded(t: PyDataType) -> bool { + t.0.is_run_ends_type() + } + #[staticmethod] + fn is_temporal(t: PyDataType) -> bool { + t.0.is_temporal() + } + #[staticmethod] + fn is_timestamp(t: PyDataType) -> bool { + matches!(t.0, DataType::Timestamp(_, _)) + } + #[staticmethod] + fn is_date(t: PyDataType) -> bool { + matches!(t.0, DataType::Date32 | DataType::Date64) + } + #[staticmethod] + fn is_date32(t: PyDataType) -> bool { + t.0 == DataType::Date32 + } + #[staticmethod] + fn is_date64(t: PyDataType) -> bool { + t.0 == DataType::Date64 + } + #[staticmethod] + fn is_time(t: PyDataType) -> bool { + matches!(t.0, DataType::Time32(_) | DataType::Time64(_)) + } + #[staticmethod] + fn is_time32(t: PyDataType) -> bool { + matches!(t.0, DataType::Time32(_)) + } + #[staticmethod] + fn is_time64(t: PyDataType) -> bool { + matches!(t.0, DataType::Time64(_)) + } + #[staticmethod] + fn is_duration(t: PyDataType) -> bool { + matches!(t.0, DataType::Duration(_)) + } + #[staticmethod] + fn is_interval(t: PyDataType) -> bool { + matches!(t.0, DataType::Interval(_)) + } + #[staticmethod] + fn is_null(t: PyDataType) -> bool { + t.0 == DataType::Null + } + #[staticmethod] + fn is_binary(t: PyDataType) -> bool { + t.0 == DataType::Binary + } + #[staticmethod] + fn is_unicode(t: PyDataType) -> bool { + t.0 == DataType::Utf8 + } + #[staticmethod] + fn is_string(t: PyDataType) -> bool { + t.0 == DataType::Utf8 + } + #[staticmethod] + fn is_large_binary(t: PyDataType) -> bool { + t.0 == DataType::LargeBinary + } + #[staticmethod] + fn is_large_unicode(t: PyDataType) -> bool { + t.0 == DataType::LargeUtf8 + } + #[staticmethod] + fn is_large_string(t: PyDataType) -> bool { + t.0 == DataType::LargeUtf8 + } + #[staticmethod] + fn is_binary_view(t: PyDataType) -> bool { + t.0 == DataType::BinaryView + } + #[staticmethod] + fn is_string_view(t: PyDataType) -> bool { + t.0 == DataType::Utf8View + } + #[staticmethod] + fn is_fixed_size_binary(t: PyDataType) -> bool { + matches!(t.0, DataType::FixedSizeBinary(_)) + } + #[staticmethod] + fn is_map(t: PyDataType) -> bool { + matches!(t.0, DataType::Map(_, _)) + } + #[staticmethod] + fn is_dictionary(t: PyDataType) -> bool { + matches!(t.0, DataType::Dictionary(_, _)) + } + #[staticmethod] + fn is_primitive(t: PyDataType) -> bool { + t.0.is_primitive() + } + #[staticmethod] + fn is_numeric(t: PyDataType) -> bool { + t.0.is_numeric() + } + #[staticmethod] + fn is_dictionary_key_type(t: PyDataType) -> bool { + t.0.is_dictionary_key_type() + } }