16
16
*/
17
17
package org .apache .gluten .extension
18
18
19
- import org .apache .gluten .config .GlutenConfig
20
19
import org .apache .gluten .execution .{FilterExecTransformer , ProjectExecTransformer }
21
- import org .apache .gluten .expression .ExpressionMappings
20
+ import org .apache .gluten .expression .{ ExpressionMappings , UDFMappings }
22
21
23
22
import org .apache .spark .sql .SparkSession
24
- import org .apache .spark .sql .catalyst .expressions .{ Alias , And , Expression , GetStructField , Literal , NamedExpression , ScalaUDF }
23
+ import org .apache .spark .sql .catalyst .expressions ._
25
24
import org .apache .spark .sql .catalyst .rules .Rule
26
25
import org .apache .spark .sql .execution .SparkPlan
27
26
import org .apache .spark .sql .types .{DataType , DataTypes }
28
27
29
28
case class CollapseNestedExpressions (spark : SparkSession ) extends Rule [SparkPlan ] {
30
29
31
30
override def apply (plan : SparkPlan ): SparkPlan = {
32
- if (GlutenConfig .get.enableCollapseNestedFunctions) {
33
- val p = visitPlan(plan)
34
- p
31
+ if (canBeOptimized(plan)) {
32
+ visitPlan(plan)
35
33
} else {
36
34
plan
37
35
}
38
36
}
39
37
38
+ private def canBeOptimized (plan : SparkPlan ): Boolean = plan match {
39
+ case p : ProjectExecTransformer =>
40
+ var res = p.projectList.exists(c => canBeOptimized(c))
41
+ if (! res) {
42
+ res = p.children.exists(c => canBeOptimized(c))
43
+ }
44
+ res
45
+ case f : FilterExecTransformer =>
46
+ var res = canBeOptimized(f.condition)
47
+ if (! res) {
48
+ res = canBeOptimized(f.child)
49
+ }
50
+ res
51
+ case _ => plan.children.exists(c => canBeOptimized(c))
52
+ }
53
+
54
+ private def canBeOptimized (expr : Expression ): Boolean = {
55
+ var exprCall = expr
56
+ expr match {
57
+ case a : Alias => exprCall = a.child
58
+ case _ =>
59
+ }
60
+ val functionName = getExpressionName(exprCall)
61
+ functionName match {
62
+ case None =>
63
+ exprCall match {
64
+ case _ : LeafExpression => false
65
+ case _ => exprCall.children.exists(c => canBeOptimized(c))
66
+ }
67
+ case Some (f) =>
68
+ UDFMappings .collapsedFunctionsMap.contains(f)
69
+ }
70
+ }
71
+
72
+ private def getExpressionName (expr : Expression ): Option [String ] = expr match {
73
+ case _ : GetStructField => ExpressionMappings .expressionsMap.get(classOf [GetStructField ])
74
+ case _ : And => ExpressionMappings .expressionsMap.get(classOf [And ])
75
+ case _ : Or => ExpressionMappings .expressionsMap.get(classOf [Or ])
76
+ case _ : GetJsonObject => ExpressionMappings .expressionsMap.get(classOf [GetJsonObject ])
77
+ case _ => Option .empty[String ]
78
+ }
79
+
40
80
private def visitPlan (plan : SparkPlan ): SparkPlan = plan match {
41
81
case p : ProjectExecTransformer =>
42
82
var newProjectList = Seq .empty[NamedExpression ]
@@ -59,44 +99,87 @@ case class CollapseNestedExpressions(spark: SparkSession) extends Rule[SparkPlan
59
99
}
60
100
61
101
private def optimize (expr : Expression ): Expression = {
62
- var resultExpr = null . asInstanceOf [ Expression ]
63
- var name = Option .empty[ String ]
102
+ var resultExpr = expr
103
+ var name = getExpressionName(expr)
64
104
var children = Seq .empty[Expression ]
65
105
var nestedFunctions = 0
66
106
var dataType = null .asInstanceOf [DataType ]
67
- def f (e : Expression , nested : Boolean = false ): Unit = e match {
68
- case g : GetStructField =>
69
- if (! nested) {
70
- name = ExpressionMappings .expressionsMap.get(classOf [GetStructField ])
71
- dataType = g.dataType
72
- }
73
- children +:= Literal .apply(g.ordinal, DataTypes .IntegerType )
74
- f(g.child, nested = true )
75
- nestedFunctions += 1
76
- case a : And =>
77
- if (! nested) {
78
- name = ExpressionMappings .expressionsMap.get(classOf [And ])
79
- dataType = a.dataType
80
- }
81
- f(a.left, nested = true )
82
- f(a.right, nested = true )
83
- nestedFunctions += 1
84
- case _ =>
85
- if (nested) {
86
- children +:= e
87
- } else {
88
- nestedFunctions = 0
89
- children = Seq .empty[Expression ]
90
- val exprNewChildren = e.children.map(p => optimize(p))
91
- resultExpr = e.withNewChildren(exprNewChildren)
92
- }
107
+
108
+ def f (e : Expression , parent : Option [Expression ] = Option .empty[Expression ]): Unit = {
109
+ parent match {
110
+ case None =>
111
+ name = getExpressionName(e)
112
+ dataType = e.dataType
113
+ case _ =>
114
+ }
115
+ e match {
116
+ case g : GetStructField if canBeOptimized(g) =>
117
+ parent match {
118
+ case Some (_ : GetStructField ) | None =>
119
+ children +:= Literal .apply(g.ordinal, DataTypes .IntegerType )
120
+ f(g.child, parent = Option .apply(g))
121
+ nestedFunctions += 1
122
+ case _ =>
123
+ }
124
+ case a : And if canBeOptimized(a) =>
125
+ parent match {
126
+ case Some (_ : And ) | None =>
127
+ f(a.left, Option .apply(a))
128
+ f(a.right, Option .apply(a))
129
+ nestedFunctions += 1
130
+ case _ =>
131
+ children +:= optimize(a)
132
+ }
133
+ case o : Or if canBeOptimized(o) =>
134
+ parent match {
135
+ case Some (_ : Or ) | None =>
136
+ f(o.left, parent = Option .apply(o))
137
+ f(o.right, parent = Option .apply(o))
138
+ nestedFunctions += 1
139
+ case _ =>
140
+ children +:= optimize(o)
141
+ }
142
+ case g : GetJsonObject if canBeOptimized(g) =>
143
+ parent match {
144
+ case Some (_ : GetJsonObject ) | None =>
145
+ g.path match {
146
+ case l : Literal =>
147
+ children +:= l
148
+ f(g.json, parent = Option .apply(g))
149
+ nestedFunctions += 1
150
+ case _ =>
151
+ }
152
+ case _ =>
153
+ val newG = optimize(g)
154
+ children +:= newG
155
+ }
156
+ case _ =>
157
+ if (parent.nonEmpty) {
158
+ children +:= optimize(e)
159
+ } else {
160
+ nestedFunctions = 0
161
+ children = Seq .empty[Expression ]
162
+ val exprNewChildren = e.children.map(p => optimize(p))
163
+ resultExpr = e.withNewChildren(exprNewChildren)
164
+ }
165
+ }
93
166
}
94
167
f(expr)
95
- if (nestedFunctions > 1 && name.isDefined && dataType != null ) {
168
+ if (( nestedFunctions > 1 && name.isDefined) || scalaUDFExists(children) ) {
96
169
val func : Null = null
97
170
ScalaUDF (func, dataType, children, udfName = name, nullable = expr.nullable)
98
171
} else {
99
172
resultExpr
100
173
}
101
174
}
175
+
176
+ private def scalaUDFExists (children : Seq [Expression ]): Boolean = {
177
+ var res = false
178
+ children.foreach {
179
+ case _ : ScalaUDF if ! res => res = true
180
+ case c if ! res => res = scalaUDFExists(c.children)
181
+ case _ =>
182
+ }
183
+ res
184
+ }
102
185
}
0 commit comments