Skip to content

Commit 2372bc0

Browse files
committed
[SPARK-50679][SQL] Duplicated common expressions in different With should be projected only once
### What changes were proposed in this pull request? Sometimes we may need to duplicate the `With` expression with some shared common expressions. This PR improves the `With` rewriting rule to recognize it and only project the duplicated common expressions once. ### Why are the changes needed? Produce better plans. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? a new test ### Was this patch authored or co-authored using generative AI tooling? no Closes apache#49303 from cloud-fan/with. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent b8a8e0d commit 2372bc0

File tree

2 files changed

+72
-34
lines changed

2 files changed

+72
-34
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala

+45-30
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Plan
2626
import org.apache.spark.sql.catalyst.rules.Rule
2727
import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, WITH_EXPRESSION}
2828
import org.apache.spark.sql.internal.SQLConf
29+
import org.apache.spark.util.Utils
2930

3031
/**
3132
* Rewrites the `With` expressions by adding a `Project` to pre-evaluate the common expressions, or
@@ -66,11 +67,19 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
6667
}
6768

6869
private def applyInternal(p: LogicalPlan): LogicalPlan = {
69-
val inputPlans = p.children.toArray
70+
val inputPlans = p.children
71+
val commonExprsPerChild = Array.fill(inputPlans.length)(mutable.ListBuffer.empty[(Alias, Long)])
7072
var newPlan: LogicalPlan = p.mapExpressions { expr =>
71-
rewriteWithExprAndInputPlans(expr, inputPlans)
73+
rewriteWithExprAndInputPlans(expr, inputPlans, commonExprsPerChild)
7274
}
73-
newPlan = newPlan.withNewChildren(inputPlans.toIndexedSeq)
75+
val newChildren = inputPlans.zip(commonExprsPerChild).map { case (inputPlan, commonExprs) =>
76+
if (commonExprs.isEmpty) {
77+
inputPlan
78+
} else {
79+
Project(inputPlan.output ++ commonExprs.map(_._1), inputPlan)
80+
}
81+
}
82+
newPlan = newPlan.withNewChildren(newChildren)
7483
// Since we add extra Projects with extra columns to pre-evaluate the common expressions,
7584
// the current operator may have extra columns if it inherits the output columns from its
7685
// child, and we need to project away the extra columns to keep the plan schema unchanged.
@@ -85,17 +94,19 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
8594

8695
private def rewriteWithExprAndInputPlans(
8796
e: Expression,
88-
inputPlans: Array[LogicalPlan],
97+
inputPlans: Seq[LogicalPlan],
98+
commonExprsPerChild: Array[mutable.ListBuffer[(Alias, Long)]],
8999
isNestedWith: Boolean = false): Expression = {
90100
if (!e.containsPattern(WITH_EXPRESSION)) return e
91101
e match {
92102
// Do not handle nested With in one pass. Leave it to the next rule executor batch.
93103
case w: With if !isNestedWith =>
94104
// Rewrite nested With expressions first
95-
val child = rewriteWithExprAndInputPlans(w.child, inputPlans, isNestedWith = true)
96-
val defs = w.defs.map(rewriteWithExprAndInputPlans(_, inputPlans, isNestedWith = true))
105+
val child = rewriteWithExprAndInputPlans(
106+
w.child, inputPlans, commonExprsPerChild, isNestedWith = true)
107+
val defs = w.defs.map(rewriteWithExprAndInputPlans(
108+
_, inputPlans, commonExprsPerChild, isNestedWith = true))
97109
val refToExpr = mutable.HashMap.empty[CommonExpressionId, Expression]
98-
val childProjections = Array.fill(inputPlans.length)(mutable.ArrayBuffer.empty[Alias])
99110

100111
defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) =>
101112
if (id.canonicalized) {
@@ -106,10 +117,10 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
106117
if (CollapseProject.isCheap(child)) {
107118
refToExpr(id) = child
108119
} else {
109-
val childProjectionIndex = inputPlans.indexWhere(
120+
val childPlanIndex = inputPlans.indexWhere(
110121
c => child.references.subsetOf(c.outputSet)
111122
)
112-
if (childProjectionIndex == -1) {
123+
if (childPlanIndex == -1) {
113124
// When we cannot rewrite the common expressions, force to inline them so that the
114125
// query can still run. This can happen if the join condition contains `With` and
115126
// the common expression references columns from both join sides.
@@ -120,31 +131,33 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
120131
// if it's ref count is 1.
121132
refToExpr(id) = child
122133
} else {
123-
val aliasName = if (SQLConf.get.getConf(SQLConf.USE_COMMON_EXPR_ID_FOR_ALIAS)) {
124-
s"_common_expr_${id.id}"
125-
} else {
126-
s"_common_expr_$index"
127-
}
128-
val alias = Alias(child, aliasName)()
129-
val fakeProj = Project(Seq(alias), inputPlans(childProjectionIndex))
130-
if (PlanHelper.specialExpressionsInUnsupportedOperator(fakeProj).nonEmpty) {
131-
// We have to inline the common expression if it cannot be put in a Project.
132-
refToExpr(id) = child
134+
val commonExprs = commonExprsPerChild(childPlanIndex)
135+
val existingCommonExpr = commonExprs.find(_._2 == id.id)
136+
if (existingCommonExpr.isDefined) {
137+
if (Utils.isTesting) {
138+
assert(existingCommonExpr.get._1.child.semanticEquals(child))
139+
}
140+
refToExpr(id) = existingCommonExpr.get._1.toAttribute
133141
} else {
134-
childProjections(childProjectionIndex) += alias
135-
refToExpr(id) = alias.toAttribute
142+
val aliasName = if (SQLConf.get.getConf(SQLConf.USE_COMMON_EXPR_ID_FOR_ALIAS)) {
143+
s"_common_expr_${id.id}"
144+
} else {
145+
s"_common_expr_$index"
146+
}
147+
val alias = Alias(child, aliasName)()
148+
val fakeProj = Project(Seq(alias), inputPlans(childPlanIndex))
149+
if (PlanHelper.specialExpressionsInUnsupportedOperator(fakeProj).nonEmpty) {
150+
// We have to inline the common expression if it cannot be put in a Project.
151+
refToExpr(id) = child
152+
} else {
153+
commonExprs.append((alias, id.id))
154+
refToExpr(id) = alias.toAttribute
155+
}
136156
}
137157
}
138158
}
139159
}
140160

141-
for (i <- inputPlans.indices) {
142-
val projectList = childProjections(i)
143-
if (projectList.nonEmpty) {
144-
inputPlans(i) = Project(inputPlans(i).output ++ projectList, inputPlans(i))
145-
}
146-
}
147-
148161
child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) {
149162
// `child` may contain nested With and we only replace `CommonExpressionRef` that
150163
// references common expressions in the current `With`.
@@ -158,7 +171,7 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
158171

159172
case c: ConditionalExpression =>
160173
val newAlwaysEvaluatedInputs = c.alwaysEvaluatedInputs.map(
161-
rewriteWithExprAndInputPlans(_, inputPlans, isNestedWith))
174+
rewriteWithExprAndInputPlans(_, inputPlans, commonExprsPerChild, isNestedWith))
162175
val newExpr = c.withNewAlwaysEvaluatedInputs(newAlwaysEvaluatedInputs)
163176
// Use transformUp to handle nested With.
164177
newExpr.transformUpWithPruning(_.containsPattern(WITH_EXPRESSION)) {
@@ -171,7 +184,9 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
171184
}
172185
}
173186

174-
case other => other.mapChildren(rewriteWithExprAndInputPlans(_, inputPlans, isNestedWith))
187+
case other => other.mapChildren(
188+
rewriteWithExprAndInputPlans(_, inputPlans, commonExprsPerChild, isNestedWith)
189+
)
175190
}
176191
}
177192
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala

+27-4
Original file line numberDiff line numberDiff line change
@@ -404,17 +404,16 @@ class RewriteWithExpressionSuite extends PlanTest {
404404
Optimizer.execute(plan),
405405
testRelation
406406
.select(a, b, (b + 2).as("_common_expr_0"))
407-
.select(a, b, $"_common_expr_0", (b + 2).as("_common_expr_1"))
408407
.window(
409408
Seq(windowExpr(count(a), windowSpec(Seq($"_common_expr_0" * $"_common_expr_0"), Nil,
410409
frame)).as("col2")),
411-
Seq($"_common_expr_1" * $"_common_expr_1"),
410+
Seq($"_common_expr_0" * $"_common_expr_0"),
412411
Nil
413412
)
414413
.select(a, b, $"col2")
415-
.select(a, b, $"col2", (a + 1).as("_common_expr_2"))
414+
.select(a, b, $"col2", (a + 1).as("_common_expr_1"))
416415
.window(
417-
Seq(windowExpr(sum($"_common_expr_2" * $"_common_expr_2"),
416+
Seq(windowExpr(sum($"_common_expr_1" * $"_common_expr_1"),
418417
windowSpec(Seq(a), Nil, frame)).as("col3")),
419418
Seq(a),
420419
Nil
@@ -467,4 +466,28 @@ class RewriteWithExpressionSuite extends PlanTest {
467466
testRelation.groupBy($"b")(avg("a").as("a")).where($"a" === 1).analyze
468467
)
469468
}
469+
470+
test("SPARK-50679: duplicated common expressions in different With") {
471+
val a = testRelation.output.head
472+
val exprDef = CommonExpressionDef(a + a)
473+
val exprRef = new CommonExpressionRef(exprDef)
474+
val expr1 = With(exprRef * exprRef, Seq(exprDef))
475+
val expr2 = With(exprRef - exprRef, Seq(exprDef))
476+
val plan = testRelation.select(expr1.as("c1"), expr2.as("c2")).analyze
477+
comparePlans(
478+
Optimizer.execute(plan),
479+
testRelation
480+
.select(star(), (a + a).as("_common_expr_0"))
481+
.select(
482+
($"_common_expr_0" * $"_common_expr_0").as("c1"),
483+
($"_common_expr_0" - $"_common_expr_0").as("c2"))
484+
.analyze
485+
)
486+
487+
val wrongExprDef = CommonExpressionDef(a * a, exprDef.id)
488+
val wrongExprRef = new CommonExpressionRef(wrongExprDef)
489+
val expr3 = With(wrongExprRef + wrongExprRef, Seq(wrongExprDef))
490+
val wrongPlan = testRelation.select(expr1.as("c1"), expr3.as("c3")).analyze
491+
intercept[AssertionError](Optimizer.execute(wrongPlan))
492+
}
470493
}

0 commit comments

Comments
 (0)