Skip to content

Commit

Permalink
feat(rust): Implement unpack dtypes function with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin-Patyk committed Mar 3, 2025
1 parent b9a5bc0 commit 6cc5642
Showing 1 changed file with 73 additions and 0 deletions.
73 changes: 73 additions & 0 deletions crates/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,44 @@ pub fn merge_dtypes(left: &DataType, right: &DataType) -> PolarsResult<DataType>
})
}

fn collect_nested_types(
dtype: &DataType,
result: &mut PlHashSet<DataType>,
include_compound_types: bool,
) {
match dtype {
DataType::List(inner) => {
if include_compound_types {
result.insert(dtype.clone());
}
collect_nested_types(inner, result, include_compound_types);
},
DataType::Array(inner, _) => {
if include_compound_types {
result.insert(dtype.clone());
}
collect_nested_types(inner, result, include_compound_types);
},
DataType::Struct(fields) => {
if include_compound_types {
result.insert(dtype.clone());
}
for field in fields {
collect_nested_types(field.dtype(), result, include_compound_types);
}
},
_ => {
result.insert(dtype.clone());
},
}
}

pub fn unpack_dtypes(dtype: &DataType, include_compound_types: bool) -> PlHashSet<DataType> {
let mut result = PlHashSet::new();
collect_nested_types(dtype, &mut result, include_compound_types);
result
}

#[cfg(feature = "dtype-categorical")]
pub fn create_enum_dtype(categories: Utf8ViewArray) -> DataType {
let rev_map = RevMapping::build_local(categories);
Expand Down Expand Up @@ -1078,3 +1116,38 @@ impl CompatLevel {
self.0
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_unpack_primitive_dtypes() {
let inner_type = DataType::Float64;
let array_type = DataType::Array(Box::new(inner_type), 10);
let list_type = DataType::List(Box::new(array_type.clone()));

let result = unpack_dtypes(&list_type, false);

let mut expected = PlHashSet::new();
expected.insert(DataType::Float64);

assert_eq!(result, expected)
}

#[test]
fn test_unpack_compound_dtypes() {
let inner_type = DataType::Float64;
let array_type = DataType::Array(Box::new(inner_type), 10);
let list_type = DataType::List(Box::new(array_type.clone()));

let result = unpack_dtypes(&list_type, true);

let mut expected = PlHashSet::new();
expected.insert(list_type.clone());
expected.insert(array_type.clone());
expected.insert(DataType::Float64);

assert_eq!(result, expected)
}
}

0 comments on commit 6cc5642

Please sign in to comment.