From 132c64d98cd0757ae7c07bb1bbe66e3453f81b34 Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Fri, 22 Nov 2024 13:02:41 +0100 Subject: [PATCH] refactor: Implement nested row encoding / decoding (#19874) --- .../src/array/dictionary/typed_iterator.rs | 1 + .../src/array/fixed_size_list/mod.rs | 26 + crates/polars-arrow/src/offset.rs | 2 +- crates/polars-arrow/src/trusted_len.rs | 3 +- .../src/chunked_array/ops/row_encode.rs | 33 +- crates/polars-row/src/decode.rs | 301 ++++- crates/polars-row/src/encode.rs | 1164 +++++++++++------ crates/polars-row/src/fixed.rs | 55 +- crates/polars-row/src/variable.rs | 48 +- py-polars/tests/unit/test_row_encoding.py | 379 +++++- 10 files changed, 1459 insertions(+), 553 deletions(-) diff --git a/crates/polars-arrow/src/array/dictionary/typed_iterator.rs b/crates/polars-arrow/src/array/dictionary/typed_iterator.rs index d7e7637bf28d..42482f1b082a 100644 --- a/crates/polars-arrow/src/array/dictionary/typed_iterator.rs +++ b/crates/polars-arrow/src/array/dictionary/typed_iterator.rs @@ -181,6 +181,7 @@ impl<'a, K: DictionaryKey, V: DictValue> Iterator for DictionaryIterTyped<'a, K, unsafe impl TrustedLen for DictionaryIterTyped<'_, K, V> {} +impl ExactSizeIterator for DictionaryIterTyped<'_, K, V> {} impl DoubleEndedIterator for DictionaryIterTyped<'_, K, V> { #[inline] fn next_back(&mut self) -> Option { diff --git a/crates/polars-arrow/src/array/fixed_size_list/mod.rs b/crates/polars-arrow/src/array/fixed_size_list/mod.rs index 32267cc5a4b7..11ff72c1fc64 100644 --- a/crates/polars-arrow/src/array/fixed_size_list/mod.rs +++ b/crates/polars-arrow/src/array/fixed_size_list/mod.rs @@ -225,6 +225,32 @@ impl FixedSizeListArray { } dims } + + pub fn propagate_nulls(&self) -> Self { + let Some(validity) = self.validity() else { + return self.clone(); + }; + + let propagated_validity = if self.size == 1 { + validity.clone() + } else { + Bitmap::from_trusted_len_iter( + (0..self.size * validity.len()) + .map(|i| unsafe { validity.get_bit_unchecked(i / self.size) }), + ) + }; + + let propagated_validity = match self.values.validity() { + None => propagated_validity, + Some(val) => val & &propagated_validity, + }; + Self::new( + self.dtype().clone(), + self.length, + self.values.with_validity(Some(propagated_validity)), + self.validity.clone(), + ) + } } // must use diff --git a/crates/polars-arrow/src/offset.rs b/crates/polars-arrow/src/offset.rs index 1b8a02bc7fde..694a058edc9e 100644 --- a/crates/polars-arrow/src/offset.rs +++ b/crates/polars-arrow/src/offset.rs @@ -514,7 +514,7 @@ impl OffsetsBuffer { /// Returns `(offset, len)` pairs. #[inline] - pub fn offset_and_length_iter(&self) -> impl Iterator + '_ { + pub fn offset_and_length_iter(&self) -> impl ExactSizeIterator + '_ { self.windows(2).map(|x| { let [l, r] = x else { unreachable!() }; let l = l.to_usize(); diff --git a/crates/polars-arrow/src/trusted_len.rs b/crates/polars-arrow/src/trusted_len.rs index 359edfd1b88c..242e8d2878a5 100644 --- a/crates/polars-arrow/src/trusted_len.rs +++ b/crates/polars-arrow/src/trusted_len.rs @@ -1,6 +1,6 @@ //! Declares [`TrustedLen`]. use std::iter::Scan; -use std::slice::Iter; +use std::slice::{Iter, IterMut}; /// An iterator of known, fixed size. /// @@ -14,6 +14,7 @@ use std::slice::Iter; pub unsafe trait TrustedLen: Iterator {} unsafe impl TrustedLen for Iter<'_, T> {} +unsafe impl TrustedLen for IterMut<'_, T> {} unsafe impl<'a, I, T: 'a> TrustedLen for std::iter::Copied where diff --git a/crates/polars-core/src/chunked_array/ops/row_encode.rs b/crates/polars-core/src/chunked_array/ops/row_encode.rs index 31b76357c470..2b683bef534b 100644 --- a/crates/polars-core/src/chunked_array/ops/row_encode.rs +++ b/crates/polars-core/src/chunked_array/ops/row_encode.rs @@ -144,20 +144,8 @@ pub fn _get_rows_encoded_unordered(by: &[Series]) -> PolarsResult { let arr = _get_rows_encoded_compat_array(by)?; let field = EncodingField::new_unsorted(); - match arr.dtype() { - // Flatten the struct fields. - ArrowDataType::Struct(_) => { - let arr = arr.as_any().downcast_ref::().unwrap(); - for arr in arr.values() { - cols.push(arr.clone() as ArrayRef); - fields.push(field) - } - }, - _ => { - cols.push(arr); - fields.push(field) - }, - } + cols.push(arr); + fields.push(field); } Ok(convert_columns(num_rows, &cols, &fields)) } @@ -187,21 +175,8 @@ pub fn _get_rows_encoded( nulls_last: *null_last, no_order: false, }; - match arr.dtype() { - // Flatten the struct fields. - ArrowDataType::Struct(_) => { - let arr = arr.as_any().downcast_ref::().unwrap(); - let arr = arr.propagate_nulls(); - for value_arr in arr.values() { - cols.push(value_arr.clone() as ArrayRef); - fields.push(sort_field); - } - }, - _ => { - cols.push(arr); - fields.push(sort_field); - }, - } + cols.push(arr); + fields.push(sort_field); } Ok(convert_columns(num_rows, &cols, &fields)) } diff --git a/crates/polars-row/src/decode.rs b/crates/polars-row/src/decode.rs index 04c320e33ec3..060733681de1 100644 --- a/crates/polars-row/src/decode.rs +++ b/crates/polars-row/src/decode.rs @@ -1,5 +1,10 @@ +use arrow::bitmap::{Bitmap, MutableBitmap}; +use arrow::buffer::Buffer; use arrow::datatypes::ArrowDataType; +use arrow::offset::OffsetsBuffer; +use self::encode::fixed_size; +use self::fixed::get_null_sentinel; use super::*; use crate::fixed::{decode_bool, decode_primitive}; use crate::variable::{decode_binary, decode_binview}; @@ -38,9 +43,249 @@ pub unsafe fn decode_rows( .collect() } +unsafe fn decode_validity(rows: &mut [&[u8]], field: &EncodingField) -> Option { + // 2 loop system to avoid the overhead of allocating the bitmap if all the elements are valid. + + let null_sentinel = get_null_sentinel(field); + let first_null = (0..rows.len()).find(|&i| { + let v; + (v, rows[i]) = rows[i].split_at_unchecked(1); + v[0] == null_sentinel + }); + + // No nulls just return None + let first_null = first_null?; + + let mut bm = MutableBitmap::new(); + bm.reserve(rows.len()); + bm.extend_constant(first_null, true); + bm.push(false); + bm.extend_from_trusted_len_iter(rows[first_null + 1..].iter_mut().map(|row| { + let v; + (v, *row) = row.split_at_unchecked(1); + v[0] != null_sentinel + })); + Some(bm.freeze()) +} + +// We inline this in an attempt to avoid the dispatch cost. +#[inline(always)] +fn dtype_and_data_to_encoded_item_len( + dtype: &ArrowDataType, + data: &[u8], + field: &EncodingField, +) -> usize { + // Fast path: if the size is fixed, we can just divide. + if let Some(size) = fixed_size(dtype) { + return size; + } + + let (non_empty_sentinel, continuation_token) = if field.descending { + ( + !variable::NON_EMPTY_SENTINEL, + !variable::BLOCK_CONTINUATION_TOKEN, + ) + } else { + ( + variable::NON_EMPTY_SENTINEL, + variable::BLOCK_CONTINUATION_TOKEN, + ) + }; + + use ArrowDataType as D; + match dtype { + D::Binary + | D::LargeBinary + | D::Utf8 + | D::LargeUtf8 + | D::List(_) + | D::LargeList(_) + | D::BinaryView + | D::Utf8View => unsafe { + crate::variable::encoded_item_len(data, non_empty_sentinel, continuation_token) + }, + + D::FixedSizeBinary(_) => todo!(), + D::FixedSizeList(fsl_field, width) => { + let mut data = &data[1..]; + let mut item_len = 1; // validity byte + + for _ in 0..*width { + let len = dtype_and_data_to_encoded_item_len(fsl_field.dtype(), data, field); + data = &data[len..]; + item_len += len; + } + item_len + }, + D::Struct(struct_fields) => { + let mut data = &data[1..]; + let mut item_len = 1; // validity byte + + for struct_field in struct_fields { + let len = dtype_and_data_to_encoded_item_len(struct_field.dtype(), data, field); + data = &data[len..]; + item_len += len; + } + item_len + }, + + D::Union(_, _, _) => todo!(), + D::Map(_, _) => todo!(), + D::Dictionary(_, _, _) => todo!(), + D::Decimal(_, _) => todo!(), + D::Decimal256(_, _) => todo!(), + D::Extension(_, _, _) => todo!(), + D::Unknown => todo!(), + + _ => unreachable!(), + } +} + +fn rows_for_fixed_size_list<'a>( + dtype: &ArrowDataType, + field: &EncodingField, + width: usize, + rows: &mut [&'a [u8]], + nested_rows: &mut Vec<&'a [u8]>, +) { + nested_rows.clear(); + nested_rows.reserve(rows.len() * width); + + // Fast path: if the size is fixed, we can just divide. + if let Some(size) = fixed_size(dtype) { + for row in rows.iter_mut() { + for i in 0..width { + nested_rows.push(&row[(i * size)..][..size]); + } + *row = &row[size * width..]; + } + return; + } + + use ArrowDataType as D; + match dtype { + D::FixedSizeBinary(_) => todo!(), + D::BinaryView + | D::Utf8View + | D::Binary + | D::LargeBinary + | D::Utf8 + | D::LargeUtf8 + | D::List(_) + | D::LargeList(_) => { + let (non_empty_sentinel, continuation_token) = if field.descending { + ( + !variable::NON_EMPTY_SENTINEL, + !variable::BLOCK_CONTINUATION_TOKEN, + ) + } else { + ( + variable::NON_EMPTY_SENTINEL, + variable::BLOCK_CONTINUATION_TOKEN, + ) + }; + + for row in rows.iter_mut() { + for _ in 0..width { + let length = unsafe { + crate::variable::encoded_item_len( + row, + non_empty_sentinel, + continuation_token, + ) + }; + let v; + (v, *row) = row.split_at(length); + nested_rows.push(v); + } + } + }, + _ => { + // @TODO: This is quite slow since we need to dispatch for possibly every nested type + for row in rows.iter_mut() { + for _ in 0..width { + let length = dtype_and_data_to_encoded_item_len(dtype, row, field); + let v; + (v, *row) = row.split_at(length); + nested_rows.push(v); + } + } + }, + } +} + +fn offsets_from_dtype_and_data( + dtype: &ArrowDataType, + field: &EncodingField, + data: &[u8], + offsets: &mut Vec, +) { + offsets.clear(); + + // Fast path: if the size is fixed, we can just divide. + if let Some(size) = fixed_size(dtype) { + assert!(size == 0 || data.len() % size == 0); + offsets.extend((0..data.len() / size).map(|i| i * size)); + return; + } + + use ArrowDataType as D; + match dtype { + D::FixedSizeBinary(_) => todo!(), + D::BinaryView + | D::Utf8View + | D::Binary + | D::LargeBinary + | D::Utf8 + | D::LargeUtf8 + | D::List(_) + | D::LargeList(_) => { + let mut data = data; + let (non_empty_sentinel, continuation_token) = if field.descending { + ( + !variable::NON_EMPTY_SENTINEL, + !variable::BLOCK_CONTINUATION_TOKEN, + ) + } else { + ( + variable::NON_EMPTY_SENTINEL, + variable::BLOCK_CONTINUATION_TOKEN, + ) + }; + let mut offset = 0; + while !data.is_empty() { + let length = unsafe { + crate::variable::encoded_item_len(data, non_empty_sentinel, continuation_token) + }; + offsets.push(offset); + data = &data[length..]; + offset += length; + } + }, + _ => { + // @TODO: This is quite slow since we need to dispatch for possibly every nested type + let mut data = data; + let mut offset = 0; + while !data.is_empty() { + let length = dtype_and_data_to_encoded_item_len(dtype, data, field); + offsets.push(offset); + data = &data[length..]; + offset += length; + } + }, + } +} + unsafe fn decode(rows: &mut [&[u8]], field: &EncodingField, dtype: &ArrowDataType) -> ArrayRef { match dtype { - ArrowDataType::Null => NullArray::new(ArrowDataType::Null, rows.len()).to_boxed(), + ArrowDataType::Null => { + // Temporary: remove when list encoding is better. + for row in rows.iter_mut() { + *row = &row[1..]; + } + + NullArray::new(ArrowDataType::Null, rows.len()).to_boxed() + }, ArrowDataType::Boolean => decode_bool(rows, field).to_boxed(), ArrowDataType::BinaryView | ArrowDataType::LargeBinary => { decode_binview(rows, field).to_boxed() @@ -60,14 +305,62 @@ unsafe fn decode(rows: &mut [&[u8]], field: &EncodingField, dtype: &ArrowDataTyp .to_boxed() }, ArrowDataType::Struct(fields) => { + let validity = decode_validity(rows, field); let values = fields .iter() .map(|struct_fld| decode(rows, field, struct_fld.dtype())) .collect(); - StructArray::new(dtype.clone(), rows.len(), values, None).to_boxed() + StructArray::new(dtype.clone(), rows.len(), values, validity).to_boxed() }, - ArrowDataType::List { .. } | ArrowDataType::LargeList { .. } => { - todo!("list decoding is not yet supported in polars' row encoding") + ArrowDataType::FixedSizeList(fsl_field, width) => { + let validity = decode_validity(rows, field); + + // @TODO: we could consider making this into a scratchpad + let mut nested_rows = Vec::new(); + rows_for_fixed_size_list(fsl_field.dtype(), field, *width, rows, &mut nested_rows); + let values = decode(&mut nested_rows, field, fsl_field.dtype()); + + FixedSizeListArray::new(dtype.clone(), rows.len(), values, validity).to_boxed() + }, + ArrowDataType::List(list_field) | ArrowDataType::LargeList(list_field) => { + let arr = decode_binary(rows, field); + + let mut offsets = Vec::with_capacity(rows.len()); + // @TODO: we could consider making this into a scratchpad + let mut nested_offsets = Vec::new(); + offsets_from_dtype_and_data( + list_field.dtype(), + field, + arr.values().as_ref(), + &mut nested_offsets, + ); + // @TODO: This might cause realloc, fix. + nested_offsets.push(arr.values().len()); + let mut nested_rows = nested_offsets + .windows(2) + .map(|vs| &arr.values()[vs[0]..vs[1]]) + .collect::>(); + + let mut i = 0; + for offset in arr.offsets().iter() { + while nested_offsets[i] != offset.as_usize() { + i += 1; + } + + offsets.push(i as i64); + } + assert_eq!(offsets.len(), rows.len() + 1); + + let values = decode(&mut nested_rows, field, list_field.dtype()); + let (_, _, _, validity) = arr.into_inner(); + + ListArray::::new( + dtype.clone(), + unsafe { OffsetsBuffer::new_unchecked(Buffer::from(offsets)) }, + values, + validity, + ) + .to_boxed() }, dt => { with_match_arrow_primitive_type!(dt, |$T| { diff --git a/crates/polars-row/src/encode.rs b/crates/polars-row/src/encode.rs index 57ede510fb11..65bf08b3fa29 100644 --- a/crates/polars-row/src/encode.rs +++ b/crates/polars-row/src/encode.rs @@ -1,15 +1,14 @@ +use std::mem::MaybeUninit; + use arrow::array::{ - Array, BinaryArray, BinaryViewArray, BooleanArray, DictionaryArray, PrimitiveArray, - StructArray, Utf8ViewArray, + Array, BinaryArray, BinaryViewArray, BooleanArray, DictionaryArray, FixedSizeListArray, + ListArray, PrimitiveArray, StructArray, Utf8Array, Utf8ViewArray, }; -use arrow::bitmap::utils::ZipValidity; -use arrow::compute::utils::combine_validities_and; +use arrow::bitmap::Bitmap; use arrow::datatypes::ArrowDataType; -use arrow::legacy::prelude::{LargeBinaryArray, LargeListArray}; -use arrow::types::NativeType; -use polars_utils::vec::PushUnchecked; +use arrow::types::{NativeType, Offset}; -use crate::fixed::FixedLengthEncoding; +use crate::fixed::{get_null_sentinel, FixedLengthEncoding}; use crate::row::{EncodingField, RowsEncoded}; use crate::{with_match_arrow_primitive_type, ArrayRef}; @@ -42,247 +41,769 @@ pub fn convert_columns_amortized_no_order( ); } -enum Encoder { - // For list encoding we recursively call encode on the inner until we - // have a leaf we can encode. - // On allocation we already encode the leaves and set those to `rows`. - List { - enc: Vec, - rows: Option, - original: LargeListArray, - field: EncodingField, - }, - Leaf(ArrayRef), +pub fn convert_columns_amortized<'a>( + num_rows: usize, + columns: &[ArrayRef], + fields: impl IntoIterator + Clone, + rows: &mut RowsEncoded, +) { + let mut row_widths = RowWidths::new(num_rows); + let mut encoders = columns + .iter() + .zip(fields.clone()) + .map(|(column, field)| get_encoder(column.as_ref(), field, &mut row_widths)) + .collect::>(); + + // Create an offsets array, we append 0 at the beginning here so it can serve as the final + // offset array. + let mut offsets = Vec::with_capacity(num_rows + 1); + offsets.push(0); + row_widths.extend_with_offsets(&mut offsets); + + // Create a buffer without initializing everything to zero. + let total_num_bytes = row_widths.sum(); + let mut out = Vec::::with_capacity(total_num_bytes); + let buffer = &mut out.spare_capacity_mut()[..total_num_bytes]; + + let mut scratches = EncodeScratches::default(); + for (encoder, field) in encoders.iter_mut().zip(fields) { + unsafe { encode_array(buffer, encoder, field, &mut offsets[1..], &mut scratches) }; + } + // SAFETY: All the bytes in out up to total_num_bytes should now be initialized. + unsafe { + out.set_len(total_num_bytes); + } + + *rows = RowsEncoded { + values: out, + offsets, + }; +} + +/// Container of byte-widths for (partial) rows. +/// +/// The `RowWidths` keeps track of the sum of all widths and allows to efficiently deal with a +/// constant row-width (i.e. with primitive types). +#[derive(Debug, Clone)] +pub(crate) enum RowWidths { + Constant { num_rows: usize, width: usize }, + // @TODO: Maybe turn this into a Box<[usize]> + Variable { widths: Vec, sum: usize }, } -impl Encoder { - fn list_iter(&self) -> impl Iterator> { +impl Default for RowWidths { + fn default() -> Self { + Self::Constant { + num_rows: 0, + width: 0, + } + } +} + +impl RowWidths { + fn new(num_rows: usize) -> Self { + Self::Constant { num_rows, width: 0 } + } + + /// Push a constant width into the widths + fn push_constant(&mut self, constant: usize) { + match self { + Self::Constant { width, .. } => *width += constant, + Self::Variable { widths, sum } => { + widths.iter_mut().for_each(|w| *w += constant); + *sum += constant * widths.len(); + }, + } + } + /// Push an another [`RowWidths`] into the widths + fn push(&mut self, other: &Self) { + debug_assert_eq!(self.num_rows(), other.num_rows()); + + match (std::mem::take(self), other) { + (mut slf, RowWidths::Constant { width, num_rows: _ }) => { + slf.push_constant(*width); + *self = slf; + }, + (RowWidths::Constant { num_rows, width }, RowWidths::Variable { widths, sum }) => { + *self = RowWidths::Variable { + widths: widths.iter().map(|w| *w + width).collect(), + sum: num_rows * width + sum, + }; + }, + ( + RowWidths::Variable { mut widths, sum }, + RowWidths::Variable { + widths: other_widths, + sum: other_sum, + }, + ) => { + widths + .iter_mut() + .zip(other_widths.iter()) + .for_each(|(l, r)| *l += *r); + *self = RowWidths::Variable { + widths, + sum: sum + other_sum, + }; + }, + } + } + + /// Create a [`RowWidths`] with the chunked sum with a certain `chunk_size`. + fn collapse_chunks(&self, chunk_size: usize, output_num_rows: usize) -> RowWidths { + if chunk_size == 0 { + assert_eq!(self.num_rows(), 0); + return RowWidths::new(output_num_rows); + } + + assert_eq!(self.num_rows() % chunk_size, 0); + assert_eq!(self.num_rows() / chunk_size, output_num_rows); + match self { + Self::Constant { num_rows, width } => Self::Constant { + num_rows: num_rows / chunk_size, + width: width * chunk_size, + }, + Self::Variable { widths, sum } => Self::Variable { + widths: widths + .chunks_exact(chunk_size) + .map(|chunk| chunk.iter().copied().sum()) + .collect(), + sum: *sum, + }, + } + } + + fn extend_with_offsets(&self, out: &mut Vec) { + match self { + RowWidths::Constant { num_rows, width } => { + out.extend((0..*num_rows).map(|i| i * width)); + }, + RowWidths::Variable { widths, sum: _ } => { + let mut next = 0; + out.extend(widths.iter().map(|w| { + let current = next; + next += w; + current + })); + }, + } + } + + fn num_rows(&self) -> usize { + match self { + Self::Constant { num_rows, .. } => *num_rows, + Self::Variable { widths, .. } => widths.len(), + } + } + + fn append_iter(&mut self, iter: impl ExactSizeIterator) -> RowWidths { + assert_eq!(self.num_rows(), iter.len()); + match self { - Encoder::Leaf(_) => unreachable!(), - Encoder::List { original, rows, .. } => { - let rows = rows.as_ref().unwrap(); - // This should be 0 due to rows encoding; - assert_eq!(rows.null_count(), 0); - - let offsets = original.offsets().windows(2); - let zipped = ZipValidity::new_with_validity(offsets, original.validity()); - - let binary_offsets = rows.offsets(); - let row_values = rows.values().as_slice(); - - zipped.map(|opt_window| { - opt_window.map(|window| { - unsafe { - // Offsets of the list - let start = *window.get_unchecked(0); - let end = *window.get_unchecked(1); - - // Offsets in the binary values. - let start = *binary_offsets.get_unchecked(start as usize); - let end = *binary_offsets.get_unchecked(end as usize); - - let start = start as usize; - let end = end as usize; - - row_values.get_unchecked(start..end) - } + RowWidths::Constant { num_rows, width } => { + let num_rows = *num_rows; + let width = *width; + + let mut sum = 0; + let (slf, out) = iter + .map(|v| { + sum += v; + (v + width, v) + }) + .collect(); + + *self = Self::Variable { + widths: slf, + sum: num_rows * width + sum, + }; + Self::Variable { widths: out, sum } + }, + RowWidths::Variable { widths, sum } => { + let mut out_sum = 0; + let out = iter + .zip(widths) + .map(|(v, w)| { + out_sum += v; + *w += v; + v }) - }) + .collect(); + + *sum += out_sum; + Self::Variable { + widths: out, + sum: out_sum, + } }, } } - fn dtype(&self) -> &ArrowDataType { + fn get(&self, index: usize) -> usize { + assert!(index < self.num_rows()); match self { - Encoder::List { original, .. } => original.dtype(), - Encoder::Leaf(arr) => arr.dtype(), + Self::Constant { width, .. } => *width, + Self::Variable { widths, .. } => widths[index], } } - fn is_variable(&self) -> bool { + fn sum(&self) -> usize { match self { - Encoder::Leaf(arr) => { - matches!( - arr.dtype(), - ArrowDataType::BinaryView - | ArrowDataType::Dictionary(_, _, _) - | ArrowDataType::LargeBinary + Self::Constant { num_rows, width } => *num_rows * *width, + Self::Variable { sum, .. } => *sum, + } + } +} + +fn list_num_column_bytes( + array: &dyn Array, + field: &EncodingField, + row_widths: &mut RowWidths, +) -> Encoder { + let array = array.as_any().downcast_ref::>().unwrap(); + let array = array.trim_to_normalized_offsets_recursive(); + let values = array.values(); + + let mut list_row_widths = RowWidths::new(values.len()); + let encoder = get_encoder(values.as_ref(), field, &mut list_row_widths); + + // @TODO: make specialized implementation for list_row_widths is RowWidths::Constant + let mut offsets = Vec::with_capacity(list_row_widths.num_rows() + 1); + list_row_widths.extend_with_offsets(&mut offsets); + offsets.push(encoder.widths.sum()); + + let widths = match array.validity() { + None => row_widths.append_iter(array.offsets().offset_and_length_iter().map( + |(offset, length)| { + crate::variable::encoded_len_from_len( + Some(offsets[offset + length] - offsets[offset]), + field, ) }, - Encoder::List { .. } => true, - } + )), + Some(validity) => row_widths.append_iter( + array + .offsets() + .offset_and_length_iter() + .zip(validity.iter()) + .map(|((offset, length), is_valid)| { + crate::variable::encoded_len_from_len( + is_valid.then_some(offsets[offset + length] - offsets[offset]), + field, + ) + }), + ), + }; + + Encoder { + widths, + array: array.boxed(), + state: EncoderState::List(Box::new(encoder)), } } -fn get_encoders(arr: &dyn Array, encoders: &mut Vec, field: &EncodingField) -> usize { - let mut added = 0; - match arr.dtype() { - ArrowDataType::Struct(_) => { - let arr = arr.as_any().downcast_ref::().unwrap(); - for value_arr in arr.values() { - // A hack to make outer validity work. - // TODO! improve - if arr.null_count() > 0 { - let new_validity = combine_validities_and(arr.validity(), value_arr.validity()); - value_arr.with_validity(new_validity); - added += get_encoders(value_arr.as_ref(), encoders, field); - } else { - added += get_encoders(value_arr.as_ref(), encoders, field); - } +fn biniter_num_column_bytes( + array: &dyn Array, + iter: impl ExactSizeIterator, + validity: Option<&Bitmap>, + field: &EncodingField, + row_widths: &mut RowWidths, +) -> Encoder { + let widths = match validity { + None => row_widths + .append_iter(iter.map(|v| crate::variable::encoded_len_from_len(Some(v), field))), + Some(validity) => row_widths.append_iter(iter.zip(validity.iter()).map(|(v, is_valid)| { + crate::variable::encoded_len_from_len(is_valid.then_some(v), field) + })), + }; + + Encoder { + widths, + array: array.to_boxed(), + state: EncoderState::Stateless, + } +} + +/// Get the encoder for a specific array. +fn get_encoder(array: &dyn Array, field: &EncodingField, row_widths: &mut RowWidths) -> Encoder { + use ArrowDataType as D; + let dtype = array.dtype(); + + // Fast path: column has a fixed size encoding + if let Some(size) = fixed_size(dtype) { + row_widths.push_constant(size); + let state = match dtype { + D::FixedSizeList(_, width) => { + let array = array.as_any().downcast_ref::().unwrap(); + let array = array.propagate_nulls(); + + debug_assert_eq!(array.values().len(), array.len() * width); + let nested_encoder = get_encoder( + array.values().as_ref(), + field, + &mut RowWidths::new(array.values().len()), + ); + EncoderState::FixedSizeList(Box::new(nested_encoder), *width) + }, + D::Struct(_) => { + let struct_array = array.as_any().downcast_ref::().unwrap(); + let struct_array = struct_array.propagate_nulls(); + EncoderState::Struct( + struct_array + .values() + .iter() + .map(|array| { + get_encoder( + array.as_ref(), + field, + &mut RowWidths::new(struct_array.len()), + ) + }) + .collect(), + ) + }, + _ => EncoderState::Stateless, + }; + return Encoder { + widths: RowWidths::Constant { + num_rows: array.len(), + width: size, + }, + array: array.to_boxed(), + state, + }; + } + + match dtype { + D::FixedSizeList(_, width) => { + let array = array.as_any().downcast_ref::().unwrap(); + let array = array.propagate_nulls(); + + debug_assert_eq!(array.values().len(), array.len() * width); + let mut nested_row_widths = RowWidths::new(array.values().len()); + let nested_encoder = + get_encoder(array.values().as_ref(), field, &mut nested_row_widths); + + let mut fsl_row_widths = nested_row_widths.collapse_chunks(*width, array.len()); + fsl_row_widths.push_constant(1); // validity byte + + row_widths.push(&fsl_row_widths); + Encoder { + widths: fsl_row_widths, + array: array.to_boxed(), + state: EncoderState::FixedSizeList(Box::new(nested_encoder), *width), } }, - ArrowDataType::Utf8View => { - let arr = arr.as_any().downcast_ref::().unwrap(); - encoders.push(Encoder::Leaf(arr.to_binview().boxed())); - added += 1 + D::Struct(_) => { + let array = array.as_any().downcast_ref::().unwrap(); + let array = array.propagate_nulls(); + + let mut struct_row_widths = RowWidths::new(array.len()); + let mut nested_encoders = Vec::with_capacity(array.values().len()); + struct_row_widths.push_constant(1); // validity byte + for array in array.values() { + let encoder = get_encoder(array.as_ref(), field, &mut struct_row_widths); + nested_encoders.push(encoder); + } + row_widths.push(&struct_row_widths); + Encoder { + widths: struct_row_widths, + array: array.to_boxed(), + state: EncoderState::Struct(nested_encoders), + } + }, + + D::List(_) => list_num_column_bytes::(array, field, row_widths), + D::LargeList(_) => list_num_column_bytes::(array, field, row_widths), + + D::BinaryView => { + let dc_array = array.as_any().downcast_ref::().unwrap(); + biniter_num_column_bytes( + array, + dc_array.views().iter().map(|v| v.length as usize), + dc_array.validity(), + field, + row_widths, + ) + }, + D::Utf8View => { + let dc_array = array.as_any().downcast_ref::().unwrap(); + biniter_num_column_bytes( + array, + dc_array.views().iter().map(|v| v.length as usize), + dc_array.validity(), + field, + row_widths, + ) + }, + D::Binary => { + let dc_array = array.as_any().downcast_ref::>().unwrap(); + biniter_num_column_bytes( + array, + dc_array + .offsets() + .windows(2) + .map(|vs| (vs[1] - vs[0]) as usize), + dc_array.validity(), + field, + row_widths, + ) }, - ArrowDataType::LargeList(_) => { - let arr = arr.as_any().downcast_ref::().unwrap(); - let mut inner = vec![]; - get_encoders(arr.values().as_ref(), &mut inner, field); - encoders.push(Encoder::List { - enc: inner, - original: arr.clone(), - rows: None, - field: *field, - }); - added += 1; + D::Utf8 => { + let dc_array = array.as_any().downcast_ref::>().unwrap(); + biniter_num_column_bytes( + array, + dc_array + .offsets() + .windows(2) + .map(|vs| (vs[1] - vs[0]) as usize), + dc_array.validity(), + field, + row_widths, + ) }, - _ => { - encoders.push(Encoder::Leaf(arr.to_boxed())); - added += 1; + D::LargeBinary => { + let dc_array = array.as_any().downcast_ref::>().unwrap(); + biniter_num_column_bytes( + array, + dc_array + .offsets() + .windows(2) + .map(|vs| (vs[1] - vs[0]) as usize), + dc_array.validity(), + field, + row_widths, + ) }, + D::LargeUtf8 => { + let dc_array = array.as_any().downcast_ref::>().unwrap(); + biniter_num_column_bytes( + array, + dc_array + .offsets() + .windows(2) + .map(|vs| (vs[1] - vs[0]) as usize), + dc_array.validity(), + field, + row_widths, + ) + }, + + D::Dictionary(_, _, _) => { + let dc_array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let iter = dc_array + .iter_typed::() + .unwrap() + .map(|opt_s| opt_s.map_or(0, |s| s.len())); + // @TODO: Do a better job here. This is just plainly incorrect. + biniter_num_column_bytes(array, iter, dc_array.validity(), field, row_widths) + }, + D::Union(_, _, _) => todo!(), + D::Map(_, _) => todo!(), + D::Decimal(_, _) => todo!(), + D::Decimal256(_, _) => todo!(), + D::Extension(_, _, _) => todo!(), + D::Unknown => todo!(), + + // All non-physical types + D::Timestamp(_, _) + | D::Date32 + | D::Date64 + | D::Time32(_) + | D::Time64(_) + | D::Duration(_) + | D::Interval(_) => unreachable!(), + + // Should be fixed size type + _ => unreachable!(), } - added } -pub fn convert_columns_amortized<'a, I: IntoIterator>( - num_rows: usize, - columns: &'a [ArrayRef], - fields: I, - rows: &mut RowsEncoded, +pub struct Encoder { + widths: RowWidths, + array: Box, + state: EncoderState, +} + +pub enum EncoderState { + Stateless, + List(Box), + Dictionary(Box), + FixedSizeList(Box, usize), + Struct(Vec), +} + +unsafe fn encode_flat_array( + buffer: &mut [MaybeUninit], + array: &dyn Array, + field: &EncodingField, + offsets: &mut [usize], ) { - let fields = fields.into_iter(); - assert_eq!(fields.size_hint().0, columns.len()); - if columns.iter().any(|arr| { - matches!( - arr.dtype(), - ArrowDataType::Struct(_) | ArrowDataType::Utf8View | ArrowDataType::LargeList(_) - ) - }) { - let mut flattened_columns = Vec::with_capacity(columns.len() * 5); - let mut flattened_fields = Vec::with_capacity(columns.len() * 5); - - for (arr, field) in columns.iter().zip(fields) { - let added = get_encoders(arr.as_ref(), &mut flattened_columns, field); - for _ in 0..added { - flattened_fields.push(*field); + use ArrowDataType as D; + match array.dtype() { + D::Null => { + // @NOTE: This is an artifact of the list encoding, this can be removed when we have a + // better list encoding. + for offset in offsets.iter_mut() { + buffer[*offset] = MaybeUninit::new(0); + *offset += 1; } - } - let values_size = allocate_rows_buf( - num_rows, - &mut flattened_columns, - &flattened_fields, - &mut rows.values, - &mut rows.offsets, - ); - for (arr, field) in flattened_columns.iter().zip(flattened_fields.iter()) { - // SAFETY: - // we allocated rows with enough bytes. - unsafe { encode_array(arr, field, rows) } - } - // SAFETY: values are initialized - unsafe { rows.values.set_len(values_size) } - } else { - let mut encoders = columns - .iter() - .map(|arr| Encoder::Leaf(arr.clone())) - .collect::>(); - let fields = fields.cloned().collect::>(); - let values_size = allocate_rows_buf( - num_rows, - &mut encoders, - &fields, - &mut rows.values, - &mut rows.offsets, - ); - for (enc, field) in encoders.iter().zip(fields) { - // SAFETY: - // we allocated rows with enough bytes. - unsafe { encode_array(enc, &field, rows) } - } - // SAFETY: values are initialized - unsafe { rows.values.set_len(values_size) } + }, + D::Boolean => { + let array = array.as_any().downcast_ref::().unwrap(); + crate::fixed::encode_iter(buffer, array.iter(), field, offsets); + }, + dt if dt.is_numeric() => with_match_arrow_primitive_type!(dt, |$T| { + let array = array.as_any().downcast_ref::>().unwrap(); + encode_primitive(buffer, array, field, offsets); + }), + + D::Binary => { + let array = array.as_any().downcast_ref::>().unwrap(); + crate::variable::encode_iter(buffer, array.iter(), field, offsets); + }, + D::LargeBinary => { + let array = array.as_any().downcast_ref::>().unwrap(); + crate::variable::encode_iter(buffer, array.iter(), field, offsets); + }, + D::BinaryView => { + let array = array.as_any().downcast_ref::().unwrap(); + crate::variable::encode_iter(buffer, array.iter(), field, offsets); + }, + D::Utf8 => { + let array = array.as_any().downcast_ref::>().unwrap(); + crate::variable::encode_iter( + buffer, + array.iter().map(|v| v.map(|v| v.as_bytes())), + field, + offsets, + ); + }, + D::LargeUtf8 => { + let array = array.as_any().downcast_ref::>().unwrap(); + crate::variable::encode_iter( + buffer, + array.iter().map(|v| v.map(|v| v.as_bytes())), + field, + offsets, + ); + }, + D::Utf8View => { + let array = array.as_any().downcast_ref::().unwrap(); + crate::variable::encode_iter( + buffer, + array.iter().map(|v| v.map(|v| v.as_bytes())), + field, + offsets, + ); + }, + D::Dictionary(_, _, _) => { + let dc_array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let iter = dc_array + .iter_typed::() + .unwrap() + .map(|opt_s| opt_s.map(|s| s.as_bytes())); + crate::variable::encode_iter(buffer, iter, field, offsets); + }, + + D::FixedSizeBinary(_) => todo!(), + D::Decimal(_, _) => todo!(), + D::Decimal256(_, _) => todo!(), + + D::Union(_, _, _) => todo!(), + D::Map(_, _) => todo!(), + D::Extension(_, _, _) => todo!(), + D::Unknown => todo!(), + + // All are non-physical types. + D::Timestamp(_, _) + | D::Date32 + | D::Date64 + | D::Time32(_) + | D::Time64(_) + | D::Duration(_) + | D::Interval(_) => unreachable!(), + + _ => unreachable!(), } } -fn encode_primitive( - arr: &PrimitiveArray, - field: &EncodingField, - out: &mut RowsEncoded, -) { - if arr.null_count() == 0 { - unsafe { crate::fixed::encode_slice(arr.values().as_slice(), out, field) }; - } else { - unsafe { - crate::fixed::encode_iter(arr.into_iter().map(|v| v.copied()), out, field); - } +#[derive(Default)] +struct EncodeScratches { + nested_offsets: Vec, + nested_buffer: Vec, +} + +impl EncodeScratches { + fn clear(&mut self) { + self.nested_offsets.clear(); + self.nested_buffer.clear(); } } -/// Encodes an array into `out` -/// -/// # Safety -/// `out` must have enough bytes allocated otherwise it will be out of bounds. -unsafe fn encode_array(encoder: &Encoder, field: &EncodingField, out: &mut RowsEncoded) { - match encoder { - Encoder::List { .. } => { - let iter = encoder.list_iter(); - crate::variable::encode_iter(iter, out, field) +unsafe fn encode_array( + buffer: &mut [MaybeUninit], + encoder: &Encoder, + field: &EncodingField, + offsets: &mut [usize], + scratches: &mut EncodeScratches, +) { + match &encoder.state { + EncoderState::Stateless => { + encode_flat_array(buffer, encoder.array.as_ref(), field, offsets) }, - Encoder::Leaf(array) => { - match array.dtype() { - ArrowDataType::Boolean => { - let array = array.as_any().downcast_ref::().unwrap(); - crate::fixed::encode_iter(array.into_iter(), out, field); - }, - ArrowDataType::LargeBinary => { - let array = array.as_any().downcast_ref::>().unwrap(); - crate::variable::encode_iter(array.into_iter(), out, field) - }, - ArrowDataType::BinaryView => { - let array = array.as_any().downcast_ref::().unwrap(); - crate::variable::encode_iter(array.into_iter(), out, field) - }, - ArrowDataType::Utf8View => { - panic!("should be binview") - }, - ArrowDataType::Dictionary(_, _, _) => { - let array = array - .as_any() - .downcast_ref::>() - .unwrap(); - let iter = array - .iter_typed::() - .unwrap() - .map(|opt_s| opt_s.map(|s| s.as_bytes())); - crate::variable::encode_iter(iter, out, field) + EncoderState::List(nested_encoder) => { + // @TODO: make more general. + let array = encoder + .array + .as_any() + .downcast_ref::>() + .unwrap(); + + scratches.clear(); + + let total_num_bytes = nested_encoder.widths.sum(); + scratches.nested_buffer.reserve(total_num_bytes); + scratches + .nested_offsets + .reserve(1 + nested_encoder.widths.num_rows()); + + let nested_buffer = + &mut scratches.nested_buffer.spare_capacity_mut()[..total_num_bytes]; + let nested_offsets = &mut scratches.nested_offsets; + nested_offsets.push(0); + nested_encoder.widths.extend_with_offsets(nested_offsets); + + // Lists have the row encoding of the elements again encoded by the variable encoding. + // This is not ideal ([["a", "b"]] produces 100 bytes), but this is sort of how + // arrow-row works and is good enough for now. + unsafe { + encode_array( + nested_buffer, + nested_encoder, + field, + &mut nested_offsets[1..], + &mut EncodeScratches::default(), + ) + }; + let nested_buffer: &[u8] = unsafe { std::mem::transmute(nested_buffer) }; + + // @TODO: Differentiate between empty values and empty list. + match encoder.array.validity() { + None => { + crate::variable::encode_iter( + buffer, + array + .offsets() + .offset_and_length_iter() + .map(|(offset, length)| { + Some( + &nested_buffer + [nested_offsets[offset]..nested_offsets[offset + length]], + ) + }), + field, + offsets, + ); }, - ArrowDataType::Null => {}, // No output needed. - dt => { - with_match_arrow_primitive_type!(dt, |$T| { - let array = array.as_any().downcast_ref::>().unwrap(); - encode_primitive(array, field, out); - }) + Some(validity) => { + crate::variable::encode_iter( + buffer, + array + .offsets() + .offset_and_length_iter() + .zip(validity.iter()) + .map(|((offset, length), is_valid)| { + is_valid.then(|| { + &nested_buffer + [nested_offsets[offset]..nested_offsets[offset + length]] + }) + }), + field, + offsets, + ); }, - }; + } }, + EncoderState::Dictionary(_) => todo!(), + EncoderState::FixedSizeList(array, width) => { + encode_validity(buffer, encoder.array.validity(), field, offsets); + + if *width == 0 { + return; + } + + let mut child_offsets = Vec::with_capacity(offsets.len() * width); + for (i, offset) in offsets.iter_mut().enumerate() { + for j in 0..*width { + child_offsets.push(*offset); + *offset += array.widths.get((i * width) + j); + } + } + encode_array(buffer, array.as_ref(), field, &mut child_offsets, scratches); + for (i, offset) in offsets.iter_mut().enumerate() { + *offset = child_offsets[(i + 1) * width - 1]; + } + }, + EncoderState::Struct(arrays) => { + encode_validity(buffer, encoder.array.validity(), field, offsets); + + for array in arrays { + encode_array(buffer, array, field, offsets, scratches); + } + }, + } +} + +unsafe fn encode_validity( + buffer: &mut [MaybeUninit], + validity: Option<&Bitmap>, + field: &EncodingField, + row_starts: &mut [usize], +) { + let null_sentinel = get_null_sentinel(field); + match validity { + None => { + for row_start in row_starts.iter_mut() { + buffer[*row_start] = MaybeUninit::new(1); + *row_start += 1; + } + }, + Some(validity) => { + for (row_start, is_valid) in row_starts.iter_mut().zip(validity.iter()) { + let v = if is_valid { + MaybeUninit::new(1) + } else { + MaybeUninit::new(null_sentinel) + }; + buffer[*row_start] = v; + *row_start += 1; + } + }, + } +} + +unsafe fn encode_primitive( + buffer: &mut [MaybeUninit], + arr: &PrimitiveArray, + field: &EncodingField, + offsets: &mut [usize], +) { + if arr.null_count() == 0 { + crate::fixed::encode_slice(buffer, arr.values().as_slice(), field, offsets) + } else { + crate::fixed::encode_iter(buffer, arr.into_iter().map(|v| v.copied()), field, offsets) } } -pub fn encoded_size(dtype: &ArrowDataType) -> usize { +pub fn fixed_size(dtype: &ArrowDataType) -> Option { use ArrowDataType::*; - match dtype { + Some(match dtype { UInt8 => u8::ENCODED_LEN, UInt16 => u16::ENCODED_LEN, UInt32 => u32::ENCODED_LEN, @@ -295,230 +816,23 @@ pub fn encoded_size(dtype: &ArrowDataType) -> usize { Float32 => f32::ENCODED_LEN, Float64 => f64::ENCODED_LEN, Boolean => bool::ENCODED_LEN, - Null => 0, - dt => unimplemented!("{dt:?}"), - } -} - -// Returns the length that the caller must set on the `values` buf once the bytes -// are initialized. -fn allocate_rows_buf( - num_rows: usize, - columns: &mut [Encoder], - fields: &[EncodingField], - values: &mut Vec, - offsets: &mut Vec, -) -> usize { - let has_variable = columns.iter().any(|enc| enc.is_variable()); - - if has_variable { - // row size of the fixed-length columns - // those can be determined without looping over the arrays - let row_size_fixed: usize = columns - .iter() - .map(|enc| { - if enc.is_variable() { - 0 - } else { - encoded_size(enc.dtype()) - } - }) - .sum(); - - offsets.clear(); - offsets.reserve(num_rows + 1); - - // first write lengths to this buffer - let lengths = offsets; - - // for the variable length columns we must iterate to determine the length per row location - let mut processed_count = 0; - for (enc, enc_field) in columns.iter_mut().zip(fields) { - match enc { - Encoder::List { - enc: inner_enc, - rows, - field, - original, - } => { - let field = *field; - let fields = inner_enc.iter().map(|_| field).collect::>(); - // Nested lists don't yet work as that requires the leaves not only allocating, but also - // encoding. To make that work we must add a flag `in_list` that tell the leaves to immediately - // encode the rows instead of only setting the length. - // This needs a bit refactoring, might require allocation and encoding to be in - // the same function. - if let ArrowDataType::LargeList(inner) = original.dtype() { - assert!( - !matches!(inner.dtype, ArrowDataType::LargeList(_)), - "should not be nested" - ) - } - // Create the row encoding for the inner type. - let mut values_rows = RowsEncoded::default(); - - // Allocate and immediately row-encode the inner types recursively. - let values_size = allocate_rows_buf( - original.values().len(), - inner_enc, - &fields, - &mut values_rows.values, - &mut values_rows.offsets, - ); - - // For single nested it does work as we encode here. - unsafe { - for enc in inner_enc { - encode_array(enc, &field, &mut values_rows) - } - values_rows.values.set_len(values_size) - }; - let values_rows = values_rows.into_array(); - *rows = Some(values_rows); - - let iter = enc.list_iter(); - - if processed_count == 0 { - for opt_val in iter { - unsafe { - lengths.push_unchecked( - row_size_fixed + crate::variable::encoded_len(opt_val, &field), - ); - } - } - } else { - for (opt_val, row_length) in iter.zip(lengths.iter_mut()) { - *row_length += crate::variable::encoded_len(opt_val, &field) - } - } - processed_count += 1; - }, - Encoder::Leaf(array) => { - match array.dtype() { - ArrowDataType::BinaryView => { - let array = array.as_any().downcast_ref::().unwrap(); - if processed_count == 0 { - for opt_val in array.into_iter() { - unsafe { - lengths.push_unchecked( - row_size_fixed - + crate::variable::encoded_len(opt_val, enc_field), - ); - } - } - } else { - for (opt_val, row_length) in - array.into_iter().zip(lengths.iter_mut()) - { - *row_length += crate::variable::encoded_len(opt_val, enc_field) - } - } - processed_count += 1; - }, - ArrowDataType::LargeBinary => { - let array = array.as_any().downcast_ref::>().unwrap(); - if processed_count == 0 { - for opt_val in array.into_iter() { - unsafe { - lengths.push_unchecked( - row_size_fixed - + crate::variable::encoded_len(opt_val, enc_field), - ); - } - } - } else { - for (opt_val, row_length) in - array.into_iter().zip(lengths.iter_mut()) - { - *row_length += crate::variable::encoded_len(opt_val, enc_field) - } - } - processed_count += 1; - }, - ArrowDataType::Dictionary(_, _, _) => { - let array = array - .as_any() - .downcast_ref::>() - .unwrap(); - let iter = array - .iter_typed::() - .unwrap() - .map(|opt_s| opt_s.map(|s| s.as_bytes())); - if processed_count == 0 { - for opt_val in iter { - unsafe { - lengths.push_unchecked( - row_size_fixed - + crate::variable::encoded_len(opt_val, enc_field), - ) - } - } - } else { - for (opt_val, row_length) in iter.zip(lengths.iter_mut()) { - *row_length += crate::variable::encoded_len(opt_val, enc_field) - } - } - processed_count += 1; - }, - _ => { - // the rest is fixed - }, - } - }, + FixedSizeList(f, width) => 1 + width * fixed_size(f.dtype())?, + Struct(fs) => { + let mut sum = 0; + for f in fs { + sum += fixed_size(f.dtype())?; } - } - // now we use the lengths and the same buffer to determine the offsets - let offsets = lengths; - // we write lagged because the offsets will be written by the encoding column - let mut current_offset = 0_usize; - let mut lagged_offset = 0_usize; - - for length in offsets.iter_mut() { - let to_write = lagged_offset; - lagged_offset = current_offset; - current_offset += *length; - - *length = to_write; - } - // ensure we have len + 1 offsets - offsets.push(lagged_offset); - - // Only reserve. The init will be done later - values.reserve(current_offset); - current_offset - } else { - let row_size: usize = columns.iter().map(|arr| encoded_size(arr.dtype())).sum(); - let n_bytes = num_rows * row_size; - values.clear(); - values.reserve(n_bytes); - - // note that offsets are shifted to the left - // assume 2 fields with a len of 1 - // e.g. in arrow we would have 0, 2, 4, 6 - - // now we write 0, 0, 2, 4 - - // and when we encode field 1, we update the offset - // so that becomes: 0, 1, 3, 5 - - // and when the final field, field 2 is written - // the offsets are correct: - // 0, 2, 4, 6 - offsets.clear(); - offsets.reserve(num_rows + 1); - let mut current_offset = 0; - offsets.push(current_offset); - for _ in 0..num_rows { - offsets.push(current_offset); - current_offset += row_size; - } - n_bytes - } + 1 + sum + }, + Null => 1, + _ => return None, + }) } #[cfg(test)] mod test { use arrow::array::Int32Array; + use arrow::legacy::prelude::LargeListArray; use arrow::offset::Offsets; use super::*; diff --git a/crates/polars-row/src/fixed.rs b/crates/polars-row/src/fixed.rs index 315eada42ae4..0e909b6811b9 100644 --- a/crates/polars-row/src/fixed.rs +++ b/crates/polars-row/src/fixed.rs @@ -8,7 +8,7 @@ use arrow::types::NativeType; use polars_utils::slice::*; use polars_utils::total_ord::{canonical_f32, canonical_f64}; -use crate::row::{EncodingField, RowsEncoded}; +use crate::row::EncodingField; pub(crate) trait FromSlice { fn from_slice(slice: &[u8]) -> Self; @@ -142,7 +142,7 @@ impl FixedLengthEncoding for f64 { } #[inline] -fn encode_value( +unsafe fn encode_value( value: &T, offset: &mut usize, descending: bool, @@ -165,15 +165,34 @@ fn encode_value( *offset = end_offset; } +unsafe fn encode_opt_value( + opt_value: Option, + offset: &mut usize, + field: &EncodingField, + buffer: &mut [MaybeUninit], +) { + if let Some(value) = opt_value { + encode_value(&value, offset, field.descending, buffer); + } else { + unsafe { *buffer.get_unchecked_mut(*offset) = MaybeUninit::new(get_null_sentinel(field)) }; + let end_offset = *offset + T::ENCODED_LEN; + + // initialize remaining bytes + let remainder = unsafe { buffer.get_unchecked_mut(*offset + 1..end_offset) }; + remainder.fill(MaybeUninit::new(0)); + + *offset = end_offset; + } +} + pub(crate) unsafe fn encode_slice( + buffer: &mut [MaybeUninit], input: &[T], - out: &mut RowsEncoded, field: &EncodingField, + row_starts: &mut [usize], ) { - out.values.set_len(0); - let values = out.values.spare_capacity_mut(); - for (offset, value) in out.offsets.iter_mut().skip(1).zip(input) { - encode_value(value, offset, field.descending, values); + for (offset, value) in row_starts.iter_mut().zip(input) { + encode_value(value, offset, field.descending, buffer); } } @@ -187,27 +206,13 @@ pub(crate) fn get_null_sentinel(field: &EncodingField) -> u8 { } pub(crate) unsafe fn encode_iter>, T: FixedLengthEncoding>( + buffer: &mut [MaybeUninit], input: I, - out: &mut RowsEncoded, field: &EncodingField, + row_starts: &mut [usize], ) { - out.values.set_len(0); - let values = out.values.spare_capacity_mut(); - for (offset, opt_value) in out.offsets.iter_mut().skip(1).zip(input) { - if let Some(value) = opt_value { - encode_value(&value, offset, field.descending, values); - } else { - unsafe { - *values.get_unchecked_mut(*offset) = MaybeUninit::new(get_null_sentinel(field)) - }; - let end_offset = *offset + T::ENCODED_LEN; - - // initialize remaining bytes - let remainder = values.get_unchecked_mut(*offset + 1..end_offset); - remainder.fill(MaybeUninit::new(0)); - - *offset = end_offset; - } + for (offset, opt_value) in row_starts.iter_mut().zip(input) { + encode_opt_value(opt_value, offset, field, buffer); } } diff --git a/crates/polars-row/src/variable.rs b/crates/polars-row/src/variable.rs index f7485d44704a..15f25143348f 100644 --- a/crates/polars-row/src/variable.rs +++ b/crates/polars-row/src/variable.rs @@ -19,7 +19,6 @@ use arrow::offset::Offsets; use polars_utils::slice::Slice2Uninit; use crate::fixed::{decode_nulls, get_null_sentinel}; -use crate::row::RowsEncoded; use crate::EncodingField; /// The block size of the variable length encoding @@ -48,11 +47,11 @@ fn padded_length(a: usize) -> usize { } #[inline] -pub fn encoded_len(a: Option<&[u8]>, field: &EncodingField) -> usize { +pub fn encoded_len_from_len(a: Option, field: &EncodingField) -> usize { if field.no_order { - 4 + a.map(|v| v.len()).unwrap_or(0) + 4 + a.unwrap_or(0) } else { - a.map(|v| padded_length(v.len())).unwrap_or(1) + a.map_or(1, padded_length) } } @@ -162,32 +161,26 @@ unsafe fn encode_one( }, } } + pub(crate) unsafe fn encode_iter<'a, I: Iterator>>( + buffer: &mut [MaybeUninit], input: I, - out: &mut RowsEncoded, field: &EncodingField, + row_starts: &mut [usize], ) { - out.values.set_len(0); - let values = out.values.spare_capacity_mut(); - if field.no_order { - for (offset, opt_value) in out.offsets.iter_mut().skip(1).zip(input) { - let dst = values.get_unchecked_mut(*offset..); + for (offset, opt_value) in row_starts.iter_mut().zip(input) { + let dst = buffer.get_unchecked_mut(*offset..); let written_len = encode_one_no_order(dst, opt_value.map(|v| v.as_uninit()), field); *offset += written_len; } } else { - for (offset, opt_value) in out.offsets.iter_mut().skip(1).zip(input) { - let dst = values.get_unchecked_mut(*offset..); + for (offset, opt_value) in row_starts.iter_mut().zip(input) { + let dst = buffer.get_unchecked_mut(*offset..); let written_len = encode_one(dst, opt_value.map(|v| v.as_uninit()), field); *offset += written_len; } } - let offset = out.offsets.last().unwrap(); - let dst = values.get_unchecked_mut(*offset..); - // write remainder as zeros - dst.fill(MaybeUninit::new(0)); - out.values.set_len(out.values.capacity()) } unsafe fn has_nulls(rows: &[&[u8]], null_sentinel: u8) -> bool { @@ -195,6 +188,27 @@ unsafe fn has_nulls(rows: &[&[u8]], null_sentinel: u8) -> bool { .any(|row| *row.get_unchecked(0) == null_sentinel) } +pub(crate) unsafe fn encoded_item_len( + row: &[u8], + non_empty_sentinel: u8, + continuation_token: u8, +) -> usize { + // empty or null + if *row.get_unchecked(0) != non_empty_sentinel { + return 1; + } + + let mut idx = 1; + loop { + let sentinel = *row.get_unchecked(idx + BLOCK_SIZE); + if sentinel == continuation_token { + idx += BLOCK_SIZE + 1; + continue; + } + return idx + BLOCK_SIZE + 1; + } +} + unsafe fn decoded_len( row: &[u8], non_empty_sentinel: u8, diff --git a/py-polars/tests/unit/test_row_encoding.py b/py-polars/tests/unit/test_row_encoding.py index 0ae0cb9c34f4..66428f04250e 100644 --- a/py-polars/tests/unit/test_row_encoding.py +++ b/py-polars/tests/unit/test_row_encoding.py @@ -1,18 +1,22 @@ from __future__ import annotations +from typing import TYPE_CHECKING, Literal + import pytest from hypothesis import given import polars as pl -from polars.testing import assert_frame_equal -from polars.testing.parametric import dataframes +from polars.testing import assert_frame_equal, assert_series_equal +from polars.testing.parametric import column, dataframes + +if TYPE_CHECKING: + from polars._typing import PolarsDataType -# @TODO: Deal with no_order FIELD_COMBS = [ (descending, nulls_last, False) for descending in [False, True] for nulls_last in [False, True] -] +] + [(False, False, True)] def roundtrip_re( @@ -22,20 +26,29 @@ def roundtrip_re( fields = [(False, False, False)] * df.width row_encoded = df._row_encode(fields) + if any(f[2] for f in fields): + return + dtypes = [(c, df.get_column(c).dtype) for c in df.columns] result = row_encoded._row_decode(dtypes, fields) assert_frame_equal(df, result) +def roundtrip_series_re( + values: pl.series.series.ArrayLike, + dtype: PolarsDataType, + field: tuple[bool, bool, bool], +) -> None: + roundtrip_re(pl.Series("series", values, dtype).to_frame(), [field]) + + @given( df=dataframes( excluded_dtypes=[ - pl.List, - pl.Array, - pl.Struct, pl.Categorical, pl.Enum, + pl.Decimal, ] ) ) @@ -48,20 +61,20 @@ def test_row_encoding_parametric( @pytest.mark.parametrize("field", FIELD_COMBS) def test_nulls(field: tuple[bool, bool, bool]) -> None: - roundtrip_re(pl.Series("a", [], pl.Null).to_frame(), [field]) - roundtrip_re(pl.Series("a", [None], pl.Null).to_frame(), [field]) - roundtrip_re(pl.Series("a", [None] * 2, pl.Null).to_frame(), [field]) - roundtrip_re(pl.Series("a", [None] * 13, pl.Null).to_frame(), [field]) - roundtrip_re(pl.Series("a", [None] * 42, pl.Null).to_frame(), [field]) + roundtrip_series_re([], pl.Null, field) + roundtrip_series_re([None], pl.Null, field) + roundtrip_series_re([None] * 2, pl.Null, field) + roundtrip_series_re([None] * 13, pl.Null, field) + roundtrip_series_re([None] * 42, pl.Null, field) @pytest.mark.parametrize("field", FIELD_COMBS) def test_bool(field: tuple[bool, bool, bool]) -> None: - roundtrip_re(pl.Series("a", [], pl.Boolean).to_frame(), [field]) - roundtrip_re(pl.Series("a", [False], pl.Boolean).to_frame(), [field]) - roundtrip_re(pl.Series("a", [True], pl.Boolean).to_frame(), [field]) - roundtrip_re(pl.Series("a", [False, True], pl.Boolean).to_frame(), [field]) - roundtrip_re(pl.Series("a", [True, False], pl.Boolean).to_frame(), [field]) + roundtrip_series_re([], pl.Boolean, field) + roundtrip_series_re([False], pl.Boolean, field) + roundtrip_series_re([True], pl.Boolean, field) + roundtrip_series_re([False, True], pl.Boolean, field) + roundtrip_series_re([True, False], pl.Boolean, field) @pytest.mark.parametrize( @@ -82,14 +95,14 @@ def test_int(dtype: pl.DataType, field: tuple[bool, bool, bool]) -> None: min = pl.select(x=dtype.min()).item() # type: ignore[attr-defined] max = pl.select(x=dtype.max()).item() # type: ignore[attr-defined] - roundtrip_re(pl.Series("a", [], dtype).to_frame(), [field]) - roundtrip_re(pl.Series("a", [0], dtype).to_frame(), [field]) - roundtrip_re(pl.Series("a", [min], dtype).to_frame(), [field]) - roundtrip_re(pl.Series("a", [max], dtype).to_frame(), [field]) + roundtrip_series_re([], dtype, field) + roundtrip_series_re([0], dtype, field) + roundtrip_series_re([min], dtype, field) + roundtrip_series_re([max], dtype, field) - roundtrip_re(pl.Series("a", [1, 2, 3], dtype).to_frame(), [field]) - roundtrip_re(pl.Series("a", [0, 1, 2, 3], dtype).to_frame(), [field]) - roundtrip_re(pl.Series("a", [min, 0, max], dtype).to_frame(), [field]) + roundtrip_series_re([1, 2, 3], dtype, field) + roundtrip_series_re([0, 1, 2, 3], dtype, field) + roundtrip_series_re([min, 0, max], dtype, field) @pytest.mark.parametrize( @@ -104,46 +117,310 @@ def test_float(dtype: pl.DataType, field: tuple[bool, bool, bool]) -> None: inf = float("inf") inf_b = float("-inf") - roundtrip_re(pl.Series("a", [], dtype).to_frame(), [field]) - roundtrip_re(pl.Series("a", [0.0], dtype).to_frame(), [field]) - roundtrip_re(pl.Series("a", [inf], dtype).to_frame(), [field]) - roundtrip_re(pl.Series("a", [-inf_b], dtype).to_frame(), [field]) + roundtrip_series_re([], dtype, field) + roundtrip_series_re([0.0], dtype, field) + roundtrip_series_re([inf], dtype, field) + roundtrip_series_re([-inf_b], dtype, field) - roundtrip_re(pl.Series("a", [1.0, 2.0, 3.0], dtype).to_frame(), [field]) - roundtrip_re(pl.Series("a", [0.0, 1.0, 2.0, 3.0], dtype).to_frame(), [field]) - roundtrip_re(pl.Series("a", [inf, 0, -inf_b], dtype).to_frame(), [field]) + roundtrip_series_re([1.0, 2.0, 3.0], dtype, field) + roundtrip_series_re([0.0, 1.0, 2.0, 3.0], dtype, field) + roundtrip_series_re([inf, 0, -inf_b], dtype, field) @pytest.mark.parametrize("field", FIELD_COMBS) def test_str(field: tuple[bool, bool, bool]) -> None: - roundtrip_re(pl.Series("a", [], pl.String).to_frame(), [field]) - roundtrip_re(pl.Series("a", [""], pl.String).to_frame(), [field]) + dtype = pl.String + roundtrip_series_re([], dtype, field) + roundtrip_series_re([""], dtype, field) - roundtrip_re(pl.Series("a", ["a", "b", "c"], pl.String).to_frame(), [field]) - roundtrip_re(pl.Series("a", ["", "a", "b", "c"], pl.String).to_frame(), [field]) + roundtrip_series_re(["a", "b", "c"], dtype, field) + roundtrip_series_re(["", "a", "b", "c"], dtype, field) - roundtrip_re( - pl.Series("a", ["different", "length", "strings"], pl.String).to_frame(), - [field], + roundtrip_series_re( + ["different", "length", "strings"], + dtype, + field, ) - roundtrip_re( - pl.Series( - "a", ["different", "", "length", "", "strings"], pl.String - ).to_frame(), - [field], + roundtrip_series_re( + ["different", "", "length", "", "strings"], + dtype, + field, ) @pytest.mark.parametrize("field", FIELD_COMBS) def test_struct(field: tuple[bool, bool, bool]) -> None: - roundtrip_re(pl.Series("a", [], pl.Struct({})).to_frame()) - roundtrip_re(pl.Series("a", [{}], pl.Struct({})).to_frame()) - roundtrip_re( - pl.Series("a", [{"x": 1}], pl.Struct({"x": pl.Int32})).to_frame(), [field] + dtype = pl.Struct({}) + roundtrip_series_re([], dtype, field) + roundtrip_series_re([None], dtype, field) + roundtrip_series_re([{}], dtype, field) + roundtrip_series_re([{}, {}, {}], dtype, field) + roundtrip_series_re([{}, None, {}], dtype, field) + + dtype = pl.Struct({"x": pl.Int32}) + roundtrip_series_re([{"x": 1}], dtype, field) + roundtrip_series_re([None], dtype, field) + roundtrip_series_re([{"x": 1}] * 3, dtype, field) + + dtype = pl.Struct({"x": pl.Int32, "y": pl.Int32}) + roundtrip_series_re( + [{"x": 1}, {"y": 2}], + dtype, + field, + ) + roundtrip_series_re([None], dtype, field) + + +@pytest.mark.parametrize("field", FIELD_COMBS) +def test_list(field: tuple[bool, bool, bool]) -> None: + dtype = pl.List(pl.Int32) + roundtrip_series_re([], dtype, field) + roundtrip_series_re([[]], dtype, field) + roundtrip_series_re([[1], [2]], dtype, field) + roundtrip_series_re([[1, 2], [3]], dtype, field) + roundtrip_series_re([[1, 2], [], [3]], dtype, field) + roundtrip_series_re([None, [1, 2], None, [], [3]], dtype, field) + + dtype = pl.List(pl.String) + roundtrip_series_re([], dtype, field) + roundtrip_series_re([[]], dtype, field) + roundtrip_series_re([[""], [""]], dtype, field) + roundtrip_series_re([["abc"], ["xyzw"]], dtype, field) + roundtrip_series_re([["x", "yx"], ["abc"]], dtype, field) + roundtrip_series_re([["wow", "this is"], [], ["cool"]], dtype, field) + roundtrip_series_re( + [None, ["very", "very"], None, [], ["cool"]], + dtype, + field, + ) + + +@pytest.mark.parametrize("field", FIELD_COMBS) +def test_array(field: tuple[bool, bool, bool]) -> None: + dtype = pl.Array(pl.Int32, 0) + roundtrip_series_re([], dtype, field) + roundtrip_series_re([[]], dtype, field) + roundtrip_series_re([None, [], None], dtype, field) + roundtrip_series_re([None], dtype, field) + + dtype = pl.Array(pl.Int32, 2) + roundtrip_series_re([], dtype, field) + roundtrip_series_re([[5, 6]], dtype, field) + roundtrip_series_re([[1, 2], [2, 3]], dtype, field) + roundtrip_series_re([[1, 2], [3, 7]], dtype, field) + roundtrip_series_re([[1, 2], [13, 11], [3, 7]], dtype, field) + roundtrip_series_re( + [None, [1, 2], None, [13, 11], [5, 7]], + dtype, + field, ) + + dtype = pl.Array(pl.String, 2) + roundtrip_series_re([], dtype, field) + roundtrip_series_re([["a", "b"]], dtype, field) + roundtrip_series_re([["", ""], ["", "a"]], dtype, field) + roundtrip_series_re([["abc", "def"], ["ghi", "xyzw"]], dtype, field) + roundtrip_series_re([["x", "yx"], ["abc", "xxx"]], dtype, field) + roundtrip_series_re( + [["wow", "this is"], ["soo", "so"], ["veryyy", "cool"]], + dtype, + field, + ) + roundtrip_series_re( + [None, ["very", "very"], None, [None, None], ["verryy", "cool"]], + dtype, + field, + ) + + +@pytest.mark.parametrize("field", FIELD_COMBS) +def test_list_arr(field: tuple[bool, bool, bool]) -> None: + dtype = pl.List(pl.Array(pl.String, 2)) + roundtrip_series_re([], dtype, field) + roundtrip_series_re([None], dtype, field) + roundtrip_series_re([[None]], dtype, field) + roundtrip_series_re([[[None, None]]], dtype, field) + roundtrip_series_re([[["a", "b"]]], dtype, field) + roundtrip_series_re([[["a", "b"], ["xyz", "wowie"]]], dtype, field) + roundtrip_series_re([[["a", "b"]], None, [None, None]], dtype, field) + + +@pytest.mark.parametrize("field", FIELD_COMBS) +def test_list_struct_arr(field: tuple[bool, bool, bool]) -> None: + dtype = pl.List( + pl.Struct({"x": pl.Array(pl.String, 2), "y": pl.Array(pl.Int64, 3)}) + ) + roundtrip_series_re([], dtype, field) + roundtrip_series_re([None], dtype, field) + roundtrip_series_re([[None]], dtype, field) + roundtrip_series_re([[{"x": None, "y": None}]], dtype, field) + roundtrip_series_re([[{"x": ["a", None], "y": [1, None, 3]}]], dtype, field) + roundtrip_series_re([[{"x": ["a", "xyz"], "y": [1, 7, 3]}]], dtype, field) + roundtrip_series_re([[{"x": ["a", "xyz"], "y": [1, 7, 3]}], []], dtype, field) + + +@pytest.mark.parametrize("field", FIELD_COMBS) +def test_list_nulls(field: tuple[bool, bool, bool]) -> None: + dtype = pl.List(pl.Null) + roundtrip_series_re([], dtype, field) + roundtrip_series_re([[]], dtype, field) + roundtrip_series_re([None], dtype, field) + roundtrip_series_re([[None]], dtype, field) + roundtrip_series_re([[None, None, None]], dtype, field) + roundtrip_series_re([[None], [None, None], [None, None, None]], dtype, field) + + +def test_int_after_null() -> None: roundtrip_re( - pl.Series( - "a", [{"x": 1}, {"y": 2}], pl.Struct({"x": pl.Int32, "y": pl.Int32}) - ).to_frame(), - [field], + pl.DataFrame( + [ + pl.Series("a", [None], pl.Null), + pl.Series("b", [None], pl.Int8), + ] + ), + [(False, True, False), (False, True, False)], + ) + + +def assert_order_dataframe( + lhs: pl.DataFrame, + rhs: pl.DataFrame, + order: list[Literal["lt", "eq", "gt"]], + *, + descending: bool = False, + nulls_last: bool = False, +) -> None: + field = (descending, nulls_last, False) + l_re = lhs._row_encode([field] * lhs.width).cast(pl.Binary) + r_re = rhs._row_encode([field] * rhs.width).cast(pl.Binary) + + l_lt_r_s = "gt" if descending else "lt" + l_gt_r_s = "lt" if descending else "gt" + + assert_series_equal( + l_re < r_re, pl.Series([o == l_lt_r_s for o in order]), check_names=False + ) + assert_series_equal( + l_re == r_re, pl.Series([o == "eq" for o in order]), check_names=False + ) + assert_series_equal( + l_re > r_re, pl.Series([o == l_gt_r_s for o in order]), check_names=False + ) + + +def assert_order_series( + lhs: pl.series.series.ArrayLike, + rhs: pl.series.series.ArrayLike, + dtype: pl._typing.PolarsDataType, + order: list[Literal["lt", "eq", "gt"]], + *, + descending: bool = False, + nulls_last: bool = False, +) -> None: + lhs = pl.Series("lhs", lhs, dtype).to_frame() + rhs = pl.Series("rhs", rhs, dtype).to_frame() + assert_order_dataframe( + lhs, rhs, order, descending=descending, nulls_last=nulls_last + ) + + +def parametric_order_base(df: pl.DataFrame) -> None: + lhs = df.get_columns()[0] + rhs = df.get_columns()[1] + + field = (False, False, False) + lhs_re = lhs.to_frame()._row_encode([field]).cast(pl.Binary) + rhs_re = rhs.to_frame()._row_encode([field]).cast(pl.Binary) + + assert_series_equal(lhs < rhs, lhs_re < rhs_re, check_names=False) + assert_series_equal(lhs == rhs, lhs_re == rhs_re, check_names=False) + assert_series_equal(lhs > rhs, lhs_re > rhs_re, check_names=False) + + field = (True, False, False) + lhs_re = lhs.to_frame()._row_encode([field]).cast(pl.Binary) + rhs_re = rhs.to_frame()._row_encode([field]).cast(pl.Binary) + + assert_series_equal(lhs > rhs, lhs_re < rhs_re, check_names=False) + assert_series_equal(lhs == rhs, lhs_re == rhs_re, check_names=False) + assert_series_equal(lhs < rhs, lhs_re > rhs_re, check_names=False) + + +@given( + df=dataframes([column(dtype=pl.Int32), column(dtype=pl.Int32)], allow_null=False) +) +def test_parametric_int_order(df: pl.DataFrame) -> None: + parametric_order_base(df) + + +@given( + df=dataframes([column(dtype=pl.UInt32), column(dtype=pl.UInt32)], allow_null=False) +) +def test_parametric_uint_order(df: pl.DataFrame) -> None: + parametric_order_base(df) + + +@given( + df=dataframes([column(dtype=pl.String), column(dtype=pl.String)], allow_null=False) +) +def test_parametric_string_order(df: pl.DataFrame) -> None: + parametric_order_base(df) + + +@given( + df=dataframes([column(dtype=pl.Binary), column(dtype=pl.Binary)], allow_null=False) +) +def test_parametric_binary_order(df: pl.DataFrame) -> None: + parametric_order_base(df) + + +def test_order_int() -> None: + dtype = pl.Int32 + assert_order_series([1, 2, 3], [3, 2, 1], dtype, ["lt", "eq", "gt"]) + assert_order_series([-1, 0, 1], [1, 0, -1], dtype, ["lt", "eq", "gt"]) + assert_order_series([None], [None], dtype, ["eq"]) + assert_order_series([None], [1], dtype, ["lt"]) + assert_order_series([None], [1], dtype, ["gt"], nulls_last=True) + + +def test_order_uint() -> None: + dtype = pl.UInt32 + assert_order_series([1, 2, 3], [3, 2, 1], dtype, ["lt", "eq", "gt"]) + assert_order_series([None], [None], dtype, ["eq"]) + assert_order_series([None], [1], dtype, ["lt"]) + assert_order_series([None], [1], dtype, ["gt"], nulls_last=True) + + +def test_order_list() -> None: + dtype = pl.List(pl.Int32) + assert_order_series([[1, 2, 3]], [[3, 2, 1]], dtype, ["lt"]) + assert_order_series([[-1, 0, 1]], [[1, 0, -1]], dtype, ["lt"]) + assert_order_series([None], [None], dtype, ["eq"]) + assert_order_series([None], [[1, 2, 3]], dtype, ["lt"]) + assert_order_series([None], [[1, 2, 3]], dtype, ["gt"], nulls_last=True) + assert_order_series([[None, 2, 3]], [[None, 2, 1]], dtype, ["gt"]) + + +def test_order_array() -> None: + dtype = pl.Array(pl.Int32, 3) + assert_order_series([[1, 2, 3]], [[3, 2, 1]], dtype, ["lt"]) + assert_order_series([[-1, 0, 1]], [[1, 0, -1]], dtype, ["lt"]) + assert_order_series([None], [None], dtype, ["eq"]) + assert_order_series([None], [[1, 2, 3]], dtype, ["lt"]) + assert_order_series([None], [[1, 2, 3]], dtype, ["gt"], nulls_last=True) + assert_order_series([[None, 2, 3]], [[None, 2, 1]], dtype, ["gt"]) + + +def test_order_masked_array() -> None: + dtype = pl.Array(pl.Int32, 3) + lhs = pl.Series("l", [1, 2, 3], pl.Int32).replace(1, None).reshape((1, 3)) + rhs = pl.Series("r", [3, 2, 1], pl.Int32).replace(3, None).reshape((1, 3)) + assert_order_series(lhs, rhs, dtype, ["gt"]) + + +def test_order_masked_struct() -> None: + dtype = pl.Array(pl.Int32, 3) + lhs = pl.Series("l", [1, 2, 3], pl.Int32).replace(1, None).reshape((1, 3)) + rhs = pl.Series("r", [3, 2, 1], pl.Int32).replace(3, None).reshape((1, 3)) + assert_order_series( + lhs.to_frame().to_struct(), rhs.to_frame().to_struct(), dtype, ["gt"] )