Skip to content

Commit

Permalink
feat(python): Add top-level nth(n) method, to go with existing `fir…
Browse files Browse the repository at this point in the history
…st` and `last` (#16112)
  • Loading branch information
alexander-beedie authored May 8, 2024
1 parent a3ebdfc commit ddc30ab
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 18 deletions.
9 changes: 7 additions & 2 deletions crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1885,12 +1885,17 @@ pub fn len() -> Expr {
Expr::Len
}

/// First column in DataFrame.
/// First column in a DataFrame.
pub fn first() -> Expr {
Expr::Nth(0)
}

/// Last column in DataFrame.
/// Last column in a DataFrame.
pub fn last() -> Expr {
Expr::Nth(-1)
}

/// Nth column in a DataFrame.
pub fn nth(n: i64) -> Expr {
Expr::Nth(n)
}
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/expressions/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ These functions are available from the Polars module root and can be used as exp
min
min_horizontal
n_unique
nth
ones
quantile
reduce
Expand Down
2 changes: 2 additions & 0 deletions py-polars/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@
min,
min_horizontal,
n_unique,
nth,
ones,
quantile,
reduce,
Expand Down Expand Up @@ -401,6 +402,7 @@
"mean",
"median",
"n_unique",
"nth",
"quantile",
"reduce",
"rolling_corr",
Expand Down
2 changes: 2 additions & 0 deletions py-polars/polars/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
mean,
median,
n_unique,
nth,
quantile,
reduce,
rolling_corr,
Expand Down Expand Up @@ -162,6 +163,7 @@
"mean_horizontal",
"median",
"n_unique",
"nth",
"quantile",
"reduce",
"rolling_corr",
Expand Down
96 changes: 81 additions & 15 deletions py-polars/polars/functions/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
with contextlib.suppress(ImportError): # Module not available when building docs
import polars.polars as plr


if TYPE_CHECKING:
from typing import Awaitable, Collection, Literal

Expand Down Expand Up @@ -518,18 +517,16 @@ def approx_n_unique(*columns: str) -> Expr:
@deprecate_parameter_as_positional("column", version="0.20.4")
def first(*columns: str) -> Expr:
"""
Get the first value.
This function has different behavior depending on the input type:
Get the first column or value.
- `None` -> Takes first column of a context (equivalent to `cs.first()`).
- `str` or `[str,]` -> Syntactic sugar for `pl.col(columns).first()`.
This function has different behavior depending on the presence of `columns`
values. If none given (the default), returns an expression that takes the first
column of the context; otherwise, takes the first value of the given column(s).
Parameters
----------
*columns
One or more column names. If not provided (default), returns an expression
to take the first column of the context instead.
One or more column names.
Examples
--------
Expand All @@ -540,6 +537,9 @@ def first(*columns: str) -> Expr:
... "c": ["foo", "bar", "baz"],
... }
... )
Return the first column:
>>> df.select(pl.first())
shape: (3, 1)
┌─────┐
Expand All @@ -551,6 +551,9 @@ def first(*columns: str) -> Expr:
│ 8 │
│ 3 │
└─────┘
Return the first value for the given column(s):
>>> df.select(pl.first("b"))
shape: (1, 1)
┌─────┐
Expand Down Expand Up @@ -580,18 +583,16 @@ def first(*columns: str) -> Expr:
@deprecate_parameter_as_positional("column", version="0.20.4")
def last(*columns: str) -> Expr:
"""
Get the last value.
Get the last column or value.
This function has different behavior depending on the input type:
- `None` -> Takes last column of a context (equivalent to `cs.last()`).
- `str` or `[str,]` -> Syntactic sugar for `pl.col(columns).last()`.
This function has different behavior depending on the presence of `columns`
values. If none given (the default), returns an expression that takes the last
column of the context; otherwise, takes the last value of the given column(s).
Parameters
----------
*columns
One or more column names. If set to `None` (default), returns an expression
to take the last column of the context instead.
One or more column names.
Examples
--------
Expand All @@ -602,6 +603,9 @@ def last(*columns: str) -> Expr:
... "c": ["foo", "bar", "baz"],
... }
... )
Return the last column:
>>> df.select(pl.last())
shape: (3, 1)
┌─────┐
Expand All @@ -613,6 +617,9 @@ def last(*columns: str) -> Expr:
│ bar │
│ baz │
└─────┘
Return the last value for the given column(s):
>>> df.select(pl.last("a"))
shape: (1, 1)
┌─────┐
Expand All @@ -639,6 +646,65 @@ def last(*columns: str) -> Expr:
return F.col(*columns).last()


