Skip to content

Commit 641a7ed

Browse files
committed
Support related nested WITH expression
1 parent 7ae939a commit 641a7ed

File tree

2 files changed

+55
-29
lines changed

2 files changed

+55
-29
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala

+24-18
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,11 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
9090
e match {
9191
case w: With =>
9292
// Rewrite nested With expressions first
93-
val child = rewriteWithExprAndInputPlans(w.child, inputPlans)
9493
val defs = w.defs.map(rewriteWithExprAndInputPlans(_, inputPlans))
95-
val refToExpr = mutable.HashMap.empty[CommonExpressionId, Expression]
9694
val childProjections = Array.fill(inputPlans.length)(mutable.ArrayBuffer.empty[Alias])
95+
val refToExpr = mutable.HashMap.empty[CommonExpressionId, Expression]
9796

9897
defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) =>
99-
if (child.containsPattern(COMMON_EXPR_REF)) {
100-
throw SparkException.internalError(
101-
"Common expression definition cannot reference other Common expression definitions")
102-
}
10398
if (id.canonicalized) {
10499
throw SparkException.internalError(
105100
"Cannot rewrite canonicalized Common expression definitions")
@@ -140,25 +135,23 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
140135
}
141136
}
142137

138+
// Expression definitions apply to the entire With child contains internally nested with,
139+
// so need to generate the project first, then generate child's project
143140
for (i <- inputPlans.indices) {
144141
val projectList = childProjections(i)
145142
if (projectList.nonEmpty) {
146143
inputPlans(i) = Project(inputPlans(i).output ++ projectList, inputPlans(i))
147144
}
148145
}
149-
150-
child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) {
151-
case ref: CommonExpressionRef =>
152-
if (!refToExpr.contains(ref.id)) {
153-
throw SparkException.internalError("Undefined common expression id " + ref.id)
154-
}
155-
if (ref.id.canonicalized) {
156-
throw SparkException.internalError(
157-
"Cannot rewrite canonicalized Common expression references")
158-
}
159-
refToExpr(ref.id)
146+
val child = rewriteWithExprAndInputPlans(w.child, inputPlans)
147+
// Internally nested with may also contain outer expression references
148+
for (i <- inputPlans.indices) {
149+
val projectList = childProjections(i)
150+
if (projectList.nonEmpty) {
151+
inputPlans(i) = inputPlans(i).mapExpressions(replaceRef(_, refToExpr))
152+
}
160153
}
161-
154+
replaceRef(child, refToExpr)
162155
case c: ConditionalExpression =>
163156
val newAlwaysEvaluatedInputs = c.alwaysEvaluatedInputs.map(
164157
rewriteWithExprAndInputPlans(_, inputPlans))
@@ -177,4 +170,17 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
177170
case other => other.mapChildren(rewriteWithExprAndInputPlans(_, inputPlans))
178171
}
179172
}
173+
174+
private def replaceRef(
175+
expr: Expression,
176+
refToExpr: mutable.Map[CommonExpressionId, Expression]): Expression =
177+
expr.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) {
178+
case ref: CommonExpressionRef if refToExpr.contains(ref.id) =>
179+
if (ref.id.canonicalized) {
180+
throw SparkException.internalError(
181+
"Cannot rewrite canonicalized Common expression references")
182+
} else {
183+
refToExpr(ref.id)
184+
}
185+
}
180186
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala

+31-11
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.sql.catalyst.optimizer
1919

