Skip to content

Commit

Permalink
Avoid extra expression duplication when push filter
Browse files Browse the repository at this point in the history
fix

fix

fix

fix

fix

update

update

fix

optimize split condition

optimize split condition

update

fix

fix

fix

fix

fix

fix

fix

update

fix

optimize Subquery

fix related InferFilters

rewrite With immediately

update

update

update
  • Loading branch information
zml1206 committed Jan 2, 2025
1 parent 492fcd8 commit 8fd06da
Show file tree
Hide file tree
Showing 13 changed files with 678 additions and 593 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,20 @@ abstract class Optimizer(catalogManager: CatalogManager)
val operatorOptimizationBatch: Seq[Batch] = Seq(
Batch("Operator Optimization before Inferring Filters", fixedPoint,
operatorOptimizationRuleSet: _*),
Batch("Rewrite With expression", fixedPoint,
RewriteWithExpression,
CollapseProject),
Batch("Infer Filters", Once,
InferFiltersFromGenerate,
InferFiltersFromConstraints),
Batch("Operator Optimization after Inferring Filters", fixedPoint,
operatorOptimizationRuleSet: _*),
Batch("Push extra predicate through join", fixedPoint,
PushExtraPredicateThroughJoin,
PushDownPredicates))
PushDownPredicates),
Batch("Rewrite With expression", fixedPoint,
RewriteWithExpression,
CollapseProject))

val batches: Seq[Batch] = flattenBatches(Seq(
Batch("Finish Analysis", FixedPoint(1), FinishAnalysis),
Expand Down Expand Up @@ -1811,7 +1817,7 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe
case Filter(condition, project @ Project(fields, grandChild))
if fields.forall(_.deterministic) && canPushThroughCondition(grandChild, condition) =>
val aliasMap = getAliasMap(project)
project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild))
project.copy(child = Filter(rewriteCondition(condition, aliasMap), grandChild))

// We can push down deterministic predicate through Aggregate, including throwable predicate.
// If we can push down a filter through Aggregate, it means the filter only references the
Expand All @@ -1831,8 +1837,7 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe
}

if (pushDown.nonEmpty) {
val pushDownPredicate = pushDown.reduce(And)
val replaced = replaceAlias(pushDownPredicate, aliasMap)
val replaced = rewriteCondition(pushDown.reduce(And), aliasMap)
val newAggregate = aggregate.copy(child = Filter(replaced, aggregate.child))
// If there is no more filter to stay up, just eliminate the filter.
// Otherwise, create "Filter(stayUp) <- Aggregate <- Filter(pushDownPredicate)".
Expand Down Expand Up @@ -1978,6 +1983,64 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe
case _ => false
}
}

private def rewriteCondition(
cond: Expression,
aliasMap: AttributeMap[Alias]): Expression = {
replaceAlias(rewriteConditionByWith(cond, aliasMap), aliasMap)
}

/**
* Use [[With]] to rewrite condition which contains attribute that are not cheap and be consumed
* multiple times. Each predicate generates one or 0 With. For facilitates subsequent merge
* [[With]], use the same CommonExpressionDef ids for different [[With]].
*/
private def rewriteConditionByWith(
cond: Expression,
aliasMap: AttributeMap[Alias]): Expression = {
if (!SQLConf.get.getConf(SQLConf.ALWAYS_INLINE_COMMON_EXPR)) {
// SubqueryExpression can't contain common expression ref, replace alias for it first.
val newCond = replaceAliasForSubqueryExpression(cond, aliasMap)
val replaceWithMap = newCond.collect {case a: Attribute => a }
.groupBy(identity)
.transform((_, v) => v.size)
.filter(m => aliasMap.contains(m._1) && m._2 > 1)
.map(m => m._1 -> trimAliases(aliasMap.getOrElse(m._1, m._1)))
.filter(m => !CollapseProject.isCheap(m._2))
if (replaceWithMap.isEmpty) {
newCond
} else {
val defsMap = AttributeMap(replaceWithMap.map(m => m._1 -> CommonExpressionDef(m._2)))
val refsMap = AttributeMap(defsMap.map(m => m._1 -> new CommonExpressionRef(m._2)))
splitConjunctivePredicates(newCond)
.map(rewriteByWith(_, defsMap, refsMap))
.reduce(And)
}
} else cond
}

private def replaceAliasForSubqueryExpression(
expr: Expression,
aliasMap: AttributeMap[Alias]): Expression = {
expr.transform {
case s: SubqueryExpression => replaceAlias(s, aliasMap)
}
}

