Skip to content

Commit 76d9b6d

Browse files
committed
optimize partial pushdown: extend attribute when child already contains it
1 parent 32f9964 commit 76d9b6d

File tree

5 files changed

+107
-67
lines changed

5 files changed

+107
-67
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

+9-8
Original file line numberDiff line numberDiff line change
@@ -1132,7 +1132,6 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
11321132
r.copy(child = p.copy(projectList = buildCleanedProjectList(l1, p.projectList)))
11331133
case Project(l1, s @ Sample(_, _, _, _, p2 @ Project(l2, _))) if isRenaming(l1, l2) =>
11341134
s.copy(child = p2.copy(projectList = buildCleanedProjectList(l1, p2.projectList)))
1135-
case p @ Project(_, child) if child.output == p.output => child
11361135
}
11371136
}
11381137

@@ -1812,7 +1811,10 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe
18121811
maxIterationsSetting = SQLConf.OPTIMIZER_MAX_ITERATIONS.key)
18131812
val batches = Seq(
18141813
Batch("RewriteWithExpression", fixedPoint, RewriteWithExpression),
1815-
Batch("Optimize after RewriteWithExpression", fixedPoint, CollapseProject, ColumnPruning)
1814+
Batch("Optimize after RewriteWithExpression", fixedPoint,
1815+
CollapseProject,
1816+
ColumnPruning,
1817+
RemoveNoopOperators)
18161818
)
18171819
}
18181820

