Skip to content

Commit

Permalink
fix: Fix incorrect result from inequality filter after join on LazyFr…
Browse files Browse the repository at this point in the history
…ame (#19898)
  • Loading branch information
nameexhaustion authored Nov 21, 2024
1 parent 3925085 commit bbb4b2b
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 20 deletions.
30 changes: 15 additions & 15 deletions crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,25 @@ fn should_block_join_specific(
// any operation that checks for equality or ordering can be wrong because
// the join can produce null values
// TODO! check if we can be less conservative here
BinaryExpr { op, left, right } => match op {
Operator::NotEq => LeftRight(false, false),
Operator::Eq => {
let LeftRight(bleft, bright) = join_produces_null(how);
BinaryExpr {
op: Operator::Eq | Operator::NotEq,
left,
right,
} => {
let LeftRight(bleft, bright) = join_produces_null(how);

let l_name = aexpr_output_name(*left, expr_arena).unwrap();
let r_name = aexpr_output_name(*right, expr_arena).unwrap();
let l_name = aexpr_output_name(*left, expr_arena).unwrap();
let r_name = aexpr_output_name(*right, expr_arena).unwrap();

let is_in_on = on_names.contains(&l_name) || on_names.contains(&r_name);
let is_in_on = on_names.contains(&l_name) || on_names.contains(&r_name);

let block_left =
is_in_on && (schema_left.contains(&l_name) || schema_left.contains(&r_name));
let block_right =
is_in_on && (schema_right.contains(&l_name) || schema_right.contains(&r_name));
LeftRight(block_left | bleft, block_right | bright)
},
_ => join_produces_null(how),
let block_left =
is_in_on && (schema_left.contains(&l_name) || schema_left.contains(&r_name));
let block_right =
is_in_on && (schema_right.contains(&l_name) || schema_right.contains(&r_name));
LeftRight(block_left | bleft, block_right | bright)
},
_ => LeftRight(false, false),
_ => join_produces_null(how),
}
}

Expand Down
35 changes: 30 additions & 5 deletions py-polars/tests/unit/test_predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,16 +555,41 @@ def test_predicate_pushdown_struct_unnest_19632() -> None:
)


def test_predicate_pushdown_right_join_19772() -> None:
left = pl.LazyFrame({"k": [1], "v": [7]})
right = pl.LazyFrame({"k": [1, 2]})
@pytest.mark.parametrize(
"predicate",
[
pl.col("v") == 7,
pl.col("v") != 99,
pl.col("v") > 0,
pl.col("v") < 999,
pl.col("v").is_in([7]),
pl.col("v").cast(pl.Boolean),
pl.col("b"),
],
)
@pytest.mark.parametrize("alias", [True, False])
@pytest.mark.parametrize("join_type", ["left", "right"])
def test_predicate_pushdown_join_19772(
predicate: pl.Expr, join_type: str, alias: bool
) -> None:
left = pl.LazyFrame({"k": [1, 2]})
right = pl.LazyFrame({"k": [1], "v": [7], "b": True})

if join_type == "right":
[left, right] = [right, left]

q = left.join(right, on="k", how="right").filter(pl.col("v") == 7)
if alias:
predicate = predicate.alias(":V")

q = left.join(right, on="k", how=join_type).filter(predicate) # type: ignore[arg-type]

plan = q.explain()
assert plan.startswith("FILTER")

expect = pl.DataFrame({"v": 7, "k": 1})
expect = pl.DataFrame({"k": 1, "v": 7, "b": True})

if join_type == "right":
expect = expect.select("v", "b", "k")

assert_frame_equal(q.collect(no_optimization=True), expect)
assert_frame_equal(q.collect(), expect)

0 comments on commit bbb4b2b

Please sign in to comment.