private def rewriteByWith(
expr: Expression,
defsMap: AttributeMap[CommonExpressionDef],
refsMap: AttributeMap[CommonExpressionRef]): Expression = {
val defs = mutable.HashSet.empty[CommonExpressionDef]
val replaced = expr.transform {
case a: Attribute if refsMap.contains(a) =>
defs.add(defsMap(a))
refsMap(a)
}
if (defs.nonEmpty) {
With(replaced, defs.toSeq)
} else expr
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

Expand All @@ -45,7 +46,10 @@ class FilterPushdownSuite extends PlanTest {
CollapseProject) ::
Batch("Push extra predicate through join", FixedPoint(10),
PushExtraPredicateThroughJoin,
PushDownPredicates) :: Nil
PushDownPredicates) ::
Batch("Rewrite With expression", FixedPoint(10),
RewriteWithExpression,
CollapseProject) :: Nil
}

val attrA = $"a".int
Expand Down Expand Up @@ -1539,4 +1543,37 @@ class FilterPushdownSuite extends PlanTest {
.analyze
comparePlans(optimizedQueryWithoutStep, correctAnswer)
}

test("SPARK-50589: avoid extra expression duplication when push filter") {
withSQLConf(SQLConf.USE_COMMON_EXPR_ID_FOR_ALIAS.key -> "false") {
// through project
val originalQuery1 = testRelation
.select($"a" + $"b" as "add", $"a" - $"b" as "sub")
.where($"add" < 10 && $"add" + $"add" > 10 && $"sub" > 0)
val optimized1 = Optimize.execute(originalQuery1.analyze)
val correctAnswer1 = testRelation
.select($"a", $"b", $"c", $"a" + $"b" as "_common_expr_0")
.where($"_common_expr_0" < 10 &&
$"_common_expr_0" + $"_common_expr_0" > 10 &&
$"a" - $"b" > 0)
.select($"a" + $"b" as "add", $"a" - $"b" as "sub")
.analyze
comparePlans(optimized1, correctAnswer1)

// through aggregate
val originalQuery2 = testRelation
.groupBy($"a")($"a", $"a" + $"a" as "add", abs($"a") as "abs", count(1) as "ct")
.where($"add" < 10 && $"add" + $"add" > 10 && $"abs" > 5)
val optimized2 = Optimize.execute(originalQuery2.analyze)
val correctAnswer2 = testRelation
.select($"a", $"b", $"c", $"a" + $"a" as "_common_expr_0")
.where($"_common_expr_0" < 10 &&
$"_common_expr_0" + $"_common_expr_0" > 10 &&
abs($"a") > 5)
.select($"a", $"b", $"c")
.groupBy($"a")($"a", $"a" + $"a" as "add", abs($"a") as "abs", count(1) as "ct")
.analyze
comparePlans(optimized2, correctAnswer2)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,24 @@ class InferFiltersFromConstraintsSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("InferAndPushDownFilters", FixedPoint(100),
Batch("PushDownFilters", FixedPoint(100),
PushPredicateThroughJoin,
PushPredicateThroughNonJoin,
PushPredicateThroughNonJoin) ::
Batch("Rewrite With expression", FixedPoint(10),
RewriteWithExpression,
CollapseProject) ::
Batch("InferFilters", Once,
InferFiltersFromConstraints,
CombineFilters,
SimplifyBinaryComparison,
BooleanSimplification,
PruneFilters) :: Nil
PruneFilters) ::
Batch("PushDownFilters", FixedPoint(100),
PushPredicateThroughJoin,
PushPredicateThroughNonJoin) ::
Batch("Rewrite With expression", FixedPoint(10),
RewriteWithExpression,
CollapseProject) :: Nil
}

val testRelation = LocalRelation($"a".int, $"b".int, $"c".int)
Expand Down Expand Up @@ -144,21 +154,24 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
}