20-
import org.apache.spark.SparkException
2120
import org.apache.spark.sql.catalyst.analysis.TempResolvedColumn
2221
import org.apache.spark.sql.catalyst.dsl.expressions._
2322
import org.apache.spark.sql.catalyst.dsl.plans._
@@ -113,14 +112,14 @@ class RewriteWithExpressionSuite extends PlanTest {
113112
comparePlans(
114113
Optimizer.execute(plan),
115114
testRelation
116-
.select((testRelation.output :+ rewrittenInnerExpr): _*)
117-
.select((testRelation.output :+ rewrittenInnerExpr.toAttribute :+ rewrittenOuterExpr): _*)
115+
.select((testRelation.output :+ rewrittenOuterExpr): _*)
116+
.select((testRelation.output :+ rewrittenOuterExpr.toAttribute :+ rewrittenInnerExpr): _*)
118117
.select(finalExpr.as("col"))
119118
.analyze
120119
)
121120
}
122121

123-
test("correlated nested WITH expression is not supported") {
122+
test("related nested WITH expression") {
124123
val Seq(a, b) = testRelation.output
125124
val outerCommonExprDef = CommonExpressionDef(b + b, CommonExpressionId(0))
126125
val outerRef = new CommonExpressionRef(outerCommonExprDef)
@@ -129,18 +128,39 @@ class RewriteWithExpressionSuite extends PlanTest {
129128
val commonExprDef1 = CommonExpressionDef(a + a + outerRef, CommonExpressionId(1))
130129
val ref1 = new CommonExpressionRef(commonExprDef1)
131130
val innerExpr1 = With(ref1 + ref1, Seq(commonExprDef1))
132-
133131
val outerExpr1 = With(outerRef + innerExpr1, Seq(outerCommonExprDef))
134-
intercept[SparkException](Optimizer.execute(testRelation.select(outerExpr1.as("col"))))
132+
val rewrittenInnerExpr1 = (a + a + $"_common_expr_0").as("_common_expr_1")
133+
val rewrittenOuterExpr1 = (b + b).as("_common_expr_0")
134+
val finalExpr1 = rewrittenOuterExpr1.toAttribute +
135+
(rewrittenInnerExpr1.toAttribute + rewrittenInnerExpr1.toAttribute)
136+
val plan1 = testRelation.select(outerExpr1.as("col"))
137+
comparePlans(
138+
Optimizer.execute(plan1),
139+
testRelation
140+
.select(testRelation.output :+ rewrittenOuterExpr1: _*)
141+
.select(testRelation.output :+ rewrittenOuterExpr1.toAttribute :+ rewrittenInnerExpr1: _*)
142+
.select(finalExpr1.as("col"))
143+
.analyze
144+
)
135145

146+
// The inner main expression references the outer expression
136147
val commonExprDef2 = CommonExpressionDef(a + a)
137148
val ref2 = new CommonExpressionRef(commonExprDef2)
138-
// The inner main expression references the outer expression
139-
val innerExpr2 = With(ref2 + outerRef, Seq(commonExprDef1))
140-
149+
val innerExpr2 = With(ref2 + outerRef, Seq(commonExprDef2))
141150
val outerExpr2 = With(outerRef + innerExpr2, Seq(outerCommonExprDef))
142-
intercept[SparkException](Optimizer.execute(testRelation.select(outerExpr2.as("col"))))
143-
}
151+
val rewrittenInnerExpr2 = (a + a).as("_common_expr_1")
152+
val rewrittenOuterExpr2 = (b + b).as("_common_expr_0")
153+
val finalExpr2 = rewrittenOuterExpr2.toAttribute +
154+
(rewrittenInnerExpr2.toAttribute + rewrittenOuterExpr2.toAttribute)
155+
val plan2 = testRelation.select(outerExpr2.as("col"))
156+
comparePlans(
157+
Optimizer.execute(plan2),
158+
testRelation
159+
.select(testRelation.output :+ rewrittenOuterExpr2: _*)
160+
.select(testRelation.output :+ rewrittenOuterExpr2.toAttribute :+ rewrittenInnerExpr2: _*)
161+
.select(finalExpr2.as("col"))
162+
.analyze
163+
) }
144164

145165
test("WITH expression in filter") {
146166
val a = testRelation.output.head

0 commit comments

Comments
 (0)