diff --git a/crates/polars-row/src/decode.rs b/crates/polars-row/src/decode.rs index 060733681de1..59eddd9f655b 100644 --- a/crates/polars-row/src/decode.rs +++ b/crates/polars-row/src/decode.rs @@ -94,17 +94,25 @@ fn dtype_and_data_to_encoded_item_len( use ArrowDataType as D; match dtype { - D::Binary - | D::LargeBinary - | D::Utf8 - | D::LargeUtf8 - | D::List(_) - | D::LargeList(_) - | D::BinaryView - | D::Utf8View => unsafe { + D::Binary | D::LargeBinary | D::Utf8 | D::LargeUtf8 | D::BinaryView | D::Utf8View => unsafe { crate::variable::encoded_item_len(data, non_empty_sentinel, continuation_token) }, + D::List(list_field) | D::LargeList(list_field) => { + let mut data = data; + let mut item_len = 0; + + let list_continuation_token = field.list_continuation_token(); + + while data[0] == list_continuation_token { + data = &data[1..]; + let len = dtype_and_data_to_encoded_item_len(list_field.dtype(), data, field); + data = &data[len..]; + item_len += 1 + len; + } + 1 + item_len + }, + D::FixedSizeBinary(_) => todo!(), D::FixedSizeList(fsl_field, width) => { let mut data = &data[1..]; @@ -162,130 +170,20 @@ fn rows_for_fixed_size_list<'a>( return; } - use ArrowDataType as D; - match dtype { - D::FixedSizeBinary(_) => todo!(), - D::BinaryView - | D::Utf8View - | D::Binary - | D::LargeBinary - | D::Utf8 - | D::LargeUtf8 - | D::List(_) - | D::LargeList(_) => { - let (non_empty_sentinel, continuation_token) = if field.descending { - ( - !variable::NON_EMPTY_SENTINEL, - !variable::BLOCK_CONTINUATION_TOKEN, - ) - } else { - ( - variable::NON_EMPTY_SENTINEL, - variable::BLOCK_CONTINUATION_TOKEN, - ) - }; - - for row in rows.iter_mut() { - for _ in 0..width { - let length = unsafe { - crate::variable::encoded_item_len( - row, - non_empty_sentinel, - continuation_token, - ) - }; - let v; - (v, *row) = row.split_at(length); - nested_rows.push(v); - } - } - }, - _ => { - // @TODO: This is quite slow since we need to dispatch for possibly every nested type - for row in rows.iter_mut() { - for _ in 0..width { - let length = dtype_and_data_to_encoded_item_len(dtype, row, field); - let v; - (v, *row) = row.split_at(length); - nested_rows.push(v); - } - } - }, - } -} - -fn offsets_from_dtype_and_data( - dtype: &ArrowDataType, - field: &EncodingField, - data: &[u8], - offsets: &mut Vec, -) { - offsets.clear(); - - // Fast path: if the size is fixed, we can just divide. - if let Some(size) = fixed_size(dtype) { - assert!(size == 0 || data.len() % size == 0); - offsets.extend((0..data.len() / size).map(|i| i * size)); - return; - } - - use ArrowDataType as D; - match dtype { - D::FixedSizeBinary(_) => todo!(), - D::BinaryView - | D::Utf8View - | D::Binary - | D::LargeBinary - | D::Utf8 - | D::LargeUtf8 - | D::List(_) - | D::LargeList(_) => { - let mut data = data; - let (non_empty_sentinel, continuation_token) = if field.descending { - ( - !variable::NON_EMPTY_SENTINEL, - !variable::BLOCK_CONTINUATION_TOKEN, - ) - } else { - ( - variable::NON_EMPTY_SENTINEL, - variable::BLOCK_CONTINUATION_TOKEN, - ) - }; - let mut offset = 0; - while !data.is_empty() { - let length = unsafe { - crate::variable::encoded_item_len(data, non_empty_sentinel, continuation_token) - }; - offsets.push(offset); - data = &data[length..]; - offset += length; - } - }, - _ => { - // @TODO: This is quite slow since we need to dispatch for possibly every nested type - let mut data = data; - let mut offset = 0; - while !data.is_empty() { - let length = dtype_and_data_to_encoded_item_len(dtype, data, field); - offsets.push(offset); - data = &data[length..]; - offset += length; - } - }, + // @TODO: This is quite slow since we need to dispatch for possibly every nested type + for row in rows.iter_mut() { + for _ in 0..width { + let length = dtype_and_data_to_encoded_item_len(dtype, row, field); + let v; + (v, *row) = row.split_at(length); + nested_rows.push(v); + } } } unsafe fn decode(rows: &mut [&[u8]], field: &EncodingField, dtype: &ArrowDataType) -> ArrayRef { match dtype { - ArrowDataType::Null => { - // Temporary: remove when list encoding is better. - for row in rows.iter_mut() { - *row = &row[1..]; - } - - NullArray::new(ArrowDataType::Null, rows.len()).to_boxed() - }, + ArrowDataType::Null => NullArray::new(ArrowDataType::Null, rows.len()).to_boxed(), ArrowDataType::Boolean => decode_bool(rows, field).to_boxed(), ArrowDataType::BinaryView | ArrowDataType::LargeBinary => { decode_binview(rows, field).to_boxed() @@ -323,36 +221,51 @@ unsafe fn decode(rows: &mut [&[u8]], field: &EncodingField, dtype: &ArrowDataTyp FixedSizeListArray::new(dtype.clone(), rows.len(), values, validity).to_boxed() }, ArrowDataType::List(list_field) | ArrowDataType::LargeList(list_field) => { - let arr = decode_binary(rows, field); + let mut validity = MutableBitmap::new(); - let mut offsets = Vec::with_capacity(rows.len()); // @TODO: we could consider making this into a scratchpad - let mut nested_offsets = Vec::new(); - offsets_from_dtype_and_data( - list_field.dtype(), - field, - arr.values().as_ref(), - &mut nested_offsets, - ); - // @TODO: This might cause realloc, fix. - nested_offsets.push(arr.values().len()); - let mut nested_rows = nested_offsets - .windows(2) - .map(|vs| &arr.values()[vs[0]..vs[1]]) - .collect::>(); - - let mut i = 0; - for offset in arr.offsets().iter() { - while nested_offsets[i] != offset.as_usize() { - i += 1; + let num_rows = rows.len(); + let mut nested_rows = Vec::new(); + let mut offsets = Vec::with_capacity(rows.len() + 1); + offsets.push(0); + + let list_null_sentinel = field.list_null_sentinel(); + let list_continuation_token = field.list_continuation_token(); + let list_termination_token = field.list_termination_token(); + + // @TODO: make a specialized loop for fixed size list_field.dtype() + for (i, row) in rows.iter_mut().enumerate() { + while row[0] == list_continuation_token { + *row = &row[1..]; + let len = dtype_and_data_to_encoded_item_len(list_field.dtype(), row, field); + nested_rows.push(&row[..len]); + *row = &row[len..]; } - offsets.push(i as i64); + offsets.push(nested_rows.len() as i64); + + // @TODO: Might be better to make this a 2-loop system. + if row[0] == list_null_sentinel { + *row = &row[1..]; + validity.reserve(num_rows); + validity.extend_constant(i - validity.len(), true); + validity.push(false); + continue; + } + + assert_eq!(row[0], list_termination_token); + *row = &row[1..]; } + + let validity = if validity.is_empty() { + None + } else { + validity.extend_constant(num_rows - validity.len(), true); + Some(validity.freeze()) + }; assert_eq!(offsets.len(), rows.len() + 1); let values = decode(&mut nested_rows, field, list_field.dtype()); - let (_, _, _, validity) = arr.into_inner(); ListArray::::new( dtype.clone(), diff --git a/crates/polars-row/src/encode.rs b/crates/polars-row/src/encode.rs index a288dc67659d..61a4819b2c8b 100644 --- a/crates/polars-row/src/encode.rs +++ b/crates/polars-row/src/encode.rs @@ -266,18 +266,14 @@ fn list_num_column_bytes( let mut list_row_widths = RowWidths::new(values.len()); let encoder = get_encoder(values.as_ref(), field, &mut list_row_widths); - // @TODO: make specialized implementation for list_row_widths is RowWidths::Constant - let mut offsets = Vec::with_capacity(list_row_widths.num_rows() + 1); - list_row_widths.extend_with_offsets(&mut offsets); - offsets.push(encoder.widths.sum()); - let widths = match array.validity() { None => row_widths.append_iter(array.offsets().offset_and_length_iter().map( |(offset, length)| { - crate::variable::encoded_len_from_len( - Some(offsets[offset + length] - offsets[offset]), - field, - ) + let mut sum = 0; + for i in offset..offset + length { + sum += list_row_widths.get(i); + } + 1 + length + sum }, )), Some(validity) => row_widths.append_iter( @@ -286,10 +282,15 @@ fn list_num_column_bytes( .offset_and_length_iter() .zip(validity.iter()) .map(|((offset, length), is_valid)| { - crate::variable::encoded_len_from_len( - is_valid.then_some(offsets[offset + length] - offsets[offset]), - field, - ) + if !is_valid { + return 1; + } + + let mut sum = 0; + for i in offset..offset + length { + sum += list_row_widths.get(i); + } + 1 + length + sum }), ), }; @@ -543,14 +544,7 @@ unsafe fn encode_flat_array( ) { use ArrowDataType as D; match array.dtype() { - D::Null => { - // @NOTE: This is an artifact of the list encoding, this can be removed when we have a - // better list encoding. - for offset in offsets.iter_mut() { - buffer[*offset] = MaybeUninit::new(0); - *offset += 1; - } - }, + D::Null => {}, D::Boolean => { let array = array.as_any().downcast_ref::().unwrap(); crate::fixed::encode_bool_iter(buffer, array.iter(), field, offsets); @@ -667,68 +661,69 @@ unsafe fn encode_array( scratches.clear(); - let total_num_bytes = nested_encoder.widths.sum(); - scratches.nested_buffer.reserve(total_num_bytes); scratches .nested_offsets - .reserve(1 + nested_encoder.widths.num_rows()); - - let nested_buffer = - &mut scratches.nested_buffer.spare_capacity_mut()[..total_num_bytes]; + .reserve(nested_encoder.widths.num_rows()); let nested_offsets = &mut scratches.nested_offsets; - nested_offsets.push(0); - nested_encoder.widths.extend_with_offsets(nested_offsets); + + let list_null_sentinel = field.list_null_sentinel(); + let list_continuation_token = field.list_continuation_token(); + let list_termination_token = field.list_termination_token(); + + match array.validity() { + None => { + for (i, (offset, length)) in + array.offsets().offset_and_length_iter().enumerate() + { + for j in offset..offset + length { + buffer[offsets[i]] = MaybeUninit::new(list_continuation_token); + offsets[i] += 1; + + nested_offsets.push(offsets[i]); + offsets[i] += nested_encoder.widths.get(j); + } + buffer[offsets[i]] = MaybeUninit::new(list_termination_token); + offsets[i] += 1; + } + }, + Some(validity) => { + for (i, ((offset, length), is_valid)) in array + .offsets() + .offset_and_length_iter() + .zip(validity.iter()) + .enumerate() + { + if !is_valid { + buffer[offsets[i]] = MaybeUninit::new(list_null_sentinel); + offsets[i] += 1; + continue; + } + + for j in offset..offset + length { + buffer[offsets[i]] = MaybeUninit::new(list_continuation_token); + offsets[i] += 1; + + nested_offsets.push(offsets[i]); + offsets[i] += nested_encoder.widths.get(j); + } + buffer[offsets[i]] = MaybeUninit::new(list_termination_token); + offsets[i] += 1; + } + }, + } // Lists have the row encoding of the elements again encoded by the variable encoding. // This is not ideal ([["a", "b"]] produces 100 bytes), but this is sort of how // arrow-row works and is good enough for now. unsafe { encode_array( - nested_buffer, + buffer, nested_encoder, field, - &mut nested_offsets[1..], + nested_offsets, &mut EncodeScratches::default(), ) }; - let nested_buffer: &[u8] = unsafe { std::mem::transmute(nested_buffer) }; - - // @TODO: Differentiate between empty values and empty list. - match encoder.array.validity() { - None => { - crate::variable::encode_iter( - buffer, - array - .offsets() - .offset_and_length_iter() - .map(|(offset, length)| { - Some( - &nested_buffer - [nested_offsets[offset]..nested_offsets[offset + length]], - ) - }), - field, - offsets, - ); - }, - Some(validity) => { - crate::variable::encode_iter( - buffer, - array - .offsets() - .offset_and_length_iter() - .zip(validity.iter()) - .map(|((offset, length), is_valid)| { - is_valid.then(|| { - &nested_buffer - [nested_offsets[offset]..nested_offsets[offset + length]] - }) - }), - field, - offsets, - ); - }, - } }, EncoderState::Dictionary(_) => todo!(), EncoderState::FixedSizeList(array, width) => { @@ -824,7 +819,7 @@ pub fn fixed_size(dtype: &ArrowDataType) -> Option { } 1 + sum }, - Null => 1, + Null => 0, _ => return None, }) } @@ -832,8 +827,6 @@ pub fn fixed_size(dtype: &ArrowDataType) -> Option { #[cfg(test)] mod test { use arrow::array::Int32Array; - use arrow::legacy::prelude::LargeListArray; - use arrow::offset::Offsets; use super::*; use crate::decode::decode_rows_from_binary; @@ -934,28 +927,4 @@ mod test { assert_eq!(decoded, &a); } } - - #[test] - fn test_list_encode() { - let values = Utf8ViewArray::from_slice_values([ - "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten", - ]); - let dtype = LargeListArray::default_datatype(values.dtype().clone()); - let array = LargeListArray::new( - dtype, - Offsets::::try_from(vec![0i64, 1, 4, 7, 7, 9, 10]) - .unwrap() - .into(), - values.boxed(), - None, - ); - let fields = &[EncodingField::new_sorted(true, false)]; - - let out = convert_columns(array.len(), &[array.boxed()], fields); - let out = out.into_array(); - assert_eq!( - out.values().iter().map(|v| *v as usize).sum::(), - 42774 - ); - } } diff --git a/crates/polars-row/src/row.rs b/crates/polars-row/src/row.rs index 6eb8deceebad..9ee5a74143ee 100644 --- a/crates/polars-row/src/row.rs +++ b/crates/polars-row/src/row.rs @@ -18,6 +18,8 @@ pub struct EncodingField { pub no_order: bool, } +const LIST_CONTINUATION_TOKEN: u8 = 0xFE; + impl EncodingField { pub fn new_sorted(descending: bool, nulls_last: bool) -> Self { EncodingField { @@ -49,6 +51,22 @@ impl EncodingField { BOOLEAN_FALSE_SENTINEL } } + + pub fn list_null_sentinel(self) -> u8 { + crate::fixed::get_null_sentinel(&self) + } + + pub fn list_continuation_token(self) -> u8 { + if self.descending { + !LIST_CONTINUATION_TOKEN + } else { + LIST_CONTINUATION_TOKEN + } + } + + pub fn list_termination_token(self) -> u8 { + !self.list_continuation_token() + } } #[derive(Default, Clone)] diff --git a/py-polars/tests/unit/test_row_encoding.py b/py-polars/tests/unit/test_row_encoding.py index 2dc34e136bbd..4971ccb59ff8 100644 --- a/py-polars/tests/unit/test_row_encoding.py +++ b/py-polars/tests/unit/test_row_encoding.py @@ -427,6 +427,13 @@ def test_order_list() -> None: assert_order_series([None], [[1, 2, 3]], dtype, ["gt"], nulls_last=True) assert_order_series([[None, 2, 3]], [[None, 2, 1]], dtype, ["gt"]) + assert_order_series([[]], [[None]], dtype, ["lt"]) + assert_order_series([[]], [[None]], dtype, ["lt"], descending=True) + assert_order_series([[]], [[1]], dtype, ["lt"]) + assert_order_series([[]], [[1]], dtype, ["lt"], descending=True) + assert_order_series([[1]], [[1, 2]], dtype, ["lt"]) + assert_order_series([[1]], [[1, 2]], dtype, ["lt"], descending=True) + def test_order_array() -> None: dtype = pl.Array(pl.Int32, 3)