@@ -28,13 +28,9 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor
28
28
class RewriteWithExpressionSuite extends PlanTest {
29
29
30
30
object Optimizer extends RuleExecutor [LogicalPlan ] {
31
- val batches =
32
- Batch (" Rewrite With expression" , FixedPoint (5 ),
33
- PullOutGroupingExpressions ,
34
- RewriteWithExpression ) ::
35
- Batch (" Optimize after RewriteWithExpression" , FixedPoint (5 ),
36
- CollapseProject ,
37
- ColumnPruning ) :: Nil
31
+ val batches = Batch (" Rewrite With expression" , FixedPoint (5 ),
32
+ PullOutGroupingExpressions ,
33
+ RewriteWithExpression ) :: Nil
38
34
}
39
35
40
36
private val testRelation = LocalRelation ($" a" .int, $" b" .int)
@@ -72,7 +68,7 @@ class RewriteWithExpressionSuite extends PlanTest {
72
68
comparePlans(
73
69
Optimizer .execute(plan),
74
70
testRelation
75
- .select((a + a).as(" _common_expr_0" ))
71
+ .select((testRelation.output :+ ( a + a).as(" _common_expr_0" )) : _* )
76
72
.select(($" _common_expr_0" * $" _common_expr_0" ).as(" col" ))
77
73
.analyze
78
74
)
@@ -90,8 +86,8 @@ class RewriteWithExpressionSuite extends PlanTest {
90
86
comparePlans(
91
87
Optimizer .execute(testRelation.select(outerExpr.as(" col" ))),
92
88
testRelation
93
- .select(b , (a + a).as(" _common_expr_0" ))
94
- .select(($" _common_expr_0" + $" _common_expr_0" + b).as(" _common_expr_1" ))
89
+ .select(star() , (a + a).as(" _common_expr_0" ))
90
+ .select(a, b, ($" _common_expr_0" + $" _common_expr_0" + b).as(" _common_expr_1" ))
95
91
.select(($" _common_expr_1" * $" _common_expr_1" ).as(" col" ))
96
92
.analyze
97
93
)
@@ -109,7 +105,8 @@ class RewriteWithExpressionSuite extends PlanTest {
109
105
comparePlans(
110
106
Optimizer .execute(testRelation.select(outerExpr.as(" col" ))),
111
107
testRelation
112
- .select((b + b).as(" _common_expr_1" ), (a + a).as(" _common_expr_0" ))
108
+ .select(star(), (b + b).as(" _common_expr_1" ))
109
+ .select(star(), (a + a).as(" _common_expr_0" ))
113
110
.select(finalExpr.as(" col" ))
114
111
.analyze
115
112
)
@@ -130,10 +127,10 @@ class RewriteWithExpressionSuite extends PlanTest {
130
127
Optimizer .execute(testRelation.select(outerExpr1.as(" col" ))),
131
128
testRelation
132
129
// The first Project contains the common expression of the outer With
133
- .select(a , rewrittenOuterExpr)
130
+ .select(star() , rewrittenOuterExpr)
134
131
// The second Project contains the common expression of the inner With, which references
135
132
// the common expression of the outer With.
136
- .select($ " _common_expr_0 " , (a + a + $" _common_expr_0" ).as(" _common_expr_1" ))
133
+ .select(star() , (a + a + $" _common_expr_0" ).as(" _common_expr_1" ))
137
134
// The final Project contains the final result expression, which references both common
138
135
// expressions.
139
136
.select(($" _common_expr_0" + ($" _common_expr_1" + $" _common_expr_1" )).as(" col" ))
@@ -148,7 +145,11 @@ class RewriteWithExpressionSuite extends PlanTest {
148
145
comparePlans(
149
146
Optimizer .execute(testRelation.select(outerExpr2.as(" col" ))),
150
147
testRelation
151
- .select(rewrittenOuterExpr, (a + a).as(" _common_expr_2" ))
148
+ // The first Project contains the common expression of the outer With
149
+ .select(star(), rewrittenOuterExpr)
150
+ // The second Project contains the common expression of the inner With, which does not
151
+ // reference the common expression of the outer With.
152
+ .select(star(), (a + a).as(" _common_expr_2" ))
152
153
// The final Project contains the final result expression, which references both common
153
154
// expressions.
154
155
.select(($" _common_expr_0" +
@@ -243,7 +244,7 @@ class RewriteWithExpressionSuite extends PlanTest {
243
244
comparePlans(
244
245
Optimizer .execute(plan2),
245
246
testRelation
246
- .select(a, (a + a).as(" _common_expr_0" ))
247
+ .select((testRelation.output :+ (a + a).as(" _common_expr_0" )) : _* )
247
248
.select(Coalesce (Seq (($" _common_expr_0" * $" _common_expr_0" ), a)).as(" col" ))
248
249
.analyze
249
250
)
@@ -267,8 +268,24 @@ class RewriteWithExpressionSuite extends PlanTest {
267
268
comparePlans(
268
269
Optimizer .execute(plan),
269
270
testRelation
270
- .select(a, (a + 1 ).as(" _common_expr_0" ))
271
- .select(Seq (
271
+ .select(testRelation.output :+ (a + 1 ).as(" _common_expr_0" ): _* )
272
+ .select(testRelation.output :+
273
+ ($" _common_expr_0" * $" _common_expr_0" ).as(" _groupingexpression" ): _* )
274
+ .select(testRelation.output ++ Seq ($" _groupingexpression" ,
275
+ (a + 1 ).as(" _common_expr_1" )): _* )
276
+ .groupBy($" _groupingexpression" )(
277
+ $" _groupingexpression" ,
278
+ count($" _common_expr_1" * $" _common_expr_1" - 3 ).as(" _aggregateexpression" )
279
+ )
280
+ .select(($" _groupingexpression" + 2 ).as(" col1" ), $" _aggregateexpression" .as(" col2" ))
281
+ .analyze
282
+ )
283
+ // Running CollapseProject after the rule cleans up the unnecessary projections.
284
+ comparePlans(
285
+ CollapseProject (Optimizer .execute(plan)),
286
+ testRelation
287
+ .select(testRelation.output :+ (a + 1 ).as(" _common_expr_0" ): _* )
288
+ .select(testRelation.output ++ Seq (
272
289
($" _common_expr_0" * $" _common_expr_0" ).as(" _groupingexpression" ),
273
290
(a + 1 ).as(" _common_expr_1" )): _* )
274
291
.groupBy($" _groupingexpression" )(
@@ -295,9 +312,9 @@ class RewriteWithExpressionSuite extends PlanTest {
295
312
comparePlans(
296
313
Optimizer .execute(plan),
297
314
testRelation
298
- .select(a, (b + 2 ).as(" _common_expr_0" ))
299
- .groupBy(a)(a, max($" _common_expr_0" * $" _common_expr_0" ).as(" _aggregateexpression" ),
300
- (a + 1 ).as(" _common_expr_1" ))
315
+ .select(testRelation.output :+ (b + 2 ).as(" _common_expr_0" ): _* )
316
+ .groupBy(a)(a, max($" _common_expr_0" * $" _common_expr_0" ).as(" _aggregateexpression" ))
317
+ .select(a, $ " _aggregateexpression " , (a + 1 ).as(" _common_expr_1" ))
301
318
.select(
302
319
(a + 3 ).as(" col1" ),
303
320
($" _common_expr_1" * $" _common_expr_1" ).as(" col2" ),
@@ -319,7 +336,6 @@ class RewriteWithExpressionSuite extends PlanTest {
319
336
comparePlans(
320
337
Optimizer .execute(plan),
321
338
testRelation
322
- .select(a)
323
339
.groupBy(a)(a, count(a - 1 ).as(" _aggregateexpression" ))
324
340
.select(
325
341
(a - 1 ).as(" col1" ),
@@ -355,9 +371,9 @@ class RewriteWithExpressionSuite extends PlanTest {
355
371
comparePlans(
356
372
Optimizer .execute(plan),
357
373
testRelation
358
- .select(a, (a + 1 ).as(" _common_expr_0" ))
359
- .groupBy(a)(max($" _common_expr_0" * $" _common_expr_0" ).as(" _aggregateexpression" ),
360
- (a - 1 ).as(" _common_expr_1" ))
374
+ .select(testRelation.output :+ (a + 1 ).as(" _common_expr_0" ): _* )
375
+ .groupBy(a)(a, max($" _common_expr_0" * $" _common_expr_0" ).as(" _aggregateexpression" ))
376
+ .select($ " a " , $ " _aggregateexpression " , (a - 1 ).as(" _common_expr_1" ))
361
377
.select(($" _common_expr_1" * $" _aggregateexpression" + $" _common_expr_1" ).as(" col" ))
362
378
.analyze
363
379
)
@@ -388,20 +404,22 @@ class RewriteWithExpressionSuite extends PlanTest {
388
404
comparePlans(
389
405
Optimizer .execute(plan),
390
406
testRelation
391
- .select(a, (b + 2 ).as(" _common_expr_0" ))
407
+ .select(a, b, (b + 2 ).as(" _common_expr_0" ))
392
408
.window(
393
409
Seq (windowExpr(count(a), windowSpec(Seq ($" _common_expr_0" * $" _common_expr_0" ), Nil ,
394
410
frame)).as(" col2" )),
395
411
Seq ($" _common_expr_0" * $" _common_expr_0" ),
396
412
Nil
397
413
)
398
- .select(a, $" col2" , (a + 1 ).as(" _common_expr_1" ))
414
+ .select(a, b, $" col2" )
415
+ .select(a, b, $" col2" , (a + 1 ).as(" _common_expr_1" ))
399
416
.window(
400
417
Seq (windowExpr(sum($" _common_expr_1" * $" _common_expr_1" ),
401
418
windowSpec(Seq (a), Nil , frame)).as(" col3" )),
402
419
Seq (a),
403
420
Nil
404
421
)
422
+ .select(a, b, $" col2" , $" col3" )
405
423
.select((a - 1 ).as(" col1" ), $" col2" , $" col3" )
406
424
.analyze
407
425
)
@@ -420,7 +438,8 @@ class RewriteWithExpressionSuite extends PlanTest {
420
438
testRelation
421
439
.select(a)
422
440
.window(Seq (winExpr.as(" _we0" )), Seq (a), Nil )
423
- .select(($" _we0" * $" _we0" ).as(" col" ))
441
+ .select(a, $" _we0" , ($" _we0" * $" _we0" ).as(" col" ))
442
+ .select($" col" )
424
443
.analyze
425
444
)
426
445
}
@@ -445,7 +464,7 @@ class RewriteWithExpressionSuite extends PlanTest {
445
464
val plan = testRelation.having($" b" )(avg(" a" ).as(" a" ))(expr).analyze
446
465
comparePlans(
447
466
Optimizer .execute(plan),
448
- testRelation.select($ " b " ). groupBy($" b" )(avg(" a" ).as(" a" )).where($" a" === 1 ).analyze
467
+ testRelation.groupBy($" b" )(avg(" a" ).as(" a" )).where($" a" === 1 ).analyze
449
468
)
450
469
}
451
470
@@ -459,7 +478,7 @@ class RewriteWithExpressionSuite extends PlanTest {
459
478
comparePlans(
460
479
Optimizer .execute(plan),
461
480
testRelation
462
- .select((a + a).as(" _common_expr_0" ))
481
+ .select(star(), (a + a).as(" _common_expr_0" ))
463
482
.select(
464
483
($" _common_expr_0" * $" _common_expr_0" ).as(" c1" ),
465
484
($" _common_expr_0" - $" _common_expr_0" ).as(" c2" ))
0 commit comments