Skip to content

Commit 2cc64f2

Browse files
BIGOKevinyhZou
BIGO
authored andcommitted
optimize nested function calls
1 parent f898bc2 commit 2cc64f2

File tree

13 files changed

+509
-122
lines changed

13 files changed

+509
-122
lines changed

backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala

+1
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ object CHRuleApi {
119119
injector.injectPostTransform(c => InsertTransitions.create(c.outputsColumnar, CHBatch))
120120
injector.injectPostTransform(c => RemoveDuplicatedColumns.apply(c.session))
121121
injector.injectPostTransform(c => AddPreProjectionForHashJoin.apply(c.session))
122+
injector.injectPostTransform(c => CollapseNestedExpressions.apply(c.session))
122123

123124
// Gluten columnar: Fallback policies.
124125
injector.injectFallbackPolicy(c => p => ExpandFallbackPolicy(c.caller.isAqe(), p))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.gluten.extension
18+
19+
import org.apache.gluten.execution.{FilterExecTransformer, ProjectExecTransformer}
20+
import org.apache.gluten.expression.{ExpressionMappings, UDFMappings}
21+
22+
import org.apache.spark.sql.SparkSession
23+
import org.apache.spark.sql.catalyst.expressions._
24+
import org.apache.spark.sql.catalyst.rules.Rule
25+
import org.apache.spark.sql.execution.SparkPlan
26+
import org.apache.spark.sql.types.{DataType, DataTypes}
27+
28+
case class CollapseNestedExpressions(spark: SparkSession) extends Rule[SparkPlan] {
29+
30+
override def apply(plan: SparkPlan): SparkPlan = {
31+
if (canBeOptimized(plan)) {
32+
visitPlan(plan)
33+
} else {
34+
plan
35+
}
36+
}
37+
38+
private def canBeOptimized(plan: SparkPlan): Boolean = plan match {
39+
case p: ProjectExecTransformer =>
40+
var res = p.projectList.exists(c => canBeOptimized(c))
41+
if (!res) {
42+
res = p.children.exists(c => canBeOptimized(c))
43+
}
44+
res
45+
case f: FilterExecTransformer =>
46+
var res = canBeOptimized(f.condition)
47+
if (!res) {
48+
res = canBeOptimized(f.child)
49+
}
50+
res
51+
case _ => plan.children.exists(c => canBeOptimized(c))
52+
}
53+
54+
private def canBeOptimized(expr: Expression): Boolean = {
55+
var exprCall = expr
56+
expr match {
57+
case a: Alias => exprCall = a.child
58+
case _ =>
59+
}
60+
val functionName = getExpressionName(exprCall)
61+
functionName match {
62+
case None =>
63+
exprCall match {
64+
case _: LeafExpression => false
65+
case _ => exprCall.children.exists(c => canBeOptimized(c))
66+
}
67+
case Some(f) =>
68+
UDFMappings.collapsedFunctionsMap.contains(f)
69+
}
70+
}
71+
72+
private def getExpressionName(expr: Expression): Option[String] = expr match {
73+
case _: GetStructField => ExpressionMappings.expressionsMap.get(classOf[GetStructField])
74+
case _: And => ExpressionMappings.expressionsMap.get(classOf[And])
75+
case _: Or => ExpressionMappings.expressionsMap.get(classOf[Or])
76+
case _: GetJsonObject => ExpressionMappings.expressionsMap.get(classOf[GetJsonObject])
77+
case _ => Option.empty[String]
78+
}
79+
80+
private def visitPlan(plan: SparkPlan): SparkPlan = plan match {
81+
case p: ProjectExecTransformer =>
82+
var newProjectList = Seq.empty[NamedExpression]
83+
p.projectList.foreach {
84+
case a: Alias =>
85+
val newAlias = Alias(optimize(a.child), a.name)(a.exprId)
86+
newProjectList :+= newAlias
87+
case p =>
88+
newProjectList :+= p
89+
}
90+
val newChild = visitPlan(p.child)
91+
ProjectExecTransformer(newProjectList, newChild)
92+
case f: FilterExecTransformer =>
93+
val newCondition = optimize(f.condition)
94+
val newChild = visitPlan(f.child)
95+
FilterExecTransformer(newCondition, newChild)
96+
case _ =>
97+
val newChildren = plan.children.map(p => visitPlan(p))
98+
plan.withNewChildren(newChildren)
99+
}
100+
101+
private def optimize(expr: Expression): Expression = {
102+
var resultExpr = expr
103+
var name = getExpressionName(expr)
104+
var children = Seq.empty[Expression]
105+
var nestedFunctions = 0
106+
var dataType = null.asInstanceOf[DataType]
107+
108+
def f(e: Expression, parent: Option[Expression] = Option.empty[Expression]): Unit = {
109+
parent match {
110+
case None =>
111+
name = getExpressionName(e)
112+
dataType = e.dataType
113+
case _ =>
114+
}
115+
e match {
116+
case g: GetStructField if canBeOptimized(g) =>
117+
parent match {
118+
case Some(_: GetStructField) | None =>
119+
children +:= Literal.apply(g.ordinal, DataTypes.IntegerType)
120+
f(g.child, parent = Option.apply(g))
121+
nestedFunctions += 1
122+
case _ =>
123+
}
124+
case a: And if canBeOptimized(a) =>
125+
parent match {
126+
case Some(_: And) | None =>
127+
f(a.left, Option.apply(a))
128+
f(a.right, Option.apply(a))
129+
nestedFunctions += 1
130+
case _ =>
131+
children +:= optimize(a)
132+
}
133+
case o: Or if canBeOptimized(o) =>
134+
parent match {
135+
case Some(_: Or) | None =>
136+
f(o.left, parent = Option.apply(o))
137+
f(o.right, parent = Option.apply(o))
138+
nestedFunctions += 1
139+
case _ =>
140+
children +:= optimize(o)
141+
}
142+
case g: GetJsonObject if canBeOptimized(g) =>
143+
parent match {
144+
case Some(_: GetJsonObject) | None =>
145+
g.path match {
146+
case l: Literal =>
147+
children +:= l
148+
f(g.json, parent = Option.apply(g))
149+
nestedFunctions += 1
150+
case _ =>
151+
}
152+
case _ =>
153+
val newG = optimize(g)
154+
children +:= newG
155+
}
156+
case _ =>
157+
if (parent.nonEmpty) {
158+
children +:= optimize(e)
159+
} else {
160+
nestedFunctions = 0
161+
children = Seq.empty[Expression]
162+
val exprNewChildren = e.children.map(p => optimize(p))
163+
resultExpr = e.withNewChildren(exprNewChildren)
164+
}
165+
}
166+
}
167+
f(expr)
168+
if ((nestedFunctions > 1 && name.isDefined) || scalaUDFExists(children)) {
169+
val func: Null = null
170+
ScalaUDF(func, dataType, children, udfName = name, nullable = expr.nullable)
171+
} else {
172+
resultExpr
173+
}
174+
}
175+
176+
private def scalaUDFExists(children: Seq[Expression]): Boolean = {
177+
var res = false
178+
children.foreach {
179+
case _: ScalaUDF if !res => res = true
180+
case c if !res => res = scalaUDFExists(c.children)
181+
case _ =>
182+
}
183+
res
184+
}
185+
}

backends-clickhouse/src/test/resources/text-data/empty_as_default/data.txt

Whitespace-only changes.

backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala

+52-13
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS
271271
}
272272
}
273273

274-
test("GLUTEN-8304: Optimize nested get_json_object") {
274+
test("GLUTEN-8304: Optimize nested functions") {
275275
def checkExpression(expr: Expression, path: String): Boolean = {
276276
expr match {
277277
case g: GetJsonObject
@@ -298,55 +298,94 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS
298298
plan.children.exists(c => checkPlan(c, path))
299299
}
300300
}
301-
def checkGetJsonObjectPath(df: DataFrame, path: String): Boolean = {
302-
checkPlan(df.queryExecution.analyzed, path)
301+
def checkGetJsonObjectPath(
302+
df: DataFrame,
303+
path: String,
304+
collapsedGetJsonObjectEnabled: Boolean): Boolean = {
305+
if (collapsedGetJsonObjectEnabled) {
306+
checkPlan(df.queryExecution.analyzed, path)
307+
} else {
308+
true
309+
}
303310
}
304-
withSQLConf(("spark.gluten.sql.collapseGetJsonObject.enabled", "true")) {
311+
312+
def runCheck(collapseGetJsonObjectEnabled: Boolean): Unit = {
305313
runQueryAndCompare(
306314
"select get_json_object(get_json_object(string_field1, '$.a'), '$.y') " +
307315
" from json_test where int_field1 = 6") {
308-
x => assert(checkGetJsonObjectPath(x, "$.a.y"))
316+
x => assert(checkGetJsonObjectPath(x, "$.a.y", collapseGetJsonObjectEnabled))
309317
}
310318
runQueryAndCompare(
311319
"select get_json_object(get_json_object(string_field1, '$[a]'), '$[y]') " +
312320
" from json_test where int_field1 = 6") {
313-
x => assert(checkGetJsonObjectPath(x, "$[a][y]"))
321+
x => assert(checkGetJsonObjectPath(x, "$[a][y]", collapseGetJsonObjectEnabled))
314322
}
315323
runQueryAndCompare(
316324
"select get_json_object(get_json_object(get_json_object(string_field1, " +
317325
"'$.a'), '$.y'), '$.z') from json_test where int_field1 = 6") {
318-
x => assert(checkGetJsonObjectPath(x, "$.a.y.z"))
326+
x => assert(checkGetJsonObjectPath(x, "$.a.y.z", collapseGetJsonObjectEnabled))
319327
}
320328
runQueryAndCompare(
321329
"select get_json_object(get_json_object(get_json_object(string_field1, '$.a')," +
322330
" string_field1), '$.z') from json_test where int_field1 = 6",
323331
noFallBack = false
324-
)(x => assert(checkGetJsonObjectPath(x, "$.a") && checkGetJsonObjectPath(x, "$.z")))
332+
)(
333+
x =>
334+
assert(
335+
checkGetJsonObjectPath(
336+
x,
337+
"$.a",
338+
collapseGetJsonObjectEnabled) && checkGetJsonObjectPath(
339+
x,
340+
"$.z",
341+
collapseGetJsonObjectEnabled)))
325342
runQueryAndCompare(
326343
"select get_json_object(get_json_object(get_json_object(string_field1, " +
327344
" string_field1), '$.a'), '$.z') from json_test where int_field1 = 6",
328345
noFallBack = false
329-
)(x => assert(checkGetJsonObjectPath(x, "$.a.z")))
346+
)(x => assert(checkGetJsonObjectPath(x, "$.a.z", collapseGetJsonObjectEnabled)))
330347
runQueryAndCompare(
331348
"select get_json_object(get_json_object(get_json_object(" +
332349
" substring(string_field1, 10), '$.a'), '$.z'), string_field1) " +
333350
" from json_test where int_field1 = 6",
334351
noFallBack = false
335-
)(x => assert(checkGetJsonObjectPath(x, "$.a.z")))
352+
)(x => assert(checkGetJsonObjectPath(x, "$.a.z", collapseGetJsonObjectEnabled)))
336353
runQueryAndCompare(
337354
"select get_json_object(get_json_object(string_field1, '$.a[0]'), '$.y') " +
338355
" from json_test where int_field1 = 7") {
339-
x => assert(checkGetJsonObjectPath(x, "$.a[0].y"))
356+
x => assert(checkGetJsonObjectPath(x, "$.a[0].y", collapseGetJsonObjectEnabled))
340357
}
341358
runQueryAndCompare(
342359
"select get_json_object(get_json_object(get_json_object(string_field1, " +
343360
" '$.a[1]'), '$.z[1]'), '$.n') from json_test where int_field1 = 7") {
344-
x => assert(checkGetJsonObjectPath(x, "$.a[1].z[1].n"))
361+
x => assert(checkGetJsonObjectPath(x, "$.a[1].z[1].n", collapseGetJsonObjectEnabled))
345362
}
346363
runQueryAndCompare(
347364
"select * from json_test where " +
348365
" get_json_object(get_json_object(get_json_object(string_field1, '$.a'), " +
349-
"'$.y'), '$.z') != null")(x => assert(checkGetJsonObjectPath(x, "$.a.y.z")))
366+
"'$.y'), '$.z') != null")(
367+
x => assert(checkGetJsonObjectPath(x, "$.a.y.z", collapseGetJsonObjectEnabled)))
368+
runQueryAndCompare(
369+
"select get_json_object(get_json_object(get_json_object(string_field1, " +
370+
" '$.a[1]'), '$.z[1]'), '$.n') from json_test where int_field1 = 7 or int_field1 = 5") {
371+
_ =>
372+
}
373+
374+
runQueryAndCompare(
375+
"select get_json_object(get_json_object(get_json_object(string_field1, " +
376+
" '$.a[1]'), '$.z[1]'), '$.n') from json_test where int_field1 > 3 and int_field1 != 5 " +
377+
" or int_field1 < 2") { _ => }
378+
}
379+
withSQLConf(
380+
("spark.gluten.sql.collapseGetJsonObject.enabled", "true"),
381+
("spark.gluten.sql.supported.collapseNestedFunctions", "")) {
382+
runCheck(true)
383+
}
384+
withSQLConf(
385+
(
386+
"spark.gluten.sql.supported.collapseNestedFunctions",
387+
"get_json_object,get_struct_field,and,or")) {
388+
runCheck(false)
350389
}
351390
}
352391

0 commit comments

Comments
 (0)