Skip to content

Commit 2c6d6f6

Browse files
BIGOKevinyhZou
BIGO
authored andcommitted
optimize nested function calls
1 parent 5170be1 commit 2c6d6f6

File tree

8 files changed

+188
-7
lines changed

8 files changed

+188
-7
lines changed

backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala

+10
Original file line numberDiff line numberDiff line change
@@ -947,4 +947,14 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
947947
outputAttributes: Seq[Attribute],
948948
child: Seq[SparkPlan]): ColumnarRangeBaseExec =
949949
CHRangeExecTransformer(start, end, step, numSlices, numElements, outputAttributes, child)
950+
951+
override def genCollapseNestedExpressionsTransformer(
952+
substraitExprName: String,
953+
children: Seq[ExpressionTransformer],
954+
original: Expression): ExpressionTransformer =
955+
CHCollapseNestedExpressionsTransformer(substraitExprName, children, original)
956+
957+
override def expressionCollapseSupported(exprName: String): Boolean =
958+
GlutenConfig.get.getSupportedCollapsedExpressions.split(",").exists(c => exprName.equals(c))
959+
950960
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.gluten.expression
18+
19+
import org.apache.gluten.config.GlutenConfig
20+
import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode}
21+
22+
import org.apache.spark.internal.Logging
23+
import org.apache.spark.sql.catalyst.expressions._
24+
import org.apache.spark.sql.types.DataType
25+
26+
import java.util
27+
28+
/**
29+
* Collapse nested expressions for optimization, to reduce expression calls. Now support `and`,
30+
* `or`. e.g. select ... and(and(a=1, b=2), c=3) => select ... and(a=1, b=2, c=3).
31+
*/
32+
case class CHCollapseNestedExpressionsTransformer(
33+
substraitExprName: String,
34+
children: Seq[ExpressionTransformer],
35+
original: Expression)
36+
extends ExpressionTransformer
37+
with Logging {
38+
39+
override def doTransform(args: Object): ExpressionNode = {
40+
if (canBeOptimized(original)) {
41+
val functionMap = args.asInstanceOf[util.HashMap[String, java.lang.Long]]
42+
val newExprNode = doTransform(original, children, functionMap)
43+
logDebug("The new expression node: " + newExprNode.toProtobuf)
44+
newExprNode
45+
} else {
46+
super.doTransform(args)
47+
}
48+
}
49+
50+
def getExpressionName(expr: Expression): Option[String] = expr match {
51+
case _: And => ExpressionMappings.expressionsMap.get(classOf[And])
52+
case _: Or => ExpressionMappings.expressionsMap.get(classOf[Or])
53+
case _ => Option.empty[String]
54+
}
55+
56+
private def canBeOptimized(expr: Expression): Boolean = {
57+
var exprCall = expr
58+
expr match {
59+
case a: Alias => exprCall = a.child
60+
case _ =>
61+
}
62+
val exprName = getExpressionName(exprCall)
63+
exprName match {
64+
case None =>
65+
exprCall match {
66+
case _: LeafExpression => false
67+
case _ => exprCall.children.exists(c => canBeOptimized(c))
68+
}
69+
case Some(f) =>
70+
GlutenConfig.get.getSupportedCollapsedExpressions.split(",").exists(c => c.equals(f))
71+
}
72+
}
73+
74+
private def doTransform0(
75+
expr: Expression,
76+
dataType: DataType,
77+
childNodes: Seq[ExpressionNode],
78+
childTypes: Seq[DataType],
79+
functionMap: util.Map[String, java.lang.Long]): ExpressionNode = {
80+
val funcName: String = ConverterUtils.makeFuncName(substraitExprName, childTypes)
81+
val functionId = ExpressionBuilder.newScalarFunction(functionMap, funcName)
82+
val childNodeList = new util.ArrayList[ExpressionNode]()
83+
childNodes.foreach(c => childNodeList.add(c))
84+
val typeNode = ConverterUtils.getTypeNode(dataType, expr.nullable)
85+
ExpressionBuilder.makeScalarFunction(functionId, childNodeList, typeNode)
86+
}
87+
88+
private def doTransform(
89+
expr: Expression,
90+
transformers: Seq[ExpressionTransformer],
91+
functionMap: util.Map[String, java.lang.Long]): ExpressionNode = {
92+
93+
var dataType = null.asInstanceOf[DataType]
94+
var children = Seq.empty[ExpressionNode]
95+
var childTypes = Seq.empty[DataType]
96+
97+
def f(
98+
e: Expression,
99+
ts: ExpressionTransformer = null,
100+
parent: Option[Expression] = Option.empty): Unit = {
101+
parent match {
102+
case None =>
103+
dataType = e.dataType
104+
e match {
105+
case a: And if canBeOptimized(a) =>
106+
f(a.left, transformers.head, Option.apply(a))
107+
f(a.right, transformers(1), Option.apply(a))
108+
case o: Or if canBeOptimized(o) =>
109+
f(o.left, transformers.head, Option.apply(o))
110+
f(o.right, transformers(1), Option.apply(o))
111+
case _ =>
112+
}
113+
case Some(_: And) =>
114+
e match {
115+
case a: And if canBeOptimized(a) =>
116+
val childTransformers = ts.children
117+
f(a.left, childTransformers.head, Option.apply(a))
118+
f(a.right, childTransformers(1), Option.apply(a))
119+
case _ =>
120+
children +:= ts.doTransform(functionMap)
121+
childTypes +:= e.dataType
122+
}
123+
case Some(_: Or) =>
124+
e match {
125+
case o: Or if canBeOptimized(o) =>
126+
val childTransformers = ts.children
127+
f(o.left, childTransformers.head, Option.apply(o))
128+
f(o.right, childTransformers(1), Option.apply(o))
129+
case _ =>
130+
children +:= ts.doTransform(functionMap)
131+
childTypes +:= e.dataType
132+
}
133+
case _ =>
134+
}
135+
}
136+
f(expr)
137+
doTransform0(expr, dataType, children, childTypes, functionMap)
138+
}
139+
}

backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala

+15
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,21 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS
379379
}
380380
}
381381

382+
test("GLUTEN-8557: Optimize nested and/or") {
383+
384+
runQueryAndCompare(
385+
"SELECT count(1) from json_test where int_field1 = 5 and double_field1 > 1.0" +
386+
" and string_field1 is not null") { _ => }
387+
388+
runQueryAndCompare(
389+
"SELECT count(1) from json_test where int_field1 = 5 or double_field1 > 1.0" +
390+
" or string_field1 is not null") { _ => }
391+
392+
runQueryAndCompare(
393+
"SELECT count(1) from json_test where int_field1 = 5 and double_field1 > 1.0" +
394+
" or double_field1 < 100 or string_field1 is not null") { _ => }
395+
}
396+
382397
test("Test covar_samp") {
383398
runQueryAndCompare("SELECT covar_samp(double_field1, int_field1) from json_test") { _ => }
384399
}

cpp-ch/local-engine/Functions/SparkFunctionTupleElement.cpp

-6
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,12 @@ class SparkFunctionTupleElement : public IFunction
6868
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
6969
{
7070
const size_t number_of_arguments = arguments.size();
71-
7271
if (number_of_arguments < 2 || number_of_arguments > 3)
7372
throw Exception(
7473
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
7574
"Number of arguments for function {} doesn't match: passed {}, should be 2 or 3",
7675
getName(),
7776
number_of_arguments);
78-
7977
std::vector<bool> arrays_is_nullable;
8078
DataTypePtr input_type = arguments[0].type;
8179
while (const DataTypeArray * array = checkAndGetDataType<DataTypeArray>(removeNullable(input_type).get()))
@@ -108,9 +106,6 @@ class SparkFunctionTupleElement : public IFunction
108106
if (*it)
109107
return_type = makeNullable(return_type);
110108
}
111-
112-
// std::cout << "return_type:" << return_type->getName() << std::endl;
113-
114109
return return_type;
115110
}
116111
else
@@ -163,7 +158,6 @@ class SparkFunctionTupleElement : public IFunction
163158
return arguments[2].column;
164159

165160
ColumnPtr res = input_col_as_tuple->getColumns()[index.value()];
166-
167161
/// Wrap into Nullable if needed
168162
if (input_col_as_nullable_tuple)
169163
{

cpp-ch/local-engine/Parser/FunctionParser.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ std::pair<DataTypePtr, Field> FunctionParser::parseLiteral(const substrait::Expr
9898
ActionsDAG::NodeRawConstPtrs
9999
FunctionParser::parseFunctionArguments(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAG & actions_dag) const
100100
{
101-
ActionsDAG::NodeRawConstPtrs parsed_args;
102101
return expression_parser->parseFunctionArguments(actions_dag, substrait_func);
103102
}
104103

gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala

+7
Original file line numberDiff line numberDiff line change
@@ -711,4 +711,11 @@ trait SparkPlanExecApi {
711711
numElements: BigInt,
712712
outputAttributes: Seq[Attribute],
713713
child: Seq[SparkPlan]): ColumnarRangeBaseExec
714+
715+
def genCollapseNestedExpressionsTransformer(
716+
substraitExprName: String,
717+
children: Seq[ExpressionTransformer],
718+
original: Expression): ExpressionTransformer
719+
720+
def expressionCollapseSupported(exprName: String): Boolean = false
714721
}

gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala

+8
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,14 @@ object ExpressionConverter extends SQLConfHelper with Logging {
742742
substraitExprName,
743743
expr.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)),
744744
j)
745+
case expr
746+
if BackendsApiManager.getSparkPlanExecApiInstance.expressionCollapseSupported(
747+
ExpressionMappings.expressionsMap.getOrElse(expr.getClass, "")) =>
748+
BackendsApiManager.getSparkPlanExecApiInstance.genCollapseNestedExpressionsTransformer(
749+
substraitExprName,
750+
expr.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)),
751+
expr
752+
)
745753
case expr =>
746754
GenericExpressionTransformer(
747755
substraitExprName,

shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala

+9
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ class GlutenConfig(conf: SQLConf) extends Logging {
123123
def scanFileSchemeValidationEnabled: Boolean =
124124
getConf(VELOX_SCAN_FILE_SCHEME_VALIDATION_ENABLED)
125125

126+
def getSupportedCollapsedExpressions: String = getConf(GLUTEN_SUPPORTED_COLLAPSED_FUNCTIONS)
127+
126128
// Whether to use GlutenShuffleManager (experimental).
127129
def isUseGlutenShuffleManager: Boolean =
128130
conf
@@ -689,6 +691,13 @@ object GlutenConfig {
689691
.stringConf
690692
.createWithDefault("")
691693

694+
val GLUTEN_SUPPORTED_COLLAPSED_FUNCTIONS =
695+
buildConf("spark.gluten.sql.supported.collapseNestedFunctions")
696+
.internal()
697+
.doc("Collapse nested functions as one for optimization.")
698+
.stringConf
699+
.createWithDefault("and,or");
700+
692701
val GLUTEN_SOFT_AFFINITY_ENABLED =
693702
buildConf("spark.gluten.soft-affinity.enabled")
694703
.doc("Whether to enable Soft Affinity scheduling.")

0 commit comments

Comments
 (0)