From c94de803332b5b3a09ae460adee5150b592e59ca Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Sat, 4 May 2024 16:42:27 +0100 Subject: [PATCH] feat(python): split Expr.top_k and Expr.top_k_by into separate functions (#16041) --- .../reference/expressions/modify_select.rst | 2 + py-polars/polars/expr/expr.py | 214 +++++++++++------- py-polars/polars/series/series.py | 4 +- py-polars/src/expr/general.rs | 6 +- py-polars/tests/unit/operations/test_sort.py | 70 ++---- 5 files changed, 168 insertions(+), 128 deletions(-) diff --git a/py-polars/docs/source/reference/expressions/modify_select.rst b/py-polars/docs/source/reference/expressions/modify_select.rst index 94b461d4c673..c22ebcccce35 100644 --- a/py-polars/docs/source/reference/expressions/modify_select.rst +++ b/py-polars/docs/source/reference/expressions/modify_select.rst @@ -11,6 +11,7 @@ Manipulation/selection Expr.arg_true Expr.backward_fill Expr.bottom_k + Expr.bottom_k_by Expr.cast Expr.ceil Expr.clip @@ -61,5 +62,6 @@ Manipulation/selection Expr.take_every Expr.to_physical Expr.top_k + Expr.top_k_by Expr.upper_bound Expr.where diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 7053844c3f98..cd3cc21c8f77 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -2035,8 +2035,6 @@ def top_k( self, k: int | IntoExprColumn = 5, *, - by: IntoExpr | Iterable[IntoExpr] | None = None, - descending: bool | Sequence[bool] = False, nulls_last: bool = False, maintain_order: bool = False, multithreaded: bool = True, @@ -2046,19 +2044,12 @@ def top_k( This has time complexity: - .. math:: O(n + k \log{}n - \frac{k}{2}) + .. math:: O(n + k \log{n} - \frac{k}{2}) Parameters ---------- k Number of elements to return. - by - Column(s) included in sort order. Accepts expression input. - Strings are parsed as column names. - If not provided, each column will be treated induvidually. - descending - Return the k smallest. Top-k by multiple columns can be specified per - column by passing a sequence of booleans. nulls_last Place null values last. maintain_order @@ -2068,7 +2059,9 @@ def top_k( See Also -------- + top_k_by bottom_k + bottom_k_by Examples -------- @@ -2097,15 +2090,61 @@ def top_k( │ 3 ┆ 4 │ │ 2 ┆ 98 │ └───────┴──────────┘ + """ + k = parse_as_expression(k) + return self._from_pyexpr(self._pyexpr.top_k(k, nulls_last, multithreaded)) + + def top_k_by( + self, + by: IntoExpr | Iterable[IntoExpr], + k: int | IntoExprColumn = 5, + *, + descending: bool | Sequence[bool] = False, + nulls_last: bool = False, + maintain_order: bool = False, + multithreaded: bool = True, + ) -> Self: + r""" + Return elements corresponding to the `k` largest elements of the `by` column(s). + + This has time complexity: + + .. math:: O(n + k \log{n} - \frac{k}{2}) - >>> df2 = pl.DataFrame( + Parameters + ---------- + by + Column(s) included in sort order. Accepts expression input. + Strings are parsed as column names. + k + Number of elements to return. + descending + If `True`, consider the k smallest (instead of the k largest). Top-k by + multiple columns can be specified per column by passing a sequence of + booleans. + nulls_last + Place null values last. + maintain_order + Whether the order should be maintained if elements are equal. + multithreaded + Sort using multiple threads. + + See Also + -------- + top_k + bottom_k + bottom_k_by + + Examples + -------- + >>> df = pl.DataFrame( ... { ... "a": [1, 2, 3, 4, 5, 6], ... "b": [6, 5, 4, 3, 2, 1], ... "c": ["Apple", "Orange", "Apple", "Apple", "Banana", "Banana"], ... } ... ) - >>> df2 + >>> df shape: (6, 3) ┌─────┬─────┬────────┐ │ a ┆ b ┆ c │ @@ -2122,9 +2161,9 @@ def top_k( Get the top 2 rows by column `a` or `b`. - >>> df2.select( - ... pl.all().top_k(2, by="a").name.suffix("_top_by_a"), - ... pl.all().top_k(2, by="b").name.suffix("_top_by_b"), + >>> df.select( + ... pl.all().top_k_by("a", 2).name.suffix("_top_by_a"), + ... pl.all().top_k_by("b", 2).name.suffix("_top_by_b"), ... ) shape: (2, 6) ┌────────────┬────────────┬────────────┬────────────┬────────────┬────────────┐ @@ -2138,12 +2177,12 @@ def top_k( Get the top 2 rows by multiple columns with given order. - >>> df2.select( + >>> df.select( ... pl.all() - ... .top_k(2, by=["c", "a"], descending=[False, True]) + ... .top_k_by(["c", "a"], 2, descending=[False, True]) ... .name.suffix("_by_ca"), ... pl.all() - ... .top_k(2, by=["c", "b"], descending=[False, True]) + ... .top_k_by(["c", "b"], 2, descending=[False, True]) ... .name.suffix("_by_cb"), ... ) shape: (2, 6) @@ -2159,8 +2198,8 @@ def top_k( Get the top 2 rows by column `a` in each group. >>> ( - ... df2.group_by("c", maintain_order=True) - ... .agg(pl.all().top_k(2, by="a")) + ... df.group_by("c", maintain_order=True) + ... .agg(pl.all().top_k_by("a", 2)) ... .explode(pl.all().exclude("c")) ... ) shape: (5, 3) @@ -2177,32 +2216,22 @@ def top_k( └────────┴─────┴─────┘ """ k = parse_as_expression(k) - if by is not None: - by = parse_as_list_of_expressions(by) - if isinstance(descending, bool): - descending = [descending] - elif len(by) != len(descending): - msg = f"the length of `descending` ({len(descending)}) does not match the length of `by` ({len(by)})" - raise ValueError(msg) - return self._from_pyexpr( - self._pyexpr.top_k_by( - k, by, descending, nulls_last, maintain_order, multithreaded - ) - ) - else: - if not isinstance(descending, bool): - msg = "`descending` should be a boolean if no `by` is provided" - raise ValueError(msg) - return self._from_pyexpr( - self._pyexpr.top_k(k, descending, nulls_last, multithreaded) + by = parse_as_list_of_expressions(by) + if isinstance(descending, bool): + descending = [descending] + elif len(by) != len(descending): + msg = f"the length of `descending` ({len(descending)}) does not match the length of `by` ({len(by)})" + raise ValueError(msg) + return self._from_pyexpr( + self._pyexpr.top_k_by( + k, by, descending, nulls_last, maintain_order, multithreaded ) + ) def bottom_k( self, k: int | IntoExprColumn = 5, *, - by: IntoExpr | Iterable[IntoExpr] | None = None, - descending: bool | Sequence[bool] = False, nulls_last: bool = False, maintain_order: bool = False, multithreaded: bool = True, @@ -2212,19 +2241,12 @@ def bottom_k( This has time complexity: - .. math:: O(n + k \log{}n - \frac{k}{2}) + .. math:: O(n + k \log{n} - \frac{k}{2}) Parameters ---------- k Number of elements to return. - by - Column(s) included in sort order. - Accepts expression input. Strings are parsed as column names. - If not provided, each column will be treated induvidually. - descending - Return the k largest. Bottom-k by multiple columns can be specified per - column by passing a sequence of booleans. nulls_last Place null values last. maintain_order @@ -2235,6 +2257,8 @@ def bottom_k( See Also -------- top_k + top_k_by + bottom_k_by Examples -------- @@ -2261,15 +2285,61 @@ def bottom_k( │ 3 ┆ 4 │ │ 2 ┆ 98 │ └───────┴──────────┘ + """ + k = parse_as_expression(k) + return self._from_pyexpr(self._pyexpr.bottom_k(k, nulls_last, multithreaded)) + + def bottom_k_by( + self, + by: IntoExpr | Iterable[IntoExpr], + k: int | IntoExprColumn = 5, + *, + descending: bool | Sequence[bool] = False, + nulls_last: bool = False, + maintain_order: bool = False, + multithreaded: bool = True, + ) -> Self: + r""" + Return elements corresponding to the `k` smallest elements of `by` column(s). + + This has time complexity: + + .. math:: O(n + k \log{n} - \frac{k}{2}) - >>> df2 = pl.DataFrame( + Parameters + ---------- + by + Column(s) included in sort order. + Accepts expression input. Strings are parsed as column names. + k + Number of elements to return. + descending + If `True`, consider the k largest (instead of the k smallest). Bottom-k by + multiple columns can be specified per column by passing a sequence of + booleans. + nulls_last + Place null values last. + maintain_order + Whether the order should be maintained if elements are equal. + multithreaded + Sort using multiple threads. + + See Also + -------- + top_k + top_k_by + bottom_k + + Examples + -------- + >>> df = pl.DataFrame( ... { ... "a": [1, 2, 3, 4, 5, 6], ... "b": [6, 5, 4, 3, 2, 1], ... "c": ["Apple", "Orange", "Apple", "Apple", "Banana", "Banana"], ... } ... ) - >>> df2 + >>> df shape: (6, 3) ┌─────┬─────┬────────┐ │ a ┆ b ┆ c │ @@ -2286,9 +2356,9 @@ def bottom_k( Get the bottom 2 rows by column `a` or `b`. - >>> df2.select( - ... pl.all().bottom_k(2, by="a").name.suffix("_btm_by_a"), - ... pl.all().bottom_k(2, by="b").name.suffix("_btm_by_b"), + >>> df.select( + ... pl.all().bottom_k_by("a", 2).name.suffix("_btm_by_a"), + ... pl.all().bottom_k_by("b", 2).name.suffix("_btm_by_b"), ... ) shape: (2, 6) ┌────────────┬────────────┬────────────┬────────────┬────────────┬────────────┐ @@ -2302,12 +2372,12 @@ def bottom_k( Get the bottom 2 rows by multiple columns with given order. - >>> df2.select( + >>> df.select( ... pl.all() - ... .bottom_k(2, by=["c", "a"], descending=[False, True]) + ... .bottom_k_by(["c", "a"], 2, descending=[False, True]) ... .name.suffix("_by_ca"), ... pl.all() - ... .bottom_k(2, by=["c", "b"], descending=[False, True]) + ... .bottom_k_by(["c", "b"], 2, descending=[False, True]) ... .name.suffix("_by_cb"), ... ) shape: (2, 6) @@ -2323,8 +2393,8 @@ def bottom_k( Get the bottom 2 rows by column `a` in each group. >>> ( - ... df2.group_by("c", maintain_order=True) - ... .agg(pl.all().bottom_k(2, by="a")) + ... df.group_by("c", maintain_order=True) + ... .agg(pl.all().bottom_k_by("a", 2)) ... .explode(pl.all().exclude("c")) ... ) shape: (5, 3) @@ -2341,25 +2411,17 @@ def bottom_k( └────────┴─────┴─────┘ """ k = parse_as_expression(k) - if by is not None: - by = parse_as_list_of_expressions(by) - if isinstance(descending, bool): - descending = [descending] - elif len(by) != len(descending): - msg = f"the length of `descending` ({len(descending)}) does not match the length of `by` ({len(by)})" - raise ValueError(msg) - return self._from_pyexpr( - self._pyexpr.bottom_k_by( - k, by, descending, nulls_last, maintain_order, multithreaded - ) - ) - else: - if not isinstance(descending, bool): - msg = "`descending` should be a boolean if no `by` is provided" - raise ValueError(msg) - return self._from_pyexpr( - self._pyexpr.bottom_k(k, descending, nulls_last, multithreaded) + by = parse_as_list_of_expressions(by) + if isinstance(descending, bool): + descending = [descending] + elif len(by) != len(descending): + msg = f"the length of `descending` ({len(descending)}) does not match the length of `by` ({len(by)})" + raise ValueError(msg) + return self._from_pyexpr( + self._pyexpr.bottom_k_by( + k, by, descending, nulls_last, maintain_order, multithreaded ) + ) def arg_sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: """ diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 84338b5016a0..1d8892e08ee2 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -3405,7 +3405,7 @@ def top_k(self, k: int | IntoExprColumn = 5) -> Series: This has time complexity: - .. math:: O(n + k \log{}n - \frac{k}{2}) + .. math:: O(n + k \log{n} - \frac{k}{2}) Parameters ---------- @@ -3435,7 +3435,7 @@ def bottom_k(self, k: int | IntoExprColumn = 5) -> Series: This has time complexity: - .. math:: O(n + k \log{}n - \frac{k}{2}) + .. math:: O(n + k \log{n} - \frac{k}{2}) Parameters ---------- diff --git a/py-polars/src/expr/general.rs b/py-polars/src/expr/general.rs index deef7c30b573..74ccc97af12b 100644 --- a/py-polars/src/expr/general.rs +++ b/py-polars/src/expr/general.rs @@ -290,13 +290,12 @@ impl PyExpr { } #[cfg(feature = "top_k")] - fn top_k(&self, k: Self, descending: bool, nulls_last: bool, multithreaded: bool) -> Self { + fn top_k(&self, k: Self, nulls_last: bool, multithreaded: bool) -> Self { self.inner .clone() .top_k( k.inner, SortOptions::default() - .with_order_descending(descending) .with_nulls_last(nulls_last) .with_maintain_order(multithreaded), ) @@ -330,13 +329,12 @@ impl PyExpr { } #[cfg(feature = "top_k")] - fn bottom_k(&self, k: Self, descending: bool, nulls_last: bool, multithreaded: bool) -> Self { + fn bottom_k(&self, k: Self, nulls_last: bool, multithreaded: bool) -> Self { self.inner .clone() .bottom_k( k.inner, SortOptions::default() - .with_order_descending(descending) .with_nulls_last(nulls_last) .with_maintain_order(multithreaded), ) diff --git a/py-polars/tests/unit/operations/test_sort.py b/py-polars/tests/unit/operations/test_sort.py index f502e5cccc2b..c6ccdf0a70ed 100644 --- a/py-polars/tests/unit/operations/test_sort.py +++ b/py-polars/tests/unit/operations/test_sort.py @@ -345,16 +345,6 @@ def test_top_k() -> None: pl.DataFrame({"test": [4, 3, 2, 1]}), ) - assert_frame_equal( - df.select(pl.col("test").top_k(10, descending=True)), - pl.DataFrame({"test": [1, 2, 3, 4]}), - ) - - assert_frame_equal( - df.select(pl.col("test").bottom_k(10, descending=True)), - pl.DataFrame({"test": [4, 3, 2, 1]}), - ) - assert_frame_equal( df.select( top_k=pl.col("test").top_k(pl.col("val").min()), @@ -419,8 +409,8 @@ def test_top_k() -> None: assert_frame_equal( df2.select( - pl.col("a", "b").top_k(2, by="a").name.suffix("_top_by_a"), - pl.col("a", "b").top_k(2, by="b").name.suffix("_top_by_b"), + pl.col("a", "b").top_k_by("a", 2).name.suffix("_top_by_a"), + pl.col("a", "b").top_k_by("b", 2).name.suffix("_top_by_b"), ), pl.DataFrame( { @@ -434,8 +424,8 @@ def test_top_k() -> None: assert_frame_equal( df2.select( - pl.col("a", "b").top_k(2, by="a", descending=True).name.suffix("_top_by_a"), - pl.col("a", "b").top_k(2, by="b", descending=True).name.suffix("_top_by_b"), + pl.col("a", "b").top_k_by("a", 2, descending=True).name.suffix("_top_by_a"), + pl.col("a", "b").top_k_by("b", 2, descending=True).name.suffix("_top_by_b"), ), pl.DataFrame( { @@ -449,8 +439,8 @@ def test_top_k() -> None: assert_frame_equal( df2.select( - pl.col("a", "b").bottom_k(2, by="a").name.suffix("_bottom_by_a"), - pl.col("a", "b").bottom_k(2, by="b").name.suffix("_bottom_by_b"), + pl.col("a", "b").bottom_k_by("a", 2).name.suffix("_bottom_by_a"), + pl.col("a", "b").bottom_k_by("b", 2).name.suffix("_bottom_by_b"), ), pl.DataFrame( { @@ -465,10 +455,10 @@ def test_top_k() -> None: assert_frame_equal( df2.select( pl.col("a", "b") - .bottom_k(2, by="a", descending=True) + .bottom_k_by("a", 2, descending=True) .name.suffix("_bottom_by_a"), pl.col("a", "b") - .bottom_k(2, by="b", descending=True) + .bottom_k_by("b", 2, descending=True) .name.suffix("_bottom_by_b"), ), pl.DataFrame( @@ -483,7 +473,7 @@ def test_top_k() -> None: assert_frame_equal( df2.group_by("c", maintain_order=True) - .agg(pl.all().top_k(2, by="a")) + .agg(pl.all().top_k_by("a", 2)) .explode(pl.all().exclude("c")), pl.DataFrame( { @@ -496,7 +486,7 @@ def test_top_k() -> None: assert_frame_equal( df2.group_by("c", maintain_order=True) - .agg(pl.all().bottom_k(2, by="a")) + .agg(pl.all().bottom_k_by("a", 2)) .explode(pl.all().exclude("c")), pl.DataFrame( { @@ -509,8 +499,8 @@ def test_top_k() -> None: assert_frame_equal( df2.select( - pl.col("a", "b", "c").top_k(2, by=["c", "a"]).name.suffix("_top_by_ca"), - pl.col("a", "b", "c").top_k(2, by=["c", "b"]).name.suffix("_top_by_cb"), + pl.col("a", "b", "c").top_k_by(["c", "a"], 2).name.suffix("_top_by_ca"), + pl.col("a", "b", "c").top_k_by(["c", "b"], 2).name.suffix("_top_by_cb"), ), pl.DataFrame( { @@ -527,10 +517,10 @@ def test_top_k() -> None: assert_frame_equal( df2.select( pl.col("a", "b", "c") - .bottom_k(2, by=["c", "a"]) + .bottom_k_by(["c", "a"], 2) .name.suffix("_bottom_by_ca"), pl.col("a", "b", "c") - .bottom_k(2, by=["c", "b"]) + .bottom_k_by(["c", "b"], 2) .name.suffix("_bottom_by_cb"), ), pl.DataFrame( @@ -548,10 +538,10 @@ def test_top_k() -> None: assert_frame_equal( df2.select( pl.col("a", "b", "c") - .top_k(2, by=["c", "a"], descending=[True, False]) + .top_k_by(["c", "a"], 2, descending=[True, False]) .name.suffix("_top_by_ca"), pl.col("a", "b", "c") - .top_k(2, by=["c", "b"], descending=[True, False]) + .top_k_by(["c", "b"], 2, descending=[True, False]) .name.suffix("_top_by_cb"), ), pl.DataFrame( @@ -569,10 +559,10 @@ def test_top_k() -> None: assert_frame_equal( df2.select( pl.col("a", "b", "c") - .bottom_k(2, by=["c", "a"], descending=[True, False]) + .bottom_k_by(["c", "a"], 2, descending=[True, False]) .name.suffix("_bottom_by_ca"), pl.col("a", "b", "c") - .bottom_k(2, by=["c", "b"], descending=[True, False]) + .bottom_k_by(["c", "b"], 2, descending=[True, False]) .name.suffix("_bottom_by_cb"), ), pl.DataFrame( @@ -590,10 +580,10 @@ def test_top_k() -> None: assert_frame_equal( df2.select( pl.col("a", "b", "c") - .top_k(2, by=["c", "a"], descending=[False, True]) + .top_k_by(["c", "a"], 2, descending=[False, True]) .name.suffix("_top_by_ca"), pl.col("a", "b", "c") - .top_k(2, by=["c", "b"], descending=[False, True]) + .top_k_by(["c", "b"], 2, descending=[False, True]) .name.suffix("_top_by_cb"), ), pl.DataFrame( @@ -611,10 +601,10 @@ def test_top_k() -> None: assert_frame_equal( df2.select( pl.col("a", "b", "c") - .top_k(2, by=["c", "a"], descending=[False, True]) + .top_k_by(["c", "a"], 2, descending=[False, True]) .name.suffix("_bottom_by_ca"), pl.col("a", "b", "c") - .top_k(2, by=["c", "b"], descending=[False, True]) + .top_k_by(["c", "b"], 2, descending=[False, True]) .name.suffix("_bottom_by_cb"), ), pl.DataFrame( @@ -633,25 +623,13 @@ def test_top_k() -> None: ValueError, match=r"the length of `descending` \(2\) does not match the length of `by` \(1\)", ): - df2.select(pl.all().top_k(2, by="a", descending=[True, False])) + df2.select(pl.all().top_k_by("a", 2, descending=[True, False])) with pytest.raises( ValueError, match=r"the length of `descending` \(2\) does not match the length of `by` \(1\)", ): - df2.select(pl.all().bottom_k(2, by="a", descending=[True, False])) - - with pytest.raises( - ValueError, - match=r"`descending` should be a boolean if no `by` is provided", - ): - df2.select(pl.all().top_k(2, descending=[True, False])) - - with pytest.raises( - ValueError, - match=r"`descending` should be a boolean if no `by` is provided", - ): - df2.select(pl.all().bottom_k(2, descending=[True, False])) + df2.select(pl.all().bottom_k_by("a", 2, descending=[True, False])) def test_sorted_flag_unset_by_arithmetic_4937() -> None: