Skip to content

Commit de6d6f5

Browse files
committed
feat: add downcast_integer_array macro helper
1 parent c1cb61d commit de6d6f5

File tree

2 files changed

+95
-13
lines changed

2 files changed

+95
-13
lines changed

arrow-array/src/cast.rs

+88-2
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,50 @@ macro_rules! downcast_integer {
103103
};
104104
}
105105

106+
/// Downcast an [`Array`] to an integer [`PrimitiveArray`] based on its [`DataType`]
107+
/// accepts a number of subsequent patterns to match the data type
108+
///
109+
/// ```
110+
/// # use arrow_array::{Array, downcast_integer_array, cast::as_string_array};
111+
/// # use arrow_schema::DataType;
112+
///
113+
/// fn print_integer(array: &dyn Array) {
114+
/// downcast_integer_array!(
115+
/// array => {
116+
/// for v in array {
117+
/// println!("{:?}", v);
118+
/// }
119+
/// }
120+
/// DataType::Utf8 => {
121+
/// for v in as_string_array(array) {
122+
/// println!("{:?}", v);
123+
/// }
124+
/// }
125+
/// t => println!("Unsupported datatype {}", t)
126+
/// )
127+
/// }
128+
/// ```
129+
///
130+
/// [`DataType`]: arrow_schema::DataType
131+
#[macro_export]
132+
macro_rules! downcast_integer_array {
133+
($values:ident => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => {
134+
$crate::downcast_integer_array!($values => {$e} $($p => $fallback)*)
135+
};
136+
(($($values:ident),+) => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => {
137+
$crate::downcast_integer_array!($($values),+ => {$e} $($p => $fallback)*)
138+
};
139+
($($values:ident),+ => $e:block $($p:pat => $fallback:expr $(,)*)*) => {
140+
$crate::downcast_integer_array!(($($values),+) => $e $($p => $fallback)*)
141+
};
142+
(($($values:ident),+) => $e:block $($p:pat => $fallback:expr $(,)*)*) => {
143+
$crate::downcast_integer!{
144+
$($values.data_type()),+ => ($crate::downcast_primitive_array_helper, $($values),+, $e),
145+
$($p => $fallback,)*
146+
}
147+
};
148+
}
149+
106150
/// Given one or more expressions evaluating to an integer [`DataType`] invokes the provided macro
107151
/// `m` with the corresponding integer [`RunEndIndexType`], followed by any additional arguments
108152
///
@@ -999,11 +1043,11 @@ impl AsArray for ArrayRef {
9991043

10001044
#[cfg(test)]
10011045
mod tests {
1046+
use super::*;
10021047
use arrow_buffer::i256;
1048+
use arrow_schema::DataType;
10031049
use std::sync::Arc;
10041050

1005-
use super::*;
1006-
10071051
#[test]
10081052
fn test_as_primitive_array_ref() {
10091053
let array: Int32Array = vec![1, 2, 3].into_iter().map(Some).collect();
@@ -1035,4 +1079,46 @@ mod tests {
10351079
let a = Decimal256Array::from_iter_values([1, 2, 4, 5].into_iter().map(i256::from_i128));
10361080
assert!(!as_primitive_array::<Decimal256Type>(&a).is_empty());
10371081
}
1082+
1083+
#[test]
1084+
fn downcast_integer_array_should_match_only_integers() {
1085+
let i32_array: ArrayRef = Arc::new(Int32Array::new_null(1));
1086+
let i32_array_ref = &i32_array;
1087+
downcast_integer_array!(
1088+
i32_array_ref => {
1089+
assert_eq!(i32_array_ref.null_count(), 1);
1090+
},
1091+
_ => panic!("unexpected data type")
1092+
);
1093+
}
1094+
1095+
#[test]
1096+
fn downcast_integer_array_should_not_match_primitive_that_are_not_integers() {
1097+
let array: ArrayRef = Arc::new(Float32Array::new_null(1));
1098+
let array_ref = &array;
1099+
downcast_integer_array!(
1100+
array_ref => {
1101+
panic!("unexpected data type {}", array_ref.data_type())
1102+
},
1103+
DataType::Float32 => {
1104+
assert_eq!(array_ref.null_count(), 1);
1105+
},
1106+
_ => panic!("unexpected data type")
1107+
);
1108+
}
1109+
1110+
#[test]
1111+
fn downcast_integer_array_should_not_match_non_primitive() {
1112+
let array: ArrayRef = Arc::new(StringArray::new_null(1));
1113+
let array_ref = &array;
1114+
downcast_integer_array!(
1115+
array_ref => {
1116+
panic!("unexpected data type {}", array_ref.data_type())
1117+
},
1118+
DataType::Utf8 => {
1119+
assert_eq!(array_ref.null_count(), 1);
1120+
},
1121+
_ => panic!("unexpected data type")
1122+
);
1123+
}
10381124
}

arrow-select/src/take.rs

+7-11
Original file line numberDiff line numberDiff line change
@@ -82,20 +82,16 @@ pub fn take(
8282
options: Option<TakeOptions>,
8383
) -> Result<ArrayRef, ArrowError> {
8484
let options = options.unwrap_or_default();
85-
macro_rules! helper {
86-
($t:ty, $values:expr, $indices:expr, $options:expr) => {{
87-
let indices = indices.as_primitive::<$t>();
88-
if $options.check_bounds {
89-
check_bounds($values.len(), indices)?;
85+
downcast_integer_array!(
86+
indices => {
87+
if options.check_bounds {
88+
check_bounds(values.len(), indices)?;
9089
}
9190
let indices = indices.to_indices();
92-
take_impl($values, &indices)
93-
}};
94-
}
95-
downcast_integer! {
96-
indices.data_type() => (helper, values, indices, options),
91+
take_impl(values, &indices)
92+
},
9793
d => Err(ArrowError::InvalidArgumentError(format!("Take only supported for integers, got {d:?}")))
98-
}
94+
)
9995
}
10096

10197
/// For each [ArrayRef] in the [`Vec<ArrayRef>`], take elements by index and create a new

0 commit comments

Comments
 (0)