Skip to content

Commit

Permalink
rebase changes
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth-vi committed Dec 6, 2024
1 parent 6fde6d3 commit c9cc500
Showing 1 changed file with 23 additions and 3 deletions.
26 changes: 23 additions & 3 deletions py-polars/tests/unit/operations/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,30 @@ def test_sort_by_exprs() -> None:
assert out.to_list() == [1, -1, 2, -2]


def test_arg_sort_nulls() -> None:
@pytest.mark.parametrize(
("sort_function", "expected"),
[
(lambda x: x, ([0, 1, 2, 3, 4], [3, 4, 0, 1, 2])),
(
lambda x: x.sort(descending=False, nulls_last=True),
([0, 1, 2, 3, 4], [3, 4, 0, 1, 2]),
),
(
lambda x: x.sort(descending=False, nulls_last=False),
([2, 3, 4, 0, 1], [0, 1, 2, 3, 4]),
),
],
)
def test_arg_sort_nulls(
sort_function: Callable[[pl.Series], pl.Series],
expected: tuple[list[int], list[int]],
) -> None:
a = pl.Series("a", [1.0, 2.0, 3.0, None, None])

assert a.arg_sort(nulls_last=True).to_list() == [0, 1, 2, 3, 4]
assert a.arg_sort(nulls_last=False).to_list() == [3, 4, 0, 1, 2]
a = sort_function(a)

assert a.arg_sort(nulls_last=True).to_list() == expected[0]
assert a.arg_sort(nulls_last=False).to_list() == expected[1]

res = a.to_frame().sort(by="a", nulls_last=False).to_series().to_list()
assert res == [None, None, 1.0, 2.0, 3.0]
Expand All @@ -208,6 +227,7 @@ def test_arg_sort_nulls() -> None:
assert res == [1.0, 2.0, 3.0, None, None]



@pytest.mark.parametrize(
("nulls_last", "expected"),
[
Expand Down

0 comments on commit c9cc500

Please sign in to comment.