From 8d9ed64f0c7f1c7f6b52749c29b0c0b8a79b1fb1 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Sun, 24 Nov 2024 11:44:22 +0100 Subject: [PATCH] feat: Add AhoCorasick backed 'find_many' (#19952) --- .../src/array/primitive/mutable.rs | 4 +- .../chunked_array/builder/list/primitive.rs | 18 ++++- .../src/chunked_array/strings/find_many.rs | 73 ++++++++++++++++- .../src/dsl/function_expr/range/int_range.rs | 6 +- .../src/dsl/function_expr/strings.rs | 34 ++++++++ crates/polars-plan/src/dsl/string.rs | 25 ++++++ crates/polars-python/src/expr/string.rs | 14 ++++ crates/polars-python/src/lazyframe/visit.rs | 2 +- .../src/lazyframe/visitor/expr_nodes.rs | 4 + .../source/reference/expressions/string.rst | 1 + .../docs/source/reference/series/string.rst | 1 + py-polars/polars/expr/string.py | 81 ++++++++++++++++++- py-polars/polars/series/string.py | 77 +++++++++++++++++- .../namespaces/string/test_string.py | 8 ++ 14 files changed, 334 insertions(+), 14 deletions(-) diff --git a/crates/polars-arrow/src/array/primitive/mutable.rs b/crates/polars-arrow/src/array/primitive/mutable.rs index 287fe52f0534..4c6a5802c691 100644 --- a/crates/polars-arrow/src/array/primitive/mutable.rs +++ b/crates/polars-arrow/src/array/primitive/mutable.rs @@ -220,7 +220,7 @@ impl MutablePrimitiveArray { where I: TrustedLen, { - unsafe { self.extend_trusted_len_values_unchecked(iterator) } + unsafe { self.extend_values(iterator) } } /// Extends the [`MutablePrimitiveArray`] from an iterator of values of trusted len. @@ -229,7 +229,7 @@ impl MutablePrimitiveArray { /// # Safety /// The iterator must be trusted len. #[inline] - pub unsafe fn extend_trusted_len_values_unchecked(&mut self, iterator: I) + pub fn extend_values(&mut self, iterator: I) where I: Iterator, { diff --git a/crates/polars-core/src/chunked_array/builder/list/primitive.rs b/crates/polars-core/src/chunked_array/builder/list/primitive.rs index a6a700635d6a..d9e74536bf13 100644 --- a/crates/polars-core/src/chunked_array/builder/list/primitive.rs +++ b/crates/polars-core/src/chunked_array/builder/list/primitive.rs @@ -76,7 +76,10 @@ where } /// Appends from an iterator over values #[inline] - pub fn append_iter_values + TrustedLen>(&mut self, iter: I) { + pub fn append_values_iter_trusted_len + TrustedLen>( + &mut self, + iter: I, + ) { let values = self.builder.mut_values(); if iter.size_hint().0 == 0 { @@ -84,7 +87,18 @@ where } // SAFETY: // trusted len, trust the type system - unsafe { values.extend_trusted_len_values_unchecked(iter) }; + values.extend_values(iter); + self.builder.try_push_valid().unwrap(); + } + + #[inline] + pub fn append_values_iter>(&mut self, iter: I) { + let values = self.builder.mut_values(); + + if iter.size_hint().0 == 0 { + self.fast_explode = false; + } + values.extend_values(iter); self.builder.try_push_valid().unwrap(); } diff --git a/crates/polars-ops/src/chunked_array/strings/find_many.rs b/crates/polars-ops/src/chunked_array/strings/find_many.rs index d56d8b3e014d..478f1b05b777 100644 --- a/crates/polars-ops/src/chunked_array/strings/find_many.rs +++ b/crates/polars-ops/src/chunked_array/strings/find_many.rs @@ -60,7 +60,12 @@ pub fn replace_all( })) } -fn push(val: &str, builder: &mut ListStringChunkedBuilder, ac: &AhoCorasick, overlapping: bool) { +fn push_str( + val: &str, + builder: &mut ListStringChunkedBuilder, + ac: &AhoCorasick, + overlapping: bool, +) { if overlapping { let iter = ac.find_overlapping_iter(val); let iter = iter.map(|m| &val[m.start()..m.end()]); @@ -92,7 +97,7 @@ pub fn extract_many( (Some(val), Some(pat)) => { let pat = pat.as_any().downcast_ref::().unwrap(); let ac = build_ac_arr(pat, ascii_case_insensitive)?; - push(val, &mut builder, &ac, overlapping); + push_str(val, &mut builder, &ac, overlapping); }, } } @@ -108,7 +113,69 @@ pub fn extract_many( for arr in ca.downcast_iter() { for opt_val in arr.into_iter() { if let Some(val) = opt_val { - push(val, &mut builder, &ac, overlapping); + push_str(val, &mut builder, &ac, overlapping); + } else { + builder.append_null(); + } + } + } + Ok(builder.finish()) + }, + _ => { + polars_bail!(InvalidOperation: "expected 'String/List' datatype for 'patterns' argument") + }, + } +} + +type B = ListPrimitiveChunkedBuilder; +fn push_idx(val: &str, builder: &mut B, ac: &AhoCorasick, overlapping: bool) { + if overlapping { + let iter = ac.find_overlapping_iter(val); + let iter = iter.map(|m| m.start() as u32); + builder.append_values_iter(iter); + } else { + let iter = ac.find_iter(val); + let iter = iter.map(|m| m.start() as u32); + builder.append_values_iter(iter); + } +} + +pub fn find_many( + ca: &StringChunked, + patterns: &Series, + ascii_case_insensitive: bool, + overlapping: bool, +) -> PolarsResult { + type B = ListPrimitiveChunkedBuilder; + match patterns.dtype() { + DataType::List(inner) if inner.is_string() => { + let mut builder = B::new(ca.name().clone(), ca.len(), ca.len() * 2, DataType::UInt32); + let patterns = patterns.list().unwrap(); + let (ca, patterns) = align_chunks_binary(ca, patterns); + + for (arr, pat_arr) in ca.downcast_iter().zip(patterns.downcast_iter()) { + for z in arr.into_iter().zip(pat_arr.into_iter()) { + match z { + (None, _) | (_, None) => builder.append_null(), + (Some(val), Some(pat)) => { + let pat = pat.as_any().downcast_ref::().unwrap(); + let ac = build_ac_arr(pat, ascii_case_insensitive)?; + push_idx(val, &mut builder, &ac, overlapping); + }, + } + } + } + Ok(builder.finish()) + }, + DataType::String => { + let patterns = patterns.str().unwrap(); + let ac = build_ac(patterns, ascii_case_insensitive)?; + let mut builder = B::new(ca.name().clone(), ca.len(), ca.len() * 2, DataType::UInt32); + + for arr in ca.downcast_iter() { + for opt_val in arr.into_iter() { + if let Some(val) = opt_val { + push_idx(val, &mut builder, &ac, overlapping); } else { builder.append_null(); } diff --git a/crates/polars-plan/src/dsl/function_expr/range/int_range.rs b/crates/polars-plan/src/dsl/function_expr/range/int_range.rs index f9b524cfe481..12bd9082c7a6 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/int_range.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/int_range.rs @@ -67,9 +67,9 @@ pub(super) fn int_ranges(s: &[Column]) -> PolarsResult { let range_impl = |start, end, step: i64, builder: &mut ListPrimitiveChunkedBuilder| { match step { - 1 => builder.append_iter_values(start..end), - 2.. => builder.append_iter_values((start..end).step_by(step as usize)), - _ => builder.append_iter_values( + 1 => builder.append_values_iter_trusted_len(start..end), + 2.. => builder.append_values_iter_trusted_len((start..end).step_by(step as usize)), + _ => builder.append_values_iter_trusted_len( (end..start) .step_by(step.unsigned_abs() as usize) .map(|x| start - (x - end)), diff --git a/crates/polars-plan/src/dsl/function_expr/strings.rs b/crates/polars-plan/src/dsl/function_expr/strings.rs index 48fd35725fde..c79a5a9f2d37 100644 --- a/crates/polars-plan/src/dsl/function_expr/strings.rs +++ b/crates/polars-plan/src/dsl/function_expr/strings.rs @@ -130,6 +130,11 @@ pub enum StringFunction { ascii_case_insensitive: bool, overlapping: bool, }, + #[cfg(feature = "find_many")] + FindMany { + ascii_case_insensitive: bool, + overlapping: bool, + }, #[cfg(feature = "regex")] EscapeRegex, } @@ -199,6 +204,8 @@ impl StringFunction { ReplaceMany { .. } => mapper.with_same_dtype(), #[cfg(feature = "find_many")] ExtractMany { .. } => mapper.with_dtype(DataType::List(Box::new(DataType::String))), + #[cfg(feature = "find_many")] + FindMany { .. } => mapper.with_dtype(DataType::List(Box::new(DataType::UInt32))), #[cfg(feature = "regex")] EscapeRegex => mapper.with_same_dtype(), } @@ -289,6 +296,8 @@ impl Display for StringFunction { ReplaceMany { .. } => "replace_many", #[cfg(feature = "find_many")] ExtractMany { .. } => "extract_many", + #[cfg(feature = "find_many")] + FindMany { .. } => "extract_many", #[cfg(feature = "regex")] EscapeRegex => "escape_regex", }; @@ -406,6 +415,13 @@ impl From for SpecialEq> { } => { map_as_slice!(extract_many, ascii_case_insensitive, overlapping) }, + #[cfg(feature = "find_many")] + FindMany { + ascii_case_insensitive, + overlapping, + } => { + map_as_slice!(find_many, ascii_case_insensitive, overlapping) + }, #[cfg(feature = "regex")] EscapeRegex => map!(escape_regex), } @@ -452,6 +468,24 @@ fn extract_many( .map(|out| out.into_column()) } +#[cfg(feature = "find_many")] +fn find_many( + s: &[Column], + ascii_case_insensitive: bool, + overlapping: bool, +) -> PolarsResult { + let ca = s[0].str()?; + let patterns = &s[1]; + + polars_ops::chunked_array::strings::find_many( + ca, + patterns.as_materialized_series(), + ascii_case_insensitive, + overlapping, + ) + .map(|out| out.into_column()) +} + fn uppercase(s: &Column) -> PolarsResult { let ca = s.str()?; Ok(ca.to_uppercase().into_column()) diff --git a/crates/polars-plan/src/dsl/string.rs b/crates/polars-plan/src/dsl/string.rs index 2514d1a5f6a4..c9ae23bdbb04 100644 --- a/crates/polars-plan/src/dsl/string.rs +++ b/crates/polars-plan/src/dsl/string.rs @@ -100,6 +100,31 @@ impl StringNameSpace { ) } + /// Uses aho-corasick to find many patterns. + /// # Arguments + /// - `patterns`: an expression that evaluates to a String column + /// - `ascii_case_insensitive`: Enable ASCII-aware case-insensitive matching. + /// When this option is enabled, searching will be performed without respect to case for + /// ASCII letters (a-z and A-Z) only. + /// - `overlapping`: Whether matches may overlap. + #[cfg(feature = "find_many")] + pub fn find_many( + self, + patterns: Expr, + ascii_case_insensitive: bool, + overlapping: bool, + ) -> Expr { + self.0.map_many_private( + FunctionExpr::StringExpr(StringFunction::FindMany { + ascii_case_insensitive, + overlapping, + }), + &[patterns], + false, + None, + ) + } + /// Check if a string value ends with the `sub` string. pub fn ends_with(self, sub: Expr) -> Expr { self.0.map_many_private( diff --git a/crates/polars-python/src/expr/string.rs b/crates/polars-python/src/expr/string.rs index 87521a2b7aa1..8876ccea1d84 100644 --- a/crates/polars-python/src/expr/string.rs +++ b/crates/polars-python/src/expr/string.rs @@ -340,6 +340,20 @@ impl PyExpr { .into() } + #[cfg(feature = "find_many")] + fn str_find_many( + &self, + patterns: PyExpr, + ascii_case_insensitive: bool, + overlapping: bool, + ) -> Self { + self.inner + .clone() + .str() + .find_many(patterns.inner, ascii_case_insensitive, overlapping) + .into() + } + #[cfg(feature = "regex")] fn str_escape_regex(&self) -> Self { self.inner.clone().str().escape_regex().into() diff --git a/crates/polars-python/src/lazyframe/visit.rs b/crates/polars-python/src/lazyframe/visit.rs index bc4cebb360a2..b698f68a47c7 100644 --- a/crates/polars-python/src/lazyframe/visit.rs +++ b/crates/polars-python/src/lazyframe/visit.rs @@ -57,7 +57,7 @@ impl NodeTraverser { // Increment major on breaking changes to the IR (e.g. renaming // fields, reordering tuples), minor on backwards compatible // changes (e.g. exposing a new expression node). - const VERSION: Version = (3, 1); + const VERSION: Version = (3, 2); pub fn new(root: Node, lp_arena: Arena, expr_arena: Arena) -> Self { Self { diff --git a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs index aa611f08c551..c5cda028f74b 100644 --- a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs +++ b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs @@ -973,6 +973,10 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { StringFunction::ExtractMany { .. } => { return Err(PyNotImplementedError::new_err("extract_many")) }, + #[cfg(feature = "find_many")] + StringFunction::FindMany { .. } => { + return Err(PyNotImplementedError::new_err("find_many")) + }, #[cfg(feature = "regex")] StringFunction::EscapeRegex => { (PyStringFunction::EscapeRegex.into_py(py),).to_object(py) diff --git a/py-polars/docs/source/reference/expressions/string.rst b/py-polars/docs/source/reference/expressions/string.rst index 7c1358b480f6..833917061657 100644 --- a/py-polars/docs/source/reference/expressions/string.rst +++ b/py-polars/docs/source/reference/expressions/string.rst @@ -23,6 +23,7 @@ The following methods are available under the `expr.str` attribute. Expr.str.extract_groups Expr.str.extract_many Expr.str.find + Expr.str.find_many Expr.str.head Expr.str.join Expr.str.json_decode diff --git a/py-polars/docs/source/reference/series/string.rst b/py-polars/docs/source/reference/series/string.rst index 85dcf4b1b2d6..0d2ad76959e1 100644 --- a/py-polars/docs/source/reference/series/string.rst +++ b/py-polars/docs/source/reference/series/string.rst @@ -23,6 +23,7 @@ The following methods are available under the `Series.str` attribute. Series.str.extract_groups Series.str.extract_many Series.str.find + Series.str.find_many Series.str.head Series.str.join Series.str.json_decode diff --git a/py-polars/polars/expr/string.py b/py-polars/polars/expr/string.py index 37481a925e68..aa34dec49d1d 100644 --- a/py-polars/polars/expr/string.py +++ b/py-polars/polars/expr/string.py @@ -995,7 +995,7 @@ def find( self, pattern: str | Expr, *, literal: bool = False, strict: bool = True ) -> Expr: """ - Return the index position of the first substring matching a pattern. + Return the bytes offset of the first substring matching a pattern. If the pattern is not found, returns None. @@ -2731,6 +2731,85 @@ def extract_many( self._pyexpr.str_extract_many(patterns, ascii_case_insensitive, overlapping) ) + @unstable() + def find_many( + self, + patterns: IntoExpr, + *, + ascii_case_insensitive: bool = False, + overlapping: bool = False, + ) -> Expr: + """ + Use the Aho-Corasick algorithm to find many matches. + + The function will return the bytes offset of the start of each match. + The return type will be `List` + + Parameters + ---------- + patterns + String patterns to search. + ascii_case_insensitive + Enable ASCII-aware case-insensitive matching. + When this option is enabled, searching will be performed without respect + to case for ASCII letters (a-z and A-Z) only. + overlapping + Whether matches may overlap. + + Notes + ----- + This method supports matching on string literals only, and does not support + regular expression matching. + + Examples + -------- + >>> _ = pl.Config.set_fmt_str_lengths(100) + >>> df = pl.DataFrame({"values": ["discontent"]}) + >>> patterns = ["winter", "disco", "onte", "discontent"] + >>> df.with_columns( + ... pl.col("values") + ... .str.extract_many(patterns, overlapping=False) + ... .alias("matches"), + ... pl.col("values") + ... .str.extract_many(patterns, overlapping=True) + ... .alias("matches_overlapping"), + ... ) + shape: (1, 3) + ┌────────────┬───────────┬─────────────────────────────────┐ + │ values ┆ matches ┆ matches_overlapping │ + │ --- ┆ --- ┆ --- │ + │ str ┆ list[str] ┆ list[str] │ + ╞════════════╪═══════════╪═════════════════════════════════╡ + │ discontent ┆ ["disco"] ┆ ["disco", "onte", "discontent"] │ + └────────────┴───────────┴─────────────────────────────────┘ + >>> df = pl.DataFrame( + ... { + ... "values": ["discontent", "rhapsody"], + ... "patterns": [ + ... ["winter", "disco", "onte", "discontent"], + ... ["rhap", "ody", "coalesce"], + ... ], + ... } + ... ) + >>> df.select(pl.col("values").str.find_many("patterns")) + shape: (2, 1) + ┌───────────┐ + │ values │ + │ --- │ + │ list[u32] │ + ╞═══════════╡ + │ [0] │ + │ [0, 5] │ + └───────────┘ + + """ + patterns = parse_into_expression( + patterns, str_as_lit=False, list_as_series=True + ) + return wrap_expr( + self._pyexpr.str_find_many(patterns, ascii_case_insensitive, overlapping) + ) + def join(self, delimiter: str = "", *, ignore_nulls: bool = True) -> Expr: """ Vertically concatenate the string values in the column to a single string value. diff --git a/py-polars/polars/series/string.py b/py-polars/polars/series/string.py index 064cfc580b21..1072997ce8e6 100644 --- a/py-polars/polars/series/string.py +++ b/py-polars/polars/series/string.py @@ -444,9 +444,9 @@ def contains( def find( self, pattern: str | Expr, *, literal: bool = False, strict: bool = True - ) -> Expr: + ) -> Series: """ - Return the index of the first substring in Series strings matching a pattern. + Return the bytes offset of the first substring matching a pattern. If the pattern is not found, returns None. @@ -2049,6 +2049,79 @@ def extract_many( """ + @unstable() + def find_many( + self, + patterns: IntoExpr, + *, + ascii_case_insensitive: bool = False, + overlapping: bool = False, + ) -> Series: + """ + Use the Aho-Corasick algorithm to find many matches. + + The function will return the bytes offset of the start of each match. + The return type will be `List` + + Parameters + ---------- + patterns + String patterns to search. + ascii_case_insensitive + Enable ASCII-aware case-insensitive matching. + When this option is enabled, searching will be performed without respect + to case for ASCII letters (a-z and A-Z) only. + overlapping + Whether matches may overlap. + + Notes + ----- + This method supports matching on string literals only, and does not support + regular expression matching. + + Examples + -------- + >>> _ = pl.Config.set_fmt_str_lengths(100) + >>> df = pl.DataFrame({"values": ["discontent"]}) + >>> patterns = ["winter", "disco", "onte", "discontent"] + >>> df.with_columns( + ... pl.col("values") + ... .str.extract_many(patterns, overlapping=False) + ... .alias("matches"), + ... pl.col("values") + ... .str.extract_many(patterns, overlapping=True) + ... .alias("matches_overlapping"), + ... ) + shape: (1, 3) + ┌────────────┬───────────┬─────────────────────────────────┐ + │ values ┆ matches ┆ matches_overlapping │ + │ --- ┆ --- ┆ --- │ + │ str ┆ list[str] ┆ list[str] │ + ╞════════════╪═══════════╪═════════════════════════════════╡ + │ discontent ┆ ["disco"] ┆ ["disco", "onte", "discontent"] │ + └────────────┴───────────┴─────────────────────────────────┘ + >>> df = pl.DataFrame( + ... { + ... "values": ["discontent", "rhapsody"], + ... "patterns": [ + ... ["winter", "disco", "onte", "discontent"], + ... ["rhap", "ody", "coalesce"], + ... ], + ... } + ... ) + >>> df.select(pl.col("values").str.find_many("patterns")) + shape: (2, 1) + ┌───────────┐ + │ values │ + │ --- │ + │ list[u32] │ + ╞═══════════╡ + │ [0] │ + │ [0, 5] │ + └───────────┘ + + """ + def join(self, delimiter: str = "", *, ignore_nulls: bool = True) -> Series: """ Vertically concatenate the string values in the column to a single string value. diff --git a/py-polars/tests/unit/operations/namespaces/string/test_string.py b/py-polars/tests/unit/operations/namespaces/string/test_string.py index fcca5c5987b1..d041035563d8 100644 --- a/py-polars/tests/unit/operations/namespaces/string/test_string.py +++ b/py-polars/tests/unit/operations/namespaces/string/test_string.py @@ -1784,10 +1784,18 @@ def test_extract_many() -> None: } ) + # extract_many assert df.select(pl.col("values").str.extract_many("patterns")).to_dict( as_series=False ) == {"values": [["disco"], ["rhap", "ody"]]} + # find_many + f1 = df.select(pl.col("values").str.find_many("patterns")) + f2 = df["values"].str.find_many(df["patterns"]) + + assert_series_equal(f1["values"], f2) + assert f2.to_list() == [[0], [0, 5]] + def test_json_decode_raise_on_data_type_mismatch_13061() -> None: assert_series_equal(