def nth(n: int, *columns: str) -> Expr:
"""
Get the nth column or value.
This function has different behavior depending on the presence of `columns`
values. If none given (the default), returns an expression that takes the nth
column of the context; otherwise, takes the nth value of the given column(s).
Parameters
----------
n
Index of the column (or value) to get.
*columns
One or more column names. If omitted (the default), returns an
expression that takes the nth column of the context. Otherwise,
returns takes the nth value of the given column(s).
Examples
--------
>>> df = pl.DataFrame(
... {
... "a": [1, 8, 3],
... "b": [4, 5, 2],
... "c": ["foo", "bar", "baz"],
... }
... )
Return the "nth" column:
>>> df.select(pl.nth(1))
shape: (3, 1)
┌─────┐
│ b │
│ --- │
│ i64 │
╞═════╡
│ 4 │
│ 5 │
│ 2 │
└─────┘
Return the "nth" value for the given columns:
>>> df.select(pl.nth(-2, "b", "c"))
shape: (1, 2)
┌─────┬─────┐
│ b ┆ c │
│ --- ┆ --- │
│ i64 ┆ str │
╞═════╪═════╡
│ 5 ┆ bar │
└─────┴─────┘
"""
if not columns:
return wrap_expr(plr.nth(n))

return F.col(*columns).get(n)


def head(column: str, n: int = 10) -> Expr:
"""
Get the first `n` rows.
Expand Down
5 changes: 5 additions & 0 deletions py-polars/src/functions/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,11 @@ pub fn last() -> PyExpr {
dsl::last().into()
}

#[pyfunction]
pub fn nth(n: i64) -> PyExpr {
dsl::nth(n).into()
}

#[pyfunction]
pub fn lit(value: &PyAny, allow_object: bool) -> PyResult<PyExpr> {
if value.is_instance_of::<PyBool>() {
Expand Down
1 change: 1 addition & 0 deletions py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ fn polars(py: Python, m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(functions::last)).unwrap();
m.add_wrapped(wrap_pyfunction!(functions::lit)).unwrap();
m.add_wrapped(wrap_pyfunction!(functions::map_mul)).unwrap();
m.add_wrapped(wrap_pyfunction!(functions::nth)).unwrap();
m.add_wrapped(wrap_pyfunction!(functions::pearson_corr))
.unwrap();
m.add_wrapped(wrap_pyfunction!(functions::rolling_corr))
Expand Down
11 changes: 10 additions & 1 deletion py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -2178,14 +2178,23 @@ def test_product() -> None:
assert_frame_equal(out, expected, check_dtype=False)


def test_first_last_expression(fruits_cars: pl.DataFrame) -> None:
def test_first_last_nth_expressions(fruits_cars: pl.DataFrame) -> None:
df = fruits_cars
out = df.select(pl.first())
assert out.columns == ["A"]

out = df.select(pl.last())
assert out.columns == ["cars"]

out = df.select(pl.nth(0))
assert out.columns == ["A"]

out = df.select(pl.nth(1))
assert out.columns == ["fruits"]

out = df.select(pl.nth(-2))
assert out.columns == ["B"]


def test_is_between(fruits_cars: pl.DataFrame) -> None:
result = fruits_cars.select(pl.col("A").is_between(2, 4)).to_series()
Expand Down
3 changes: 3 additions & 0 deletions py-polars/tests/unit/functions/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ def test_lazy_functions() -> None:
pl.first("a").name.suffix("_first"),
pl.first("b", "c").name.suffix("_first"),
pl.last("c", "b", "a").name.suffix("_last"),
pl.nth(1, "c", "a").name.suffix("_nth1"),
)
expected: dict[str, list[Any]] = {
"b_var": [1.0],
Expand All @@ -469,6 +470,8 @@ def test_lazy_functions() -> None:
"c_last": [4.0],
"b_last": [3],
"a_last": ["foo"],
"c_nth1": [2.0],
"a_nth1": ["bar"],
}
assert_frame_equal(
out,
Expand Down

0 comments on commit ddc30ab

Please sign in to comment.