Skip to content

Commit

Permalink
fix(rust): Allow duration * primitive to propagate in IR (#21394)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller authored Feb 22, 2025
1 parent 3719977 commit 237b506
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
8 changes: 8 additions & 0 deletions crates/polars-plan/src/plans/aexpr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,14 @@ fn get_arithmetic_field(
polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
},
},
(Duration(_), r) if r.is_primitive_numeric() => match op {
Operator::Multiply => {
return Ok(left_field);
},
_ => {
polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
},
},
#[cfg(feature = "dtype-decimal")]
(Decimal(_, Some(scale_left)), Decimal(_, Some(scale_right))) => {
let scale = match op {
Expand Down
26 changes: 24 additions & 2 deletions py-polars/tests/unit/operations/arithmetic/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,8 +705,8 @@ def test_arithmetic_duration_div_multiply() -> None:
("a", pl.Duration(time_unit="us")),
("b", pl.Duration(time_unit="us")),
("c", pl.Duration(time_unit="us")),
("d", pl.Unknown()),
("e", pl.Unknown()),
("d", pl.Duration(time_unit="us")),
("e", pl.Duration(time_unit="us")),
("f", pl.Float64()),
]
)
Expand Down Expand Up @@ -824,3 +824,25 @@ def test_raise_invalid_shape() -> None:
def test_integer_divide_scalar_zero_lhs_19142() -> None:
assert_series_equal(pl.Series([0]) // pl.Series([1, 0]), pl.Series([0, None]))
assert_series_equal(pl.Series([0]) % pl.Series([1, 0]), pl.Series([0, None]))


def test_compound_duration_21389() -> None:
# test add
lf = pl.LazyFrame(
{
"ts": datetime(2024, 1, 1, 1, 2, 3),
"duration": timedelta(days=1),
}
)
result = lf.select(pl.col("ts") + pl.col("duration") * 2)
expected_schema = pl.Schema({"ts": pl.Datetime(time_unit="us", time_zone=None)})
expected = pl.DataFrame({"ts": datetime(2024, 1, 3, 1, 2, 3)})
assert result.collect_schema() == expected_schema
assert_frame_equal(result.collect(), expected)

# test subtract
result = lf.select(pl.col("ts") - pl.col("duration") * 2)
expected_schema = pl.Schema({"ts": pl.Datetime(time_unit="us", time_zone=None)})
expected = pl.DataFrame({"ts": datetime(2023, 12, 30, 1, 2, 3)})
assert result.collect_schema() == expected_schema
assert_frame_equal(result.collect(), expected)

0 comments on commit 237b506

Please sign in to comment.