@@ -88,70 +88,20 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
88
88
inputPlans : Array [LogicalPlan ]): Expression = {
89
89
if (! e.containsPattern(WITH_EXPRESSION )) return e
90
90
e match {
91
+ // Rewrite `With` consists of two steps:
92
+ // 1. Convert defs into `Project` of inputPlans
93
+ // 2. Replace the references in inputPlans and child then output the new child
94
+ // `With` support nestd, defs and child can all contain sub-With, the `Alias` generated by
95
+ // the current `With` acts on the child and the child's internal `With`, so generate
96
+ // `Project` order is:
97
+ // 1. internally nested with of main expression definitions
98
+ // 2. main expression
99
+ // 3. internally nested with of main expression
91
100
case w : With =>
92
- // Rewrite nested With expressions first
93
101
val defs = w.defs.map(rewriteWithExprAndInputPlans(_, inputPlans))
94
- val childProjections = Array .fill(inputPlans.length)(mutable.ArrayBuffer .empty[Alias ])
95
- val refToExpr = mutable.HashMap .empty[CommonExpressionId , Expression ]
96
-
97
- defs.zipWithIndex.foreach { case (CommonExpressionDef (child, id), index) =>
98
- if (id.canonicalized) {
99
- throw SparkException .internalError(
100
- " Cannot rewrite canonicalized Common expression definitions" )
101
- }
102
-
103
- if (CollapseProject .isCheap(child)) {
104
- refToExpr(id) = child
105
- } else {
106
- val childProjectionIndex = inputPlans.indexWhere(
107
- c => child.references.subsetOf(c.outputSet)
108
- )
109
- if (childProjectionIndex == - 1 ) {
110
- // When we cannot rewrite the common expressions, force to inline them so that the
111
- // query can still run. This can happen if the join condition contains `With` and
112
- // the common expression references columns from both join sides.
113
- // TODO: things can go wrong if the common expression is nondeterministic. We
114
- // don't fix it for now to match the old buggy behavior when certain
115
- // `RuntimeReplaceable` did not use the `With` expression.
116
- // TODO: we should calculate the ref count and also inline the common expression
117
- // if it's ref count is 1.
118
- refToExpr(id) = child
119
- } else {
120
- val aliasName = if (SQLConf .get.getConf(SQLConf .USE_COMMON_EXPR_ID_FOR_ALIAS )) {
121
- s " _common_expr_ ${id.id}"
122
- } else {
123
- s " _common_expr_ $index"
124
- }
125
- val alias = Alias (child, aliasName)()
126
- val fakeProj = Project (Seq (alias), inputPlans(childProjectionIndex))
127
- if (PlanHelper .specialExpressionsInUnsupportedOperator(fakeProj).nonEmpty) {
128
- // We have to inline the common expression if it cannot be put in a Project.
129
- refToExpr(id) = child
130
- } else {
131
- childProjections(childProjectionIndex) += alias
132
- refToExpr(id) = alias.toAttribute
133
- }
134
- }
135
- }
136
- }
137
-
138
- // Expression definitions apply to the entire With child contains internally nested with,
139
- // so need to generate the project first, then generate child's project
140
- for (i <- inputPlans.indices) {
141
- val projectList = childProjections(i)
142
- if (projectList.nonEmpty) {
143
- inputPlans(i) = Project (inputPlans(i).output ++ projectList, inputPlans(i))
144
- }
145
- }
102
+ val refToExpr = genProject(defs, inputPlans)
146
103
val child = rewriteWithExprAndInputPlans(w.child, inputPlans)
147
- // Internally nested with may also contain outer expression references
148
- for (i <- inputPlans.indices) {
149
- val projectList = childProjections(i)
150
- if (projectList.nonEmpty) {
151
- inputPlans(i) = inputPlans(i).mapExpressions(replaceRef(_, refToExpr))
152
- }
153
- }
154
- replaceRef(child, refToExpr)
104
+ replaceRef(child, inputPlans, refToExpr)
155
105
case c : ConditionalExpression =>
156
106
val newAlwaysEvaluatedInputs = c.alwaysEvaluatedInputs.map(
157
107
rewriteWithExprAndInputPlans(_, inputPlans))
@@ -171,9 +121,74 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
171
121
}
172
122
}
173
123
124
+ private def genProject (
125
+ defs : Seq [Expression ],
126
+ inputPlans : Array [LogicalPlan ]): mutable.Map [CommonExpressionId , Expression ] = {
127
+ val refToExpr = mutable.HashMap .empty[CommonExpressionId , Expression ]
128
+ val childProjections = Array .fill(inputPlans.length)(mutable.ArrayBuffer .empty[Alias ])
129
+ defs.zipWithIndex.foreach { case (CommonExpressionDef (child, id), index) =>
130
+ if (id.canonicalized) {
131
+ throw SparkException .internalError(
132
+ " Cannot rewrite canonicalized Common expression definitions" )
133
+ }
134
+
135
+ if (CollapseProject .isCheap(child)) {
136
+ refToExpr(id) = child
137
+ } else {
138
+ val childProjectionIndex = inputPlans.indexWhere(
139
+ c => child.references.subsetOf(c.outputSet)
140
+ )
141
+ if (childProjectionIndex == - 1 ) {
142
+ // When we cannot rewrite the common expressions, force to inline them so that the
143
+ // query can still run. This can happen if the join condition contains `With` and
144
+ // the common expression references columns from both join sides.
145
+ // TODO: things can go wrong if the common expression is nondeterministic. We
146
+ // don't fix it for now to match the old buggy behavior when certain
147
+ // `RuntimeReplaceable` did not use the `With` expression.
148
+ // TODO: we should calculate the ref count and also inline the common expression
149
+ // if it's ref count is 1.
150
+ refToExpr(id) = child
151
+ } else {
152
+ val aliasName = if (SQLConf .get.getConf(SQLConf .USE_COMMON_EXPR_ID_FOR_ALIAS )) {
153
+ s " _common_expr_ ${id.id}"
154
+ } else {
155
+ s " _common_expr_ $index"
156
+ }
157
+ val alias = Alias (child, aliasName)()
158
+ val fakeProj = Project (Seq (alias), inputPlans(childProjectionIndex))
159
+ if (PlanHelper .specialExpressionsInUnsupportedOperator(fakeProj).nonEmpty) {
160
+ // We have to inline the common expression if it cannot be put in a Project.
161
+ refToExpr(id) = child
162
+ } else {
163
+ childProjections(childProjectionIndex) += alias
164
+ refToExpr(id) = alias.toAttribute
165
+ }
166
+ }
167
+ }
168
+ }
169
+ for (i <- inputPlans.indices) {
170
+ val projectList = childProjections(i)
171
+ if (projectList.nonEmpty) {
172
+ inputPlans(i) = Project (inputPlans(i).output ++ projectList, inputPlans(i))
173
+ }
174
+ }
175
+ refToExpr
176
+ }
177
+
174
178
private def replaceRef (
175
179
expr : Expression ,
176
- refToExpr : mutable.Map [CommonExpressionId , Expression ]): Expression =
180
+ inputPlans : Array [LogicalPlan ],
181
+ refToExpr : mutable.Map [CommonExpressionId , Expression ]): Expression = {
182
+ // Internally nested `With` may also contain outer expression references
183
+ for (i <- inputPlans.indices) {
184
+ inputPlans(i) = inputPlans(i).mapExpressions(replaceRef(_, refToExpr))
185
+ }
186
+ replaceRef(expr, refToExpr)
187
+ }
188
+
189
+ private def replaceRef (
190
+ expr : Expression ,
191
+ refToExpr : mutable.Map [CommonExpressionId , Expression ]): Expression = {
177
192
expr.transformWithPruning(_.containsPattern(COMMON_EXPR_REF )) {
178
193
case ref : CommonExpressionRef if refToExpr.contains(ref.id) =>
179
194
if (ref.id.canonicalized) {
@@ -183,4 +198,5 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
183
198
refToExpr(ref.id)
184
199
}
185
200
}
201
+ }
186
202
}
0 commit comments