test("inner join with alias: alias contains multiple attributes") {
val t1 = testRelation.subquery("t1")
val t2 = testRelation.subquery("t2")

val originalQuery = t1.select($"a", Coalesce(Seq($"a", $"b")).as("int_col")).as("t")
.join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
.analyze
val correctAnswer = t1
.where(IsNotNull($"a") && IsNotNull(Coalesce(Seq($"a", $"b"))) &&
$"a" === Coalesce(Seq($"a", $"b")))
.select($"a", Coalesce(Seq($"a", $"b")).as("int_col")).as("t")
.join(t2.where(IsNotNull($"a")), Inner,
Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
.analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
withSQLConf(SQLConf.USE_COMMON_EXPR_ID_FOR_ALIAS.key -> "false") {
val t1 = testRelation.subquery("t1")
val t2 = testRelation.subquery("t2")

val originalQuery = t1.select($"a", Coalesce(Seq($"a", $"b")).as("int_col")).as("t")
.join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
.analyze
val correctAnswer = t1
.select($"a", $"b", $"c", Coalesce(Seq($"a", $"b")) as "_common_expr_0")
.where(IsNotNull($"a") && IsNotNull($"_common_expr_0") &&
$"a" === $"_common_expr_0")
.select($"a", Coalesce(Seq($"a", $"b")).as("int_col")).as("t")
.join(t2.where(IsNotNull($"a")), Inner,
Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
.analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}
}

test("inner join with alias: alias contains single attributes") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class SparkOptimizer(
RewriteDistinctAggregates),
Batch("Pushdown Filters from PartitionPruning", fixedPoint,
PushDownPredicates),
Batch("Rewrite With expression", fixedPoint,
RewriteWithExpression,
CollapseProject),
Batch("Cleanup filters that cannot be pushed down", Once,
CleanupDynamicPruningFilters,
// cleanup the unnecessary TrueLiteral predicates
Expand All @@ -89,6 +92,9 @@ class SparkOptimizer(
PushPredicateThroughNonJoin,
PushProjectionThroughLimit,
RemoveNoopOperators),
Batch("Rewrite With expression", fixedPoint,
RewriteWithExpression,
CollapseProject),
Batch("Infer window group limit", Once,
InferWindowGroupLimit,
LimitPushDown,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
== Physical Plan ==
TakeOrderedAndProject (24)
+- * Filter (23)
+- * HashAggregate (22)
+- Exchange (21)
+- * HashAggregate (20)
+- * Project (19)
+- * BroadcastHashJoin Inner BuildRight (18)
:- * Project (13)
: +- * BroadcastHashJoin Inner BuildRight (12)
: :- * Project (10)
: : +- * BroadcastHashJoin Inner BuildRight (9)
: : :- * Filter (3)
: : : +- * ColumnarToRow (2)
: : : +- Scan parquet spark_catalog.default.inventory (1)
: : +- BroadcastExchange (8)
: : +- * Project (7)
: : +- * Filter (6)
: : +- * ColumnarToRow (5)
: : +- Scan parquet spark_catalog.default.item (4)
: +- ReusedExchange (11)
+- BroadcastExchange (17)
+- * Filter (16)
+- * ColumnarToRow (15)
+- Scan parquet spark_catalog.default.warehouse (14)
TakeOrderedAndProject (26)
+- * Project (25)
+- * Filter (24)
+- * Project (23)
+- * HashAggregate (22)
+- Exchange (21)
+- * HashAggregate (20)
+- * Project (19)
+- * BroadcastHashJoin Inner BuildRight (18)
:- * Project (13)
: +- * BroadcastHashJoin Inner BuildRight (12)
: :- * Project (10)
: : +- * BroadcastHashJoin Inner BuildRight (9)
: : :- * Filter (3)
: : : +- * ColumnarToRow (2)
: : : +- Scan parquet spark_catalog.default.inventory (1)
: : +- BroadcastExchange (8)
: : +- * Project (7)
: : +- * Filter (6)
: : +- * ColumnarToRow (5)
: : +- Scan parquet spark_catalog.default.item (4)
: +- ReusedExchange (11)
+- BroadcastExchange (17)
+- * Filter (16)
+- * ColumnarToRow (15)
+- Scan parquet spark_catalog.default.warehouse (14)


(1) Scan parquet spark_catalog.default.inventory
Expand Down Expand Up @@ -72,7 +74,7 @@ Join condition: None
Output [4]: [inv_warehouse_sk#2, inv_quantity_on_hand#3, inv_date_sk#4, i_item_id#7]
Input [6]: [inv_item_sk#1, inv_warehouse_sk#2, inv_quantity_on_hand#3, inv_date_sk#4, i_item_sk#6, i_item_id#7]

(11) ReusedExchange [Reuses operator id: 28]
(11) ReusedExchange [Reuses operator id: 30]
Output [2]: [d_date_sk#9, d_date#10]

(12) BroadcastHashJoin [codegen id : 4]
Expand Down Expand Up @@ -131,38 +133,46 @@ Functions [2]: [sum(CASE WHEN (d_date#10 < 2000-03-11) THEN inv_quantity_on_hand
Aggregate Attributes [2]: [sum(CASE WHEN (d_date#10 < 2000-03-11) THEN inv_quantity_on_hand#3 ELSE 0 END)#17, sum(CASE WHEN (d_date#10 >= 2000-03-11) THEN inv_quantity_on_hand#3 ELSE 0 END)#18]
Results [4]: [w_warehouse_name#12, i_item_id#7, sum(CASE WHEN (d_date#10 < 2000-03-11) THEN inv_quantity_on_hand#3 ELSE 0 END)#17 AS inv_before#19, sum(CASE WHEN (d_date#10 >= 2000-03-11) THEN inv_quantity_on_hand#3 ELSE 0 END)#18 AS inv_after#20]

(23) Filter [codegen id : 5]
(23) Project [codegen id : 5]
Output [5]: [w_warehouse_name#12, i_item_id#7, inv_before#19, inv_after#20, CASE WHEN (inv_before#19 > 0) THEN (cast(inv_after#20 as double) / cast(inv_before#19 as double)) END AS _common_expr_0#21]
Input [4]: [w_warehouse_name#12, i_item_id#7, inv_before#19, inv_after#20]
Condition : (CASE WHEN (inv_before#19 > 0) THEN ((cast(inv_after#20 as double) / cast(inv_before#19 as double)) >= 0.666667) END AND CASE WHEN (inv_before#19 > 0) THEN ((cast(inv_after#20 as double) / cast(inv_before#19 as double)) <= 1.5) END)

(24) TakeOrderedAndProject
(24) Filter [codegen id : 5]
Input [5]: [w_warehouse_name#12, i_item_id#7, inv_before#19, inv_after#20, _common_expr_0#21]
Condition : ((isnotnull(_common_expr_0#21) AND (_common_expr_0#21 >= 0.666667)) AND (_common_expr_0#21 <= 1.5))

(25) Project [codegen id : 5]
Output [4]: [w_warehouse_name#12, i_item_id#7, inv_before#19, inv_after#20]
Input [5]: [w_warehouse_name#12, i_item_id#7, inv_before#19, inv_after#20, _common_expr_0#21]

(26) TakeOrderedAndProject
Input [4]: [w_warehouse_name#12, i_item_id#7, inv_before#19, inv_after#20]
Arguments: 100, [w_warehouse_name#12 ASC NULLS FIRST, i_item_id#7 ASC NULLS FIRST], [w_warehouse_name#12, i_item_id#7, inv_before#19, inv_after#20]

===== Subqueries =====

Subquery:1 Hosting operator id = 1 Hosting Expression = inv_date_sk#4 IN dynamicpruning#5
BroadcastExchange (28)
+- * Filter (27)
+- * ColumnarToRow (26)
+- Scan parquet spark_catalog.default.date_dim (25)
BroadcastExchange (30)
+- * Filter (29)
+- * ColumnarToRow (28)
+- Scan parquet spark_catalog.default.date_dim (27)


(25) Scan parquet spark_catalog.default.date_dim
(27) Scan parquet spark_catalog.default.date_dim
Output [2]: [d_date_sk#9, d_date#10]
Batched: true
Location [not included in comparison]/{warehouse_dir}/date_dim]
PushedFilters: [IsNotNull(d_date), GreaterThanOrEqual(d_date,2000-02-10), LessThanOrEqual(d_date,2000-04-10), IsNotNull(d_date_sk)]
ReadSchema: struct<d_date_sk:int,d_date:date>

(26) ColumnarToRow [codegen id : 1]
(28) ColumnarToRow [codegen id : 1]
Input [2]: [d_date_sk#9, d_date#10]

(27) Filter [codegen id : 1]
(29) Filter [codegen id : 1]
Input [2]: [d_date_sk#9, d_date#10]
Condition : (((isnotnull(d_date#10) AND (d_date#10 >= 2000-02-10)) AND (d_date#10 <= 2000-04-10)) AND isnotnull(d_date_sk#9))

(28) BroadcastExchange
(30) BroadcastExchange
Input [2]: [d_date_sk#9, d_date#10]
Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [plan_id=4]

Expand Down
Loading

0 comments on commit 8fd06da

Please sign in to comment.