Skip to content

Commit

Permalink
fix: Add scalar checks for n and fill_value parameters in shift (
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller authored Feb 22, 2025
1 parent ad7fdf1 commit 789f38b
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 15 deletions.
10 changes: 2 additions & 8 deletions crates/polars-plan/src/dsl/function_expr/shift_and_fill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,14 @@ fn shift_and_fill_with_mask(s: &Column, n: i64, fill_value: &Column) -> PolarsRe

pub(super) fn shift_and_fill(args: &[Column]) -> PolarsResult<Column> {
let s = &args[0];
let n_s = &args[1];

polars_ensure!(
n_s.len() == 1,
ComputeError: "n must be a single value."
);
let n_s = n_s.cast(&DataType::Int64)?;
let n_s = &args[1].cast(&DataType::Int64)?;
let n = n_s.i64()?;

if let Some(n) = n.get(0) {
let logical = s.dtype();
let physical = s.to_physical_repr();
let fill_value_s = &args[2];
let fill_value = fill_value_s.get(0)?;
let fill_value = fill_value_s.get(0).unwrap();

use DataType::*;
match logical {
Expand Down
6 changes: 6 additions & 0 deletions crates/polars-plan/src/plans/conversion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ pub(super) fn convert_functions(

let e = to_expr_irs(input, arena)?;

// Validate inputs.
if function == FunctionExpr::ShiftAndFill {
polars_ensure!(&e[1].is_scalar(arena), ComputeError: "'n' must be scalar value");
polars_ensure!(&e[2].is_scalar(arena), ComputeError: "'fill_value' must be scalar value");
}

if state.output_name.is_none() {
// Handles special case functions like `struct.field`.
if let Some(name) = function.output_name() {
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9302,8 +9302,8 @@ def shift(self, n: int = 1, *, fill_value: IntoExpr | None = None) -> DataFrame:
Number of indices to shift forward. If a negative value is passed, values
are shifted in the opposite direction instead.
fill_value
Fill the resulting null values with this value. Accepts expression input.
Non-expression inputs are parsed as literals.
Fill the resulting null values with this value. Accepts scalar expression
input. Non-expression inputs are parsed as literals.
Notes
-----
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2638,7 +2638,7 @@ def shift(
Number of indices to shift forward. If a negative value is passed, values
are shifted in the opposite direction instead.
fill_value
Fill the resulting null values with this value.
Fill the resulting null values with this scalar value.
Notes
-----
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5621,8 +5621,8 @@ def shift(
Number of indices to shift forward. If a negative value is passed, values
are shifted in the opposite direction instead.
fill_value
Fill the resulting null values with this value. Accepts expression input.
Non-expression inputs are parsed as literals.
Fill the resulting null values with this value. Accepts scalar expression
input. Non-expression inputs are parsed as literals.
Notes
-----
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5436,8 +5436,8 @@ def shift(self, n: int = 1, *, fill_value: IntoExpr | None = None) -> Series:
Number of indices to shift forward. If a negative value is passed, values
are shifted in the opposite direction instead.
fill_value
Fill the resulting null values with this value. Accepts expression input.
Non-expression inputs are parsed as literals.
Fill the resulting null values with this value. Accepts scalar expression
input. Non-expression inputs are parsed as literals.
Notes
-----
Expand Down
59 changes: 59 additions & 0 deletions py-polars/tests/unit/operations/test_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

import polars as pl
from polars.exceptions import ComputeError
from polars.testing import assert_frame_equal, assert_series_equal


Expand Down Expand Up @@ -119,3 +120,61 @@ def test_shift_fill_value_group_logicals() -> None:
result = df.select(pl.col("d").shift(fill_value=pl.col("d").max(), n=-1).over("s"))

assert result.dtypes == [pl.Date]


def test_shift_n_null() -> None:
df = pl.DataFrame({"a": pl.Series([1, 2, 3], dtype=pl.Int32)})
out = df.shift(None) # type: ignore[arg-type]
expected = pl.DataFrame({"a": pl.Series([None, None, None], dtype=pl.Int32)})
assert_frame_equal(out, expected)

out = df.shift(None, fill_value=1) # type: ignore[arg-type]
assert_frame_equal(out, expected)

out = df.select(pl.col("a").shift(None)) # type: ignore[arg-type]
assert_frame_equal(out, expected)

out = df.select(pl.col("a").shift(None, fill_value=1)) # type: ignore[arg-type]
assert_frame_equal(out, expected)


def test_shift_n_nonscalar() -> None:
df = pl.DataFrame(
{
"a": [1, 2, 3],
"b": [4, 5, 6],
}
)
with pytest.raises(
ComputeError,
match="'n' must be scalar value",
):
# Note: Expressions are not in the signature for `n`, but they work.
# We can still verify that n is scalar up-front.
df.shift(pl.col("b"), fill_value=1) # type: ignore[arg-type]

with pytest.raises(
ComputeError,
match="'n' must be scalar value",
):
df.select(pl.col("a").shift(pl.col("b"), fill_value=1))


def test_shift_fill_value_nonscalar() -> None:
df = pl.DataFrame(
{
"a": [1, 2, 3],
"b": [4, 5, 6],
}
)
with pytest.raises(
ComputeError,
match="'fill_value' must be scalar value",
):
df.shift(1, fill_value=pl.col("b"))

with pytest.raises(
ComputeError,
match="'fill_value' must be scalar value",
):
df.select(pl.col("a").shift(1, fill_value=pl.col("b")))

0 comments on commit 789f38b

Please sign in to comment.