17
17
18
18
package org .apache .spark .sql .catalyst .optimizer
19
19
20
- import org .apache .spark .SparkException
21
20
import org .apache .spark .sql .catalyst .analysis .TempResolvedColumn
22
21
import org .apache .spark .sql .catalyst .dsl .expressions ._
23
22
import org .apache .spark .sql .catalyst .dsl .plans ._
@@ -113,14 +112,14 @@ class RewriteWithExpressionSuite extends PlanTest {
113
112
comparePlans(
114
113
Optimizer .execute(plan),
115
114
testRelation
116
- .select((testRelation.output :+ rewrittenInnerExpr ): _* )
117
- .select((testRelation.output :+ rewrittenInnerExpr .toAttribute :+ rewrittenOuterExpr ): _* )
115
+ .select((testRelation.output :+ rewrittenOuterExpr ): _* )
116
+ .select((testRelation.output :+ rewrittenOuterExpr .toAttribute :+ rewrittenInnerExpr ): _* )
118
117
.select(finalExpr.as(" col" ))
119
118
.analyze
120
119
)
121
120
}
122
121
123
- test(" correlated nested WITH expression is not supported " ) {
122
+ test(" related nested WITH expression" ) {
124
123
val Seq (a, b) = testRelation.output
125
124
val outerCommonExprDef = CommonExpressionDef (b + b, CommonExpressionId (0 ))
126
125
val outerRef = new CommonExpressionRef (outerCommonExprDef)
@@ -129,18 +128,39 @@ class RewriteWithExpressionSuite extends PlanTest {
129
128
val commonExprDef1 = CommonExpressionDef (a + a + outerRef, CommonExpressionId (1 ))
130
129
val ref1 = new CommonExpressionRef (commonExprDef1)
131
130
val innerExpr1 = With (ref1 + ref1, Seq (commonExprDef1))
132
-
133
131
val outerExpr1 = With (outerRef + innerExpr1, Seq (outerCommonExprDef))
134
- intercept[SparkException ](Optimizer .execute(testRelation.select(outerExpr1.as(" col" ))))
132
+ val rewrittenInnerExpr1 = (a + a + $" _common_expr_0" ).as(" _common_expr_1" )
133
+ val rewrittenOuterExpr1 = (b + b).as(" _common_expr_0" )
134
+ val finalExpr1 = rewrittenOuterExpr1.toAttribute +
135
+ (rewrittenInnerExpr1.toAttribute + rewrittenInnerExpr1.toAttribute)
136
+ val plan1 = testRelation.select(outerExpr1.as(" col" ))
137
+ comparePlans(
138
+ Optimizer .execute(plan1),
139
+ testRelation
140
+ .select(testRelation.output :+ rewrittenOuterExpr1 : _* )
141
+ .select(testRelation.output :+ rewrittenOuterExpr1.toAttribute :+ rewrittenInnerExpr1 : _* )
142
+ .select(finalExpr1.as(" col" ))
143
+ .analyze
144
+ )
135
145
146
+ // The inner main expression references the outer expression
136
147
val commonExprDef2 = CommonExpressionDef (a + a)
137
148
val ref2 = new CommonExpressionRef (commonExprDef2)
138
- // The inner main expression references the outer expression
139
- val innerExpr2 = With (ref2 + outerRef, Seq (commonExprDef1))
140
-
149
+ val innerExpr2 = With (ref2 + outerRef, Seq (commonExprDef2))
141
150
val outerExpr2 = With (outerRef + innerExpr2, Seq (outerCommonExprDef))
142
- intercept[SparkException ](Optimizer .execute(testRelation.select(outerExpr2.as(" col" ))))
143
- }
151
+ val rewrittenInnerExpr2 = (a + a).as(" _common_expr_1" )
152
+ val rewrittenOuterExpr2 = (b + b).as(" _common_expr_0" )
153
+ val finalExpr2 = rewrittenOuterExpr2.toAttribute +
154
+ (rewrittenInnerExpr2.toAttribute + rewrittenOuterExpr2.toAttribute)
155
+ val plan2 = testRelation.select(outerExpr2.as(" col" ))
156
+ comparePlans(
157
+ Optimizer .execute(plan2),
158
+ testRelation
159
+ .select(testRelation.output :+ rewrittenOuterExpr2 : _* )
160
+ .select(testRelation.output :+ rewrittenOuterExpr2.toAttribute :+ rewrittenInnerExpr2 : _* )
161
+ .select(finalExpr2.as(" col" ))
162
+ .analyze
163
+ ) }
144
164
145
165
test(" WITH expression in filter" ) {
146
166
val a = testRelation.output.head
0 commit comments