@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Plan
26
26
import org .apache .spark .sql .catalyst .rules .Rule
27
27
import org .apache .spark .sql .catalyst .trees .TreePattern .{COMMON_EXPR_REF , WITH_EXPRESSION }
28
28
import org .apache .spark .sql .internal .SQLConf
29
+ import org .apache .spark .util .Utils
29
30
30
31
/**
31
32
* Rewrites the `With` expressions by adding a `Project` to pre-evaluate the common expressions, or
@@ -66,11 +67,19 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
66
67
}
67
68
68
69
private def applyInternal (p : LogicalPlan ): LogicalPlan = {
69
- val inputPlans = p.children.toArray
70
+ val inputPlans = p.children
71
+ val commonExprsPerChild = Array .fill(inputPlans.length)(mutable.ListBuffer .empty[(Alias , Long )])
70
72
var newPlan : LogicalPlan = p.mapExpressions { expr =>
71
- rewriteWithExprAndInputPlans(expr, inputPlans)
73
+ rewriteWithExprAndInputPlans(expr, inputPlans, commonExprsPerChild )
72
74
}
73
- newPlan = newPlan.withNewChildren(inputPlans.toIndexedSeq)
75
+ val newChildren = inputPlans.zip(commonExprsPerChild).map { case (inputPlan, commonExprs) =>
76
+ if (commonExprs.isEmpty) {
77
+ inputPlan
78
+ } else {
79
+ Project (inputPlan.output ++ commonExprs.map(_._1), inputPlan)
80
+ }
81
+ }
82
+ newPlan = newPlan.withNewChildren(newChildren)
74
83
// Since we add extra Projects with extra columns to pre-evaluate the common expressions,
75
84
// the current operator may have extra columns if it inherits the output columns from its
76
85
// child, and we need to project away the extra columns to keep the plan schema unchanged.
@@ -85,17 +94,19 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
85
94
86
95
private def rewriteWithExprAndInputPlans (
87
96
e : Expression ,
88
- inputPlans : Array [LogicalPlan ],
97
+ inputPlans : Seq [LogicalPlan ],
98
+ commonExprsPerChild : Array [mutable.ListBuffer [(Alias , Long )]],
89
99
isNestedWith : Boolean = false ): Expression = {
90
100
if (! e.containsPattern(WITH_EXPRESSION )) return e
91
101
e match {
92
102
// Do not handle nested With in one pass. Leave it to the next rule executor batch.
93
103
case w : With if ! isNestedWith =>
94
104
// Rewrite nested With expressions first
95
- val child = rewriteWithExprAndInputPlans(w.child, inputPlans, isNestedWith = true )
96
- val defs = w.defs.map(rewriteWithExprAndInputPlans(_, inputPlans, isNestedWith = true ))
105
+ val child = rewriteWithExprAndInputPlans(
106
+ w.child, inputPlans, commonExprsPerChild, isNestedWith = true )
107
+ val defs = w.defs.map(rewriteWithExprAndInputPlans(
108
+ _, inputPlans, commonExprsPerChild, isNestedWith = true ))
97
109
val refToExpr = mutable.HashMap .empty[CommonExpressionId , Expression ]
98
- val childProjections = Array .fill(inputPlans.length)(mutable.ArrayBuffer .empty[Alias ])
99
110
100
111
defs.zipWithIndex.foreach { case (CommonExpressionDef (child, id), index) =>
101
112
if (id.canonicalized) {
@@ -106,10 +117,10 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
106
117
if (CollapseProject .isCheap(child)) {
107
118
refToExpr(id) = child
108
119
} else {
109
- val childProjectionIndex = inputPlans.indexWhere(
120
+ val childPlanIndex = inputPlans.indexWhere(
110
121
c => child.references.subsetOf(c.outputSet)
111
122
)
112
- if (childProjectionIndex == - 1 ) {
123
+ if (childPlanIndex == - 1 ) {
113
124
// When we cannot rewrite the common expressions, force to inline them so that the
114
125
// query can still run. This can happen if the join condition contains `With` and
115
126
// the common expression references columns from both join sides.
@@ -120,31 +131,33 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
120
131
// if it's ref count is 1.
121
132
refToExpr(id) = child
122
133
} else {
123
- val aliasName = if (SQLConf .get.getConf(SQLConf .USE_COMMON_EXPR_ID_FOR_ALIAS )) {
124
- s " _common_expr_ ${id.id}"
125
- } else {
126
- s " _common_expr_ $index"
127
- }
128
- val alias = Alias (child, aliasName)()
129
- val fakeProj = Project (Seq (alias), inputPlans(childProjectionIndex))
130
- if (PlanHelper .specialExpressionsInUnsupportedOperator(fakeProj).nonEmpty) {
131
- // We have to inline the common expression if it cannot be put in a Project.
132
- refToExpr(id) = child
134
+ val commonExprs = commonExprsPerChild(childPlanIndex)
135
+ val existingCommonExpr = commonExprs.find(_._2 == id.id)
136
+ if (existingCommonExpr.isDefined) {
137
+ if (Utils .isTesting) {
138
+ assert(existingCommonExpr.get._1.child.semanticEquals(child))
139
+ }
140
+ refToExpr(id) = existingCommonExpr.get._1.toAttribute
133
141
} else {
134
- childProjections(childProjectionIndex) += alias
135
- refToExpr(id) = alias.toAttribute
142
+ val aliasName = if (SQLConf .get.getConf(SQLConf .USE_COMMON_EXPR_ID_FOR_ALIAS )) {
143
+ s " _common_expr_ ${id.id}"
144
+ } else {
145
+ s " _common_expr_ $index"
146
+ }
147
+ val alias = Alias (child, aliasName)()
148
+ val fakeProj = Project (Seq (alias), inputPlans(childPlanIndex))
149
+ if (PlanHelper .specialExpressionsInUnsupportedOperator(fakeProj).nonEmpty) {
150
+ // We have to inline the common expression if it cannot be put in a Project.
151
+ refToExpr(id) = child
152
+ } else {
153
+ commonExprs.append((alias, id.id))
154
+ refToExpr(id) = alias.toAttribute
155
+ }
136
156
}
137
157
}
138
158
}
139
159
}
140
160
141
- for (i <- inputPlans.indices) {
142
- val projectList = childProjections(i)
143
- if (projectList.nonEmpty) {
144
- inputPlans(i) = Project (inputPlans(i).output ++ projectList, inputPlans(i))
145
- }
146
- }
147
-
148
161
child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF )) {
149
162
// `child` may contain nested With and we only replace `CommonExpressionRef` that
150
163
// references common expressions in the current `With`.
@@ -158,7 +171,7 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
158
171
159
172
case c : ConditionalExpression =>
160
173
val newAlwaysEvaluatedInputs = c.alwaysEvaluatedInputs.map(
161
- rewriteWithExprAndInputPlans(_, inputPlans, isNestedWith))
174
+ rewriteWithExprAndInputPlans(_, inputPlans, commonExprsPerChild, isNestedWith))
162
175
val newExpr = c.withNewAlwaysEvaluatedInputs(newAlwaysEvaluatedInputs)
163
176
// Use transformUp to handle nested With.
164
177
newExpr.transformUpWithPruning(_.containsPattern(WITH_EXPRESSION )) {
@@ -171,7 +184,9 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
171
184
}
172
185
}
173
186
174
- case other => other.mapChildren(rewriteWithExprAndInputPlans(_, inputPlans, isNestedWith))
187
+ case other => other.mapChildren(
188
+ rewriteWithExprAndInputPlans(_, inputPlans, commonExprsPerChild, isNestedWith)
189
+ )
175
190
}
176
191
}
177
192
}
0 commit comments