Skip to content

Commit

Permalink
feat: Add AhoCorasick backed 'find_many' (#19952)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Nov 24, 2024
1 parent ce32455 commit 8d9ed64
Show file tree
Hide file tree
Showing 14 changed files with 334 additions and 14 deletions.
4 changes: 2 additions & 2 deletions crates/polars-arrow/src/array/primitive/mutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ impl<T: NativeType> MutablePrimitiveArray<T> {
where
I: TrustedLen<Item = T>,
{
unsafe { self.extend_trusted_len_values_unchecked(iterator) }
unsafe { self.extend_values(iterator) }
}

/// Extends the [`MutablePrimitiveArray`] from an iterator of values of trusted len.
Expand All @@ -229,7 +229,7 @@ impl<T: NativeType> MutablePrimitiveArray<T> {
/// # Safety
/// The iterator must be trusted len.
#[inline]
pub unsafe fn extend_trusted_len_values_unchecked<I>(&mut self, iterator: I)
pub fn extend_values<I>(&mut self, iterator: I)
where
I: Iterator<Item = T>,
{
Expand Down
18 changes: 16 additions & 2 deletions crates/polars-core/src/chunked_array/builder/list/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,29 @@ where
}
/// Appends from an iterator over values
#[inline]
pub fn append_iter_values<I: Iterator<Item = T::Native> + TrustedLen>(&mut self, iter: I) {
pub fn append_values_iter_trusted_len<I: Iterator<Item = T::Native> + TrustedLen>(
&mut self,
iter: I,
) {
let values = self.builder.mut_values();

if iter.size_hint().0 == 0 {
self.fast_explode = false;
}
// SAFETY:
// trusted len, trust the type system
unsafe { values.extend_trusted_len_values_unchecked(iter) };
values.extend_values(iter);
self.builder.try_push_valid().unwrap();
}

#[inline]
pub fn append_values_iter<I: Iterator<Item = T::Native>>(&mut self, iter: I) {
let values = self.builder.mut_values();

if iter.size_hint().0 == 0 {
self.fast_explode = false;
}
values.extend_values(iter);
self.builder.try_push_valid().unwrap();
}

Expand Down
73 changes: 70 additions & 3 deletions crates/polars-ops/src/chunked_array/strings/find_many.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,12 @@ pub fn replace_all(
}))
}

fn push(val: &str, builder: &mut ListStringChunkedBuilder, ac: &AhoCorasick, overlapping: bool) {
fn push_str(
val: &str,
builder: &mut ListStringChunkedBuilder,
ac: &AhoCorasick,
overlapping: bool,
) {
if overlapping {
let iter = ac.find_overlapping_iter(val);
let iter = iter.map(|m| &val[m.start()..m.end()]);
Expand Down Expand Up @@ -92,7 +97,7 @@ pub fn extract_many(
(Some(val), Some(pat)) => {
let pat = pat.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
let ac = build_ac_arr(pat, ascii_case_insensitive)?;
push(val, &mut builder, &ac, overlapping);
push_str(val, &mut builder, &ac, overlapping);
},
}
}
Expand All @@ -108,7 +113,69 @@ pub fn extract_many(
for arr in ca.downcast_iter() {
for opt_val in arr.into_iter() {
if let Some(val) = opt_val {
push(val, &mut builder, &ac, overlapping);
push_str(val, &mut builder, &ac, overlapping);
} else {
builder.append_null();
}
}
}
Ok(builder.finish())
},
_ => {
polars_bail!(InvalidOperation: "expected 'String/List<String>' datatype for 'patterns' argument")
},
}
}

type B = ListPrimitiveChunkedBuilder<UInt32Type>;
fn push_idx(val: &str, builder: &mut B, ac: &AhoCorasick, overlapping: bool) {
if overlapping {
let iter = ac.find_overlapping_iter(val);
let iter = iter.map(|m| m.start() as u32);
builder.append_values_iter(iter);
} else {
let iter = ac.find_iter(val);
let iter = iter.map(|m| m.start() as u32);
builder.append_values_iter(iter);
}
}

pub fn find_many(
ca: &StringChunked,
patterns: &Series,
ascii_case_insensitive: bool,
overlapping: bool,
) -> PolarsResult<ListChunked> {
type B = ListPrimitiveChunkedBuilder<UInt32Type>;
match patterns.dtype() {
DataType::List(inner) if inner.is_string() => {
let mut builder = B::new(ca.name().clone(), ca.len(), ca.len() * 2, DataType::UInt32);
let patterns = patterns.list().unwrap();
let (ca, patterns) = align_chunks_binary(ca, patterns);

for (arr, pat_arr) in ca.downcast_iter().zip(patterns.downcast_iter()) {
for z in arr.into_iter().zip(pat_arr.into_iter()) {
match z {
(None, _) | (_, None) => builder.append_null(),
(Some(val), Some(pat)) => {
let pat = pat.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
let ac = build_ac_arr(pat, ascii_case_insensitive)?;
push_idx(val, &mut builder, &ac, overlapping);
},
}
}
}
Ok(builder.finish())
},
DataType::String => {
let patterns = patterns.str().unwrap();
let ac = build_ac(patterns, ascii_case_insensitive)?;
let mut builder = B::new(ca.name().clone(), ca.len(), ca.len() * 2, DataType::UInt32);

for arr in ca.downcast_iter() {
for opt_val in arr.into_iter() {
if let Some(val) = opt_val {
push_idx(val, &mut builder, &ac, overlapping);
} else {
builder.append_null();
}
Expand Down
6 changes: 3 additions & 3 deletions crates/polars-plan/src/dsl/function_expr/range/int_range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ pub(super) fn int_ranges(s: &[Column]) -> PolarsResult<Column> {
let range_impl =
|start, end, step: i64, builder: &mut ListPrimitiveChunkedBuilder<Int64Type>| {
match step {
1 => builder.append_iter_values(start..end),
2.. => builder.append_iter_values((start..end).step_by(step as usize)),
_ => builder.append_iter_values(
1 => builder.append_values_iter_trusted_len(start..end),
2.. => builder.append_values_iter_trusted_len((start..end).step_by(step as usize)),
_ => builder.append_values_iter_trusted_len(
(end..start)
.step_by(step.unsigned_abs() as usize)
.map(|x| start - (x - end)),
Expand Down
34 changes: 34 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ pub enum StringFunction {
ascii_case_insensitive: bool,
overlapping: bool,
},
#[cfg(feature = "find_many")]
FindMany {
ascii_case_insensitive: bool,
overlapping: bool,
},
#[cfg(feature = "regex")]
EscapeRegex,
}
Expand Down Expand Up @@ -199,6 +204,8 @@ impl StringFunction {
ReplaceMany { .. } => mapper.with_same_dtype(),
#[cfg(feature = "find_many")]
ExtractMany { .. } => mapper.with_dtype(DataType::List(Box::new(DataType::String))),
#[cfg(feature = "find_many")]
FindMany { .. } => mapper.with_dtype(DataType::List(Box::new(DataType::UInt32))),
#[cfg(feature = "regex")]
EscapeRegex => mapper.with_same_dtype(),
}
Expand Down Expand Up @@ -289,6 +296,8 @@ impl Display for StringFunction {
ReplaceMany { .. } => "replace_many",
#[cfg(feature = "find_many")]
ExtractMany { .. } => "extract_many",
#[cfg(feature = "find_many")]
FindMany { .. } => "extract_many",
#[cfg(feature = "regex")]
EscapeRegex => "escape_regex",
};
Expand Down Expand Up @@ -406,6 +415,13 @@ impl From<StringFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
} => {
map_as_slice!(extract_many, ascii_case_insensitive, overlapping)
},
#[cfg(feature = "find_many")]
FindMany {
ascii_case_insensitive,
overlapping,
} => {
map_as_slice!(find_many, ascii_case_insensitive, overlapping)
},
#[cfg(feature = "regex")]
EscapeRegex => map!(escape_regex),
}
Expand Down Expand Up @@ -452,6 +468,24 @@ fn extract_many(
.map(|out| out.into_column())
}

#[cfg(feature = "find_many")]
fn find_many(
s: &[Column],
ascii_case_insensitive: bool,
overlapping: bool,
) -> PolarsResult<Column> {
let ca = s[0].str()?;
let patterns = &s[1];

polars_ops::chunked_array::strings::find_many(
ca,
patterns.as_materialized_series(),
ascii_case_insensitive,
overlapping,
)
.map(|out| out.into_column())
}

fn uppercase(s: &Column) -> PolarsResult<Column> {
let ca = s.str()?;
Ok(ca.to_uppercase().into_column())
Expand Down
25 changes: 25 additions & 0 deletions crates/polars-plan/src/dsl/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,31 @@ impl StringNameSpace {
)
}

/// Uses aho-corasick to find many patterns.
/// # Arguments
/// - `patterns`: an expression that evaluates to a String column
/// - `ascii_case_insensitive`: Enable ASCII-aware case-insensitive matching.
/// When this option is enabled, searching will be performed without respect to case for
/// ASCII letters (a-z and A-Z) only.
/// - `overlapping`: Whether matches may overlap.
#[cfg(feature = "find_many")]
pub fn find_many(
self,
patterns: Expr,
ascii_case_insensitive: bool,
overlapping: bool,
) -> Expr {
self.0.map_many_private(
FunctionExpr::StringExpr(StringFunction::FindMany {
ascii_case_insensitive,
overlapping,
}),
&[patterns],
false,
None,
)
}

/// Check if a string value ends with the `sub` string.
pub fn ends_with(self, sub: Expr) -> Expr {
self.0.map_many_private(
Expand Down
14 changes: 14 additions & 0 deletions crates/polars-python/src/expr/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,20 @@ impl PyExpr {
.into()
}

#[cfg(feature = "find_many")]
fn str_find_many(
&self,
patterns: PyExpr,
ascii_case_insensitive: bool,
overlapping: bool,
) -> Self {
self.inner
.clone()
.str()
.find_many(patterns.inner, ascii_case_insensitive, overlapping)
.into()
}

#[cfg(feature = "regex")]
fn str_escape_regex(&self) -> Self {
self.inner.clone().str().escape_regex().into()
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-python/src/lazyframe/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl NodeTraverser {
// Increment major on breaking changes to the IR (e.g. renaming
// fields, reordering tuples), minor on backwards compatible
// changes (e.g. exposing a new expression node).
const VERSION: Version = (3, 1);
const VERSION: Version = (3, 2);

pub fn new(root: Node, lp_arena: Arena<IR>, expr_arena: Arena<AExpr>) -> Self {
Self {
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-python/src/lazyframe/visitor/expr_nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,10 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
StringFunction::ExtractMany { .. } => {
return Err(PyNotImplementedError::new_err("extract_many"))
},
#[cfg(feature = "find_many")]
StringFunction::FindMany { .. } => {
return Err(PyNotImplementedError::new_err("find_many"))
},
#[cfg(feature = "regex")]
StringFunction::EscapeRegex => {
(PyStringFunction::EscapeRegex.into_py(py),).to_object(py)
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/expressions/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The following methods are available under the `expr.str` attribute.
Expr.str.extract_groups
Expr.str.extract_many
Expr.str.find
Expr.str.find_many
Expr.str.head
Expr.str.join
Expr.str.json_decode
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/series/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The following methods are available under the `Series.str` attribute.
Series.str.extract_groups
Series.str.extract_many
Series.str.find
Series.str.find_many
Series.str.head
Series.str.join
Series.str.json_decode
Expand Down
Loading

0 comments on commit 8d9ed64

Please sign in to comment.