Skip to content

Commit 44b21c7

Browse files
BIGOBIGO
BIGO
authored and
BIGO
committed
support collapsed functions
1 parent 1184913 commit 44b21c7

File tree

4 files changed

+130
-46
lines changed

4 files changed

+130
-46
lines changed

backends-clickhouse/src/main/scala/org/apache/gluten/extension/CollapseNestedExpressions.scala

+118-35
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,67 @@
1616
*/
1717
package org.apache.gluten.extension
1818

19-
import org.apache.gluten.config.GlutenConfig
2019
import org.apache.gluten.execution.{FilterExecTransformer, ProjectExecTransformer}
21-
import org.apache.gluten.expression.ExpressionMappings
20+
import org.apache.gluten.expression.{ExpressionMappings, UDFMappings}
2221

2322
import org.apache.spark.sql.SparkSession
24-
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Expression, GetStructField, Literal, NamedExpression, ScalaUDF}
23+
import org.apache.spark.sql.catalyst.expressions._
2524
import org.apache.spark.sql.catalyst.rules.Rule
2625
import org.apache.spark.sql.execution.SparkPlan
2726
import org.apache.spark.sql.types.{DataType, DataTypes}
2827

2928
case class CollapseNestedExpressions(spark: SparkSession) extends Rule[SparkPlan] {
3029

3130
override def apply(plan: SparkPlan): SparkPlan = {
32-
if (GlutenConfig.get.enableCollapseNestedFunctions) {
33-
val p = visitPlan(plan)
34-
p
31+
if (canBeOptimized(plan)) {
32+
visitPlan(plan)
3533
} else {
3634
plan
3735
}
3836
}
3937

38+
private def canBeOptimized(plan: SparkPlan): Boolean = plan match {
39+
case p: ProjectExecTransformer =>
40+
var res = p.projectList.exists(c => canBeOptimized(c))
41+
if (!res) {
42+
res = p.children.exists(c => canBeOptimized(c))
43+
}
44+
res
45+
case f: FilterExecTransformer =>
46+
var res = canBeOptimized(f.condition)
47+
if (!res) {
48+
res = canBeOptimized(f.child)
49+
}
50+
res
51+
case _ => plan.children.exists(c => canBeOptimized(c))
52+
}
53+
54+
private def canBeOptimized(expr: Expression): Boolean = {
55+
var exprCall = expr
56+
expr match {
57+
case a: Alias => exprCall = a.child
58+
case _ =>
59+
}
60+
val functionName = getExpressionName(exprCall)
61+
functionName match {
62+
case None =>
63+
exprCall match {
64+
case _: LeafExpression => false
65+
case _ => exprCall.children.exists(c => canBeOptimized(c))
66+
}
67+
case Some(f) =>
68+
UDFMappings.collapsedFunctionsMap.contains(f)
69+
}
70+
}
71+
72+
private def getExpressionName(expr: Expression): Option[String] = expr match {
73+
case _: GetStructField => ExpressionMappings.expressionsMap.get(classOf[GetStructField])
74+
case _: And => ExpressionMappings.expressionsMap.get(classOf[And])
75+
case _: Or => ExpressionMappings.expressionsMap.get(classOf[Or])
76+
case _: GetJsonObject => ExpressionMappings.expressionsMap.get(classOf[GetJsonObject])
77+
case _ => Option.empty[String]
78+
}
79+
4080
private def visitPlan(plan: SparkPlan): SparkPlan = plan match {
4181
case p: ProjectExecTransformer =>
4282
var newProjectList = Seq.empty[NamedExpression]
@@ -59,44 +99,87 @@ case class CollapseNestedExpressions(spark: SparkSession) extends Rule[SparkPlan
5999
}
60100

61101
private def optimize(expr: Expression): Expression = {
62-
var resultExpr = null.asInstanceOf[Expression]
63-
var name = Option.empty[String]
102+
var resultExpr = expr
103+
var name = getExpressionName(expr)
64104
var children = Seq.empty[Expression]
65105
var nestedFunctions = 0
66106
var dataType = null.asInstanceOf[DataType]
67-
def f(e: Expression, nested: Boolean = false): Unit = e match {
68-
case g: GetStructField =>
69-
if (!nested) {
70-
name = ExpressionMappings.expressionsMap.get(classOf[GetStructField])
71-
dataType = g.dataType
72-
}
73-
children +:= Literal.apply(g.ordinal, DataTypes.IntegerType)
74-
f(g.child, nested = true)
75-
nestedFunctions += 1
76-
case a: And =>
77-
if (!nested) {
78-
name = ExpressionMappings.expressionsMap.get(classOf[And])
79-
dataType = a.dataType
80-
}
81-
f(a.left, nested = true)
82-
f(a.right, nested = true)
83-
nestedFunctions += 1
84-
case _ =>
85-
if (nested) {
86-
children +:= e
87-
} else {
88-
nestedFunctions = 0
89-
children = Seq.empty[Expression]
90-
val exprNewChildren = e.children.map(p => optimize(p))
91-
resultExpr = e.withNewChildren(exprNewChildren)
92-
}
107+
108+
def f(e: Expression, parent: Option[Expression] = Option.empty[Expression]): Unit = {
109+
parent match {
110+
case None =>
111+
name = getExpressionName(e)
112+
dataType = e.dataType
113+
case _ =>
114+
}
115+
e match {
116+
case g: GetStructField if canBeOptimized(g) =>
117+
parent match {
118+
case Some(_: GetStructField) | None =>
119+
children +:= Literal.apply(g.ordinal, DataTypes.IntegerType)
120+
f(g.child, parent = Option.apply(g))
121+
nestedFunctions += 1
122+
case _ =>
123+
}
124+
case a: And if canBeOptimized(a) =>
125+
parent match {
126+
case Some(_: And) | None =>
127+
f(a.left, Option.apply(a))
128+
f(a.right, Option.apply(a))
129+
nestedFunctions += 1
130+
case _ =>
131+
children +:= optimize(a)
132+
}
133+
case o: Or if canBeOptimized(o) =>
134+
parent match {
135+
case Some(_: Or) | None =>
136+
f(o.left, parent = Option.apply(o))
137+
f(o.right, parent = Option.apply(o))
138+
nestedFunctions += 1
139+
case _ =>
140+
children +:= optimize(o)
141+
}
142+
case g: GetJsonObject if canBeOptimized(g) =>
143+
parent match {
144+
case Some(_: GetJsonObject) | None =>
145+
g.path match {
146+
case l: Literal =>
147+
children +:= l
148+
f(g.json, parent = Option.apply(g))
149+
nestedFunctions += 1
150+
case _ =>
151+
}
152+
case _ =>
153+
val newG = optimize(g)
154+
children +:= newG
155+
}
156+
case _ =>
157+
if (parent.nonEmpty) {
158+
children +:= optimize(e)
159+
} else {
160+
nestedFunctions = 0
161+
children = Seq.empty[Expression]
162+
val exprNewChildren = e.children.map(p => optimize(p))
163+
resultExpr = e.withNewChildren(exprNewChildren)
164+
}
165+
}
93166
}
94167
f(expr)
95-
if (nestedFunctions > 1 && name.isDefined && dataType != null) {
168+
if ((nestedFunctions > 1 && name.isDefined) || scalaUDFExists(children)) {
96169
val func: Null = null
97170
ScalaUDF(func, dataType, children, udfName = name, nullable = expr.nullable)
98171
} else {
99172
resultExpr
100173
}
101174
}
175+
176+
private def scalaUDFExists(children: Seq[Expression]): Boolean = {
177+
var res = false
178+
children.foreach {
179+
case _: ScalaUDF if !res => res = true
180+
case c if !res => res = scalaUDFExists(c.children)
181+
case _ =>
182+
}
183+
res
184+
}
102185
}

gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala

+4-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,10 @@ object ExpressionConverter extends SQLConfHelper with Logging {
9797
if (udf.udfName.isEmpty) {
9898
throw new GlutenNotSupportException("UDF name is not found!")
9999
}
100-
val substraitExprName = UDFMappings.scalaUDFMap.get(udf.udfName.get)
100+
var substraitExprName = UDFMappings.scalaUDFMap.get(udf.udfName.get)
101+
if (substraitExprName.isEmpty) {
102+
substraitExprName = UDFMappings.collapsedFunctionsMap.get(udf.udfName.get)
103+
}
101104
substraitExprName match {
102105
case Some(name) =>
103106
GenericExpressionTransformer(

gluten-substrait/src/main/scala/org/apache/gluten/expression/UDFMappings.scala

+7
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ object UDFMappings extends Logging {
3131
val hiveUDFMap: Map[String, String] = Map()
3232
val pythonUDFMap: Map[String, String] = Map()
3333
val scalaUDFMap: Map[String, String] = Map()
34+
val collapsedFunctionsMap: Map[String, String] = Map()
3435

3536
private def appendKVToMap(key: String, value: String, res: Map[String, String]): Unit = {
3637
if (key.isEmpty || value.isEmpty()) {
@@ -75,5 +76,11 @@ object UDFMappings extends Logging {
7576
parseStringToMap(strScalaUDFs, scalaUDFMap)
7677
logDebug(s"loaded scala udf mappings:${scalaUDFMap.toString}")
7778
}
79+
80+
val strCollapsedFunctions = conf.get(GlutenConfig.GLUTEN_SUPPORTED_COLLAPSED_FUNCTIONS, "")
81+
if (!StringUtils.isBlank(strCollapsedFunctions)) {
82+
parseStringToMap(strCollapsedFunctions, collapsedFunctionsMap)
83+
logDebug(s"loaded collapsed function mappings: ${collapsedFunctionsMap.toString}")
84+
}
7885
}
7986
}

shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala

+1-10
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,6 @@ class GlutenConfig(conf: SQLConf) extends Logging {
113113
def enableCollapseNestedGetJsonObject: Boolean =
114114
getConf(ENABLE_COLLAPSE_GET_JSON_OBJECT)
115115

116-
def enableCollapseNestedFunctions: Boolean =
117-
getConf(ENABLE_COLLAPSE_NESTED_FUNCTIONS)
118-
119116
def enableCHRewriteDateConversion: Boolean =
120117
getConf(ENABLE_CH_REWRITE_DATE_CONVERSION)
121118

@@ -667,6 +664,7 @@ object GlutenConfig {
667664
val GLUTEN_SUPPORTED_HIVE_UDFS = "spark.gluten.supported.hive.udfs"
668665
val GLUTEN_SUPPORTED_PYTHON_UDFS = "spark.gluten.supported.python.udfs"
669666
val GLUTEN_SUPPORTED_SCALA_UDFS = "spark.gluten.supported.scala.udfs"
667+
val GLUTEN_SUPPORTED_COLLAPSED_FUNCTIONS = "spark.gluten.supported.collapsed.functions"
670668

671669
// FIXME: This only works with CH backend.
672670
val GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF =
@@ -1947,13 +1945,6 @@ object GlutenConfig {
19471945
.booleanConf
19481946
.createWithDefault(false)
19491947

1950-
val ENABLE_COLLAPSE_NESTED_FUNCTIONS =
1951-
buildConf("spark.gluten.sql.nestedFunctionsCollapsed.enabled")
1952-
.internal()
1953-
.doc("Collapse nested functions as one for optimization.")
1954-
.booleanConf
1955-
.createWithDefault(true)
1956-
19571948
val ENABLE_CH_REWRITE_DATE_CONVERSION =
19581949
buildConf("spark.gluten.sql.columnar.backend.ch.rewrite.dateConversion")
19591950
.internal()

0 commit comments

Comments
 (0)