@@ -1830,9 +1832,9 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe
18301832
val aliasMap = getAliasMap(project)
18311833
var newProjectList = fields ++ getWithAttributes(condition)
18321834
val newCondition = rewriteCondition(condition, aliasMap)
1833-
val rewriteAlias = getWithAlias(newCondition).toSet
1835+
val exprIdSet = getWithAlias(newCondition).map(_.exprId).toSet
18341836
newProjectList = newProjectList.map {
1835-
case a: Alias if rewriteAlias.contains(a) => a.toAttribute
1837+
case a: Alias if exprIdSet.contains(a.exprId) => a.toAttribute
18361838
case e => e
18371839
}
18381840
project.copy(child = Filter(newCondition, grandChild), projectList = newProjectList)
@@ -1859,8 +1861,8 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe
18591861
// propagate the attributes directly need add the groupingExpressions may cause regression.
18601862
// So Aggregate only need inline common expression from parent for original project
18611863
// inheritance.
1862-
val newAggregateExpressions =
1863-
aggregate.aggregateExpressions ++ getWithAlias(pushDown.reduce(And))
1864+
val newAggregateExpressions = aggregate.aggregateExpressions ++
1865+
getWithAlias(pushDown.reduce(And)).map(replaceAliasButKeepName(_, aliasMap))
18641866
val replaced = removeOriginAlias(rewriteCondition(pushDown.reduce(And), aliasMap))
18651867
val newAggregate = aggregate.copy(child = Filter(replaced, aggregate.child),
18661868
aggregateExpressions = newAggregateExpressions)
@@ -2038,8 +2040,7 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe
20382040
cond: Expression,
20392041
aliasMap: AttributeMap[Alias]): Expression = {
20402042
if (!SQLConf.get.getConf(SQLConf.ALWAYS_INLINE_COMMON_EXPR)) {
2041-
val replaceWithMap = cond.collect {case a: Attribute => a }
2042-
.distinct
2043+
val replaceWithMap = cond.references.toSeq
20432044
.filter(attr => aliasMap.contains(attr))
20442045
.map(attr => attr -> aliasMap(attr))
20452046
.filter(m => !CollapseProject.isCheap(m._2))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala

+10-7
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,11 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
9090
// the current operator may have extra columns if it inherits the output columns from its
9191
// child, and we need to project away the extra columns to keep the plan schema unchanged.
9292
assert(p.output.length <= newPlan.output.length)
93-
newPlan.output.diff(p.output).forall(_.name.startsWith("_"))
94-
def hasOriginAlias(expr: Expression): Boolean = {
95-
expr match {
96-
case w: With =>
97-
if (w.defs.exists(_.originAlias.nonEmpty)) true else false
98-
case e => e.children.exists(hasOriginAlias)
99-
}
93+
94+
def hasOriginAlias(expr: Expression): Boolean = expr match {
95+
case w: With =>
96+
if (w.defs.exists(_.originAlias.nonEmpty)) true else false
97+
case e => e.children.exists(hasOriginAlias)
10098
}
10199
// If this iteration contains attribute that require propagate, the column cannot be pruning.
102100
val needPropagate = p.expressions.exists(hasOriginAlias)
@@ -148,6 +146,11 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
148146
// TODO: we should calculate the ref count and also inline the common expression
149147
// if it's ref count is 1.
150148
refToExpr(id) = child
149+
} else if (originAlias.nonEmpty &&
150+
inputPlans.head.output.contains(originAlias.get.toAttribute)) {
151+
// originAlias only exists in Project or Filter. If the child already contains this
152+
// attribute, extend it.
153+
refToExpr(id) = originAlias.get.toAttribute
151154
} else {
152155
val commonExprs = commonExprsPerChild(childPlanIndex)
153156
val existingCommonExpr = commonExprs.find(_._2 == id.id)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala

+39-21
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ class FilterPushdownSuite extends PlanTest {
5252

5353
val batches = Batch("ColumnPruning and CollapseProject", FixedPoint(10),
5454
ColumnPruning,
55-
CollapseProject) :: Nil
55+
CollapseProject,
56+
RemoveNoopOperators) :: Nil
5657
}
5758

5859
val attrA = $"a".int
@@ -1556,29 +1557,46 @@ class FilterPushdownSuite extends PlanTest {
15561557
}
15571558

15581559
test("SPARK-50589: avoid extra expression duplication when push filter") {
1560+
// withSQLConf(SQLConf.USE_COMMON_EXPR_ID_FOR_ALIAS.key -> "false") {
1561+
// // through project
1562+
// val originalQuery1 = testRelation
1563+
// .select($"a" + $"b" as "add", $"a" - $"b" as "sub")
1564+
// .where($"add" < 10 && $"add" + $"add" > 10 && $"sub" > 0)
1565+
// .analyze
1566+
// val optimized1 = Optimize.execute(originalQuery1)
1567+
// comparePlans(optimized1, originalQuery1)
1568+
//
1569+
// // through aggregate
1570+
// val originalQuery2 = testRelation
1571+
// .groupBy($"a")($"a", $"a" + $"a" as "add", abs($"a") as "abs", count(1) as "ct")
1572+
// .where($"add" < 10 && $"add" + $"add" > 10 && $"abs" > 5)
1573+
// val optimized2 = Optimize.execute(originalQuery2.analyze)
1574+
// val correctAnswer2 = testRelation
1575+
// .select($"a", $"a" + $"a" as "_common_expr_0")
1576+
// .where($"_common_expr_0" < 10 &&
1577+
// $"_common_expr_0" + $"_common_expr_0" > 10 &&
1578+
// abs($"a") > 5)
1579+
// .select($"a")
1580+
// .groupBy($"a")($"a", $"a" + $"a" as "add", abs($"a") as "abs", count(1) as "ct")
1581+
// .analyze
1582+
// comparePlans(optimized2, correctAnswer2)
1583+
// }
15591584
withSQLConf(SQLConf.USE_COMMON_EXPR_ID_FOR_ALIAS.key -> "false") {
1560-
// through project
1561-
val originalQuery1 = testRelation
1562-
.select($"a" + $"b" as "add", $"a" - $"b" as "sub")
1563-
.where($"add" < 10 && $"add" + $"add" > 10 && $"sub" > 0)
1564-
.analyze
1565-
val optimized1 = Optimize.execute(originalQuery1)
1566-
comparePlans(optimized1, originalQuery1)
1567-
1568-
// through aggregate
1569-
val originalQuery2 = testRelation
1570-
.groupBy($"a")($"a", $"a" + $"a" as "add", abs($"a") as "abs", count(1) as "ct")
1571-
.where($"add" < 10 && $"add" + $"add" > 10 && $"abs" > 5)
1572-
val optimized2 = Optimize.execute(originalQuery2.analyze)
1573-
val correctAnswer2 = testRelation
1585+
// partial push down
1586+
val originalQuery3 = testRelation
1587+
.groupBy($"a")($"a", count(1) as "ct")
1588+
.select($"a" + $"a" as "add", $"ct")
1589+
.where($"add" + $"add" > 10 && $"add" > $"ct")
1590+
val optimized3 = Optimize.execute(originalQuery3.analyze)
1591+
val correctAnswer3 = testRelation
15741592
.select($"a", $"a" + $"a" as "_common_expr_0")
1575-
.where($"_common_expr_0" < 10 &&
1576-
$"_common_expr_0" + $"_common_expr_0" > 10 &&
1577-
abs($"a") > 5)
1593+
.where($"_common_expr_0" + $"_common_expr_0" > 10)
15781594
.select($"a")
1579-
.groupBy($"a")($"a", $"a" + $"a" as "add", abs($"a") as "abs", count(1) as "ct")
1595+
.groupBy($"a")(count(1) as "ct", $"a" + $"a" as "add")
1596+
.where($"add" > $"ct")
1597+
.select($"add", $"ct")
15801598
.analyze
1581-
comparePlans(optimized2, correctAnswer2)
1582-
}
1599+
comparePlans(optimized3, correctAnswer3)
15831600
}
1601+
}
15841602
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala

+1-2
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,9 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
161161
.join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
162162
.analyze
163163
val correctAnswer = t1
164-
.select($"a", Coalesce(Seq($"a", $"b")) as "int_col")
164+
.select($"a", Coalesce(Seq($"a", $"b")) as "int_col").as("t")
165165
.where(IsNotNull($"a") && IsNotNull($"int_col") &&
166166
$"a" === $"int_col")
167-
.select($"a", $"int_col").as("t")
168167
.join(t2.where(IsNotNull($"a")), Inner,
169168
Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
170169
.analyze

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala

+48-29
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,9 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor
2828
class RewriteWithExpressionSuite extends PlanTest {
2929

3030
object Optimizer extends RuleExecutor[LogicalPlan] {
31-
val batches =
32-
Batch("Rewrite With expression", FixedPoint(5),
33-
PullOutGroupingExpressions,
34-
RewriteWithExpression) ::
35-
Batch("Optimize after RewriteWithExpression", FixedPoint(5),
36-
CollapseProject,
37-
ColumnPruning) :: Nil
31+
val batches = Batch("Rewrite With expression", FixedPoint(5),
32+
PullOutGroupingExpressions,
33+
RewriteWithExpression) :: Nil
3834
}
3935

4036
private val testRelation = LocalRelation($"a".int, $"b".int)
@@ -72,7 +68,7 @@ class RewriteWithExpressionSuite extends PlanTest {
7268
comparePlans(
7369
Optimizer.execute(plan),
7470
testRelation
75-
.select((a + a).as("_common_expr_0"))
71+
.select((testRelation.output :+ (a + a).as("_common_expr_0")): _*)
7672
.select(($"_common_expr_0" * $"_common_expr_0").as("col"))
7773
.analyze
7874
)
@@ -90,8 +86,8 @@ class RewriteWithExpressionSuite extends PlanTest {
9086
comparePlans(
9187
Optimizer.execute(testRelation.select(outerExpr.as("col"))),
9288
testRelation
93-
.select(b, (a + a).as("_common_expr_0"))
94-
.select(($"_common_expr_0" + $"_common_expr_0" + b).as("_common_expr_1"))
89+
.select(star(), (a + a).as("_common_expr_0"))
90+
.select(a, b, ($"_common_expr_0" + $"_common_expr_0" + b).as("_common_expr_1"))
9591
.select(($"_common_expr_1" * $"_common_expr_1").as("col"))
9692
.analyze
9793
)
@@ -109,7 +105,8 @@ class RewriteWithExpressionSuite extends PlanTest {
109105
comparePlans(
110106
Optimizer.execute(testRelation.select(outerExpr.as("col"))),
111107
testRelation
112-
.select((b + b).as("_common_expr_1"), (a + a).as("_common_expr_0"))
108+
.select(star(), (b + b).as("_common_expr_1"))
109+
.select(star(), (a + a).as("_common_expr_0"))
113110
.select(finalExpr.as("col"))
114111
.analyze
115112
)
@@ -130,10 +127,10 @@ class RewriteWithExpressionSuite extends PlanTest {
130127
Optimizer.execute(testRelation.select(outerExpr1.as("col"))),
131128
testRelation
132129
// The first Project contains the common expression of the outer With
133-
.select(a, rewrittenOuterExpr)
130+
.select(star(), rewrittenOuterExpr)
134131
// The second Project contains the common expression of the inner With, which references
135132
// the common expression of the outer With.
136-
.select($"_common_expr_0", (a + a + $"_common_expr_0").as("_common_expr_1"))
133+
.select(star(), (a + a + $"_common_expr_0").as("_common_expr_1"))
137134
// The final Project contains the final result expression, which references both common
138135
// expressions.
139136
.select(($"_common_expr_0" + ($"_common_expr_1" + $"_common_expr_1")).as("col"))
@@ -148,7 +145,11 @@ class RewriteWithExpressionSuite extends PlanTest {
148145
comparePlans(
149146
Optimizer.execute(testRelation.select(outerExpr2.as("col"))),
150147
testRelation
151-
.select(rewrittenOuterExpr, (a + a).as("_common_expr_2"))
148+
// The first Project contains the common expression of the outer With
149+
.select(star(), rewrittenOuterExpr)
150+
// The second Project contains the common expression of the inner With, which does not
151+
// reference the common expression of the outer With.
152+
.select(star(), (a + a).as("_common_expr_2"))
152153
// The final Project contains the final result expression, which references both common
153154
// expressions.
154155
.select(($"_common_expr_0" +
@@ -243,7 +244,7 @@ class RewriteWithExpressionSuite extends PlanTest {
243244
comparePlans(
244245
Optimizer.execute(plan2),
245246
testRelation
246-
.select(a, (a + a).as("_common_expr_0"))
247+
.select((testRelation.output :+ (a + a).as("_common_expr_0")): _*)
247248
.select(Coalesce(Seq(($"_common_expr_0" * $"_common_expr_0"), a)).as("col"))
248249
.analyze
249250
)
@@ -267,8 +268,24 @@ class RewriteWithExpressionSuite extends PlanTest {
267268
comparePlans(
268269
Optimizer.execute(plan),
269270
testRelation
270-
.select(a, (a + 1).as("_common_expr_0"))
271-
.select(Seq(
271+
.select(testRelation.output :+ (a + 1).as("_common_expr_0"): _*)
272+
.select(testRelation.output :+
273+
($"_common_expr_0" * $"_common_expr_0").as("_groupingexpression"): _*)
274+
.select(testRelation.output ++ Seq($"_groupingexpression",
275+
(a + 1).as("_common_expr_1")): _*)
276+
.groupBy($"_groupingexpression")(
277+
$"_groupingexpression",
278+
count($"_common_expr_1" * $"_common_expr_1" - 3).as("_aggregateexpression")
279+
)
280+
.select(($"_groupingexpression" + 2).as("col1"), $"_aggregateexpression".as("col2"))
281+
.analyze
282+
)
283+
// Running CollapseProject after the rule cleans up the unnecessary projections.
284+
comparePlans(
285+
CollapseProject(Optimizer.execute(plan)),
286+
testRelation
287+
.select(testRelation.output :+ (a + 1).as("_common_expr_0"): _*)
288+
.select(testRelation.output ++ Seq(
272289
($"_common_expr_0" * $"_common_expr_0").as("_groupingexpression"),
273290
(a + 1).as("_common_expr_1")): _*)
274291
.groupBy($"_groupingexpression")(
@@ -295,9 +312,9 @@ class RewriteWithExpressionSuite extends PlanTest {
295312
comparePlans(
296313
Optimizer.execute(plan),
297314
testRelation
298-
.select(a, (b + 2).as("_common_expr_0"))
299-
.groupBy(a)(a, max($"_common_expr_0" * $"_common_expr_0").as("_aggregateexpression"),
300-
(a + 1).as("_common_expr_1"))
315+
.select(testRelation.output :+ (b + 2).as("_common_expr_0"): _*)
316+
.groupBy(a)(a, max($"_common_expr_0" * $"_common_expr_0").as("_aggregateexpression"))
317+
.select(a, $"_aggregateexpression", (a + 1).as("_common_expr_1"))
301318
.select(
302319
(a + 3).as("col1"),
303320
($"_common_expr_1" * $"_common_expr_1").as("col2"),
@@ -319,7 +336,6 @@ class RewriteWithExpressionSuite extends PlanTest {
319336
comparePlans(
320337
Optimizer.execute(plan),
321338
testRelation
322-
.select(a)
323339
.groupBy(a)(a, count(a - 1).as("_aggregateexpression"))
324340
.select(
325341
(a - 1).as("col1"),
@@ -355,9 +371,9 @@ class RewriteWithExpressionSuite extends PlanTest {
355371
comparePlans(
356372
Optimizer.execute(plan),
357373
testRelation
358-
.select(a, (a + 1).as("_common_expr_0"))
359-
.groupBy(a)(max($"_common_expr_0" * $"_common_expr_0").as("_aggregateexpression"),
360-
(a - 1).as("_common_expr_1"))
374+
.select(testRelation.output :+ (a + 1).as("_common_expr_0"): _*)
375+
.groupBy(a)(a, max($"_common_expr_0" * $"_common_expr_0").as("_aggregateexpression"))
376+
.select($"a", $"_aggregateexpression", (a - 1).as("_common_expr_1"))
361377
.select(($"_common_expr_1" * $"_aggregateexpression" + $"_common_expr_1").as("col"))
362378
.analyze
363379
)
@@ -388,20 +404,22 @@ class RewriteWithExpressionSuite extends PlanTest {
388404
comparePlans(
389405
Optimizer.execute(plan),
390406
testRelation
391-
.select(a, (b + 2).as("_common_expr_0"))
407+
.select(a, b, (b + 2).as("_common_expr_0"))
392408
.window(
393409
Seq(windowExpr(count(a), windowSpec(Seq($"_common_expr_0" * $"_common_expr_0"), Nil,
394410
frame)).as("col2")),
395411
Seq($"_common_expr_0" * $"_common_expr_0"),
396412
Nil
397413
)
398-
.select(a, $"col2", (a + 1).as("_common_expr_1"))
414+
.select(a, b, $"col2")
415+
.select(a, b, $"col2", (a + 1).as("_common_expr_1"))
399416
.window(
400417
Seq(windowExpr(sum($"_common_expr_1" * $"_common_expr_1"),
401418
windowSpec(Seq(a), Nil, frame)).as("col3")),
402419
Seq(a),
403420
Nil
404421
)
422+
.select(a, b, $"col2", $"col3")
405423
.select((a - 1).as("col1"), $"col2", $"col3")
406424
.analyze
407425
)
@@ -420,7 +438,8 @@ class RewriteWithExpressionSuite extends PlanTest {
420438
testRelation
421439
.select(a)
422440
.window(Seq(winExpr.as("_we0")), Seq(a), Nil)
423-
.select(($"_we0" * $"_we0").as("col"))
441+
.select(a, $"_we0", ($"_we0" * $"_we0").as("col"))
442+
.select($"col")
424443
.analyze
425444
)
426445
}
@@ -445,7 +464,7 @@ class RewriteWithExpressionSuite extends PlanTest {
445464
val plan = testRelation.having($"b")(avg("a").as("a"))(expr).analyze
446465
comparePlans(
447466
Optimizer.execute(plan),
448-
testRelation.select($"b").groupBy($"b")(avg("a").as("a")).where($"a" === 1).analyze
467+
testRelation.groupBy($"b")(avg("a").as("a")).where($"a" === 1).analyze
449468
)
450469
}
451470

@@ -459,7 +478,7 @@ class RewriteWithExpressionSuite extends PlanTest {
459478
comparePlans(
460479
Optimizer.execute(plan),
461480
testRelation
462-
.select((a + a).as("_common_expr_0"))
481+
.select(star(), (a + a).as("_common_expr_0"))
463482
.select(
464483
($"_common_expr_0" * $"_common_expr_0").as("c1"),
465484
($"_common_expr_0" - $"_common_expr_0").as("c2"))

0 commit comments

Comments
 (0)