Skip to content

Commit 72857c0

Browse files
BIGOKevinyhZou
BIGO
authored andcommitted
optimize nested function calls
1 parent ca2ab6a commit 72857c0

File tree

13 files changed

+334
-15
lines changed

13 files changed

+334
-15
lines changed

backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala

+1
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ object CHRuleApi {
122122
injector.injectPostTransform(c => RemoveDuplicatedColumns.apply(c.session))
123123
injector.injectPostTransform(c => AddPreProjectionForHashJoin.apply(c.session))
124124
injector.injectPostTransform(c => ReplaceSubStringComparison.apply(c.session))
125+
injector.injectPostTransform(c => CollapseNestedExpressions.apply(c.session))
125126

126127
// Gluten columnar: Fallback policies.
127128
injector.injectFallbackPolicy(c => p => ExpandFallbackPolicy(c.caller.isAqe(), p))

backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala

+16-1
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,8 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
581581
List(
582582
Sig[CollectList](ExpressionNames.COLLECT_LIST),
583583
Sig[CollectSet](ExpressionNames.COLLECT_SET),
584-
Sig[MonotonicallyIncreasingID](MONOTONICALLY_INCREASING_ID)
584+
Sig[MonotonicallyIncreasingID](MONOTONICALLY_INCREASING_ID),
585+
CHCollapsedExpression.signature
585586
) ++
586587
ExpressionExtensionTrait.expressionExtensionTransformer.expressionSigList ++
587588
SparkShimLoader.getSparkShims.bloomFilterExpressionMappings()
@@ -947,4 +948,18 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
947948
outputAttributes: Seq[Attribute],
948949
child: Seq[SparkPlan]): ColumnarRangeBaseExec =
949950
CHRangeExecTransformer(start, end, step, numSlices, numElements, outputAttributes, child)
951+
952+
override def expressionCollapseSupported(expr: Expression): Boolean = expr match {
953+
case ce: CHCollapsedExpression => CHCollapsedExpression.supported(ce.name)
954+
case _ => false
955+
}
956+
957+
override def genCollapsedExpressionTransformer(
958+
substraitName: String,
959+
children: Seq[ExpressionTransformer],
960+
expr: Expression): ExpressionTransformer = expr match {
961+
case ce: CHCollapsedExpression =>
962+
GenericExpressionTransformer(ce.name, children, ce)
963+
case _ => super.genCollapsedExpressionTransformer(substraitName, children, expr)
964+
}
950965
}

backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHFilterExecTransformer.scala

+12-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
*/
1717
package org.apache.gluten.execution
1818

19-
import org.apache.spark.sql.catalyst.expressions.{And, Expression}
19+
import org.apache.gluten.expression.CHCollapsedExpression
20+
21+
import org.apache.spark.sql.catalyst.expressions.{And, Expression, ExprId, IsNotNull}
2022
import org.apache.spark.sql.execution.SparkPlan
2123

2224
case class CHFilterExecTransformer(condition: Expression, child: SparkPlan)
@@ -48,4 +50,13 @@ case class FilterExecTransformer(condition: Expression, child: SparkPlan)
4850
override protected def getRemainingCondition: Expression = condition
4951
override protected def withNewChildInternal(newChild: SparkPlan): FilterExecTransformer =
5052
copy(child = newChild)
53+
override protected val notNullAttributes: Seq[ExprId] = condition match {
54+
case s: CHCollapsedExpression =>
55+
val (notNullPreds, _) = s.children.partition {
56+
case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(child.outputSet)
57+
case _ => false
58+
}
59+
notNullPreds.flatMap(_.references).distinct.map(_.exprId)
60+
case _ => notNullPreds.flatMap(_.references).distinct.map(_.exprId)
61+
}
5162
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.gluten.expression
18+
19+
import org.apache.gluten.config.GlutenConfig
20+
21+
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.catalyst.expressions.Expression
23+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
24+
import org.apache.spark.sql.types.DataType
25+
26+
case class CHCollapsedExpression(
27+
dataType: DataType,
28+
children: Seq[Expression],
29+
name: String,
30+
nullable: Boolean = true,
31+
original: Expression)
32+
extends Expression {
33+
34+
override def toString: String = s"$name(${children.mkString(", ")})"
35+
36+
override def eval(input: InternalRow): Any = original.eval(input)
37+
38+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = null
39+
40+
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
41+
copy(children = newChildren)
42+
43+
}
44+
45+
object CHCollapsedExpression {
46+
47+
def signature: Sig = Sig[CHCollapsedExpression]("CHCollapsedExpression")
48+
49+
def supported(name: String): Boolean = {
50+
GlutenConfig.get.getSupportedCollapsedExpressions.split(",").exists(p => p.equals(name))
51+
}
52+
53+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.gluten.extension
18+
19+
import org.apache.gluten.execution.{FilterExecTransformer, ProjectExecTransformer}
20+
import org.apache.gluten.expression.{CHCollapsedExpression, ExpressionMappings}
21+
22+
import org.apache.spark.sql.SparkSession
23+
import org.apache.spark.sql.catalyst.expressions._
24+
import org.apache.spark.sql.catalyst.rules.Rule
25+
import org.apache.spark.sql.execution.SparkPlan
26+
import org.apache.spark.sql.types.DataType
27+
28+
/**
29+
* Collapse nested expressions for optimization, to reduce expression calls. Now support `and`,
30+
* `or`. e.g. select ... and(and(a=1, b=2), c=3) => select ... and(a=1, b=2, c=3).
31+
*/
32+
case class CollapseNestedExpressions(spark: SparkSession) extends Rule[SparkPlan] {
33+
34+
override def apply(plan: SparkPlan): SparkPlan = {
35+
if (canBeOptimized(plan)) {
36+
visitPlan(plan)
37+
} else {
38+
plan
39+
}
40+
}
41+
42+
private def canBeOptimized(plan: SparkPlan): Boolean = plan match {
43+
case p: ProjectExecTransformer =>
44+
var res = p.projectList.exists(c => c.isInstanceOf[And] || c.isInstanceOf[Or])
45+
if (res) {
46+
return false
47+
}
48+
res = p.projectList.exists(c => canBeOptimized(c))
49+
if (!res) {
50+
res = p.children.exists(c => canBeOptimized(c))
51+
}
52+
res
53+
case f: FilterExecTransformer =>
54+
var res = canBeOptimized(f.condition)
55+
if (!res) {
56+
res = canBeOptimized(f.child)
57+
}
58+
res
59+
case _ => plan.children.exists(c => canBeOptimized(c))
60+
}
61+
62+
private def canBeOptimized(expr: Expression): Boolean = {
63+
var exprCall = expr
64+
expr match {
65+
case a: Alias => exprCall = a.child
66+
case _ =>
67+
}
68+
val exprName = getExpressionName(exprCall)
69+
exprName match {
70+
case None =>
71+
exprCall match {
72+
case _: LeafExpression => false
73+
case _ => exprCall.children.exists(c => canBeOptimized(c))
74+
}
75+
case Some(f) =>
76+
CHCollapsedExpression.supported(f)
77+
}
78+
}
79+
80+
private def getExpressionName(expr: Expression): Option[String] = expr match {
81+
case _: And => ExpressionMappings.expressionsMap.get(classOf[And])
82+
case _: Or => ExpressionMappings.expressionsMap.get(classOf[Or])
83+
case _ => Option.empty[String]
84+
}
85+
86+
private def visitPlan(plan: SparkPlan): SparkPlan = plan match {
87+
case p: ProjectExecTransformer =>
88+
var newProjectList = Seq.empty[NamedExpression]
89+
p.projectList.foreach {
90+
case a: Alias =>
91+
val newAlias = Alias(optimize(a.child), a.name)(a.exprId)
92+
newProjectList :+= newAlias
93+
case p =>
94+
newProjectList :+= p
95+
}
96+
val newChild = visitPlan(p.child)
97+
ProjectExecTransformer(newProjectList, newChild)
98+
case f: FilterExecTransformer =>
99+
val newCondition = optimize(f.condition)
100+
val newChild = visitPlan(f.child)
101+
FilterExecTransformer(newCondition, newChild)
102+
case _ =>
103+
val newChildren = plan.children.map(p => visitPlan(p))
104+
plan.withNewChildren(newChildren)
105+
}
106+
107+
private def optimize(expr: Expression): Expression = {
108+
var resultExpr = expr
109+
var name = getExpressionName(expr)
110+
var children = Seq.empty[Expression]
111+
var dataType = null.asInstanceOf[DataType]
112+
113+
def f(e: Expression, parent: Option[Expression] = Option.empty[Expression]): Unit = {
114+
parent match {
115+
case None =>
116+
name = getExpressionName(e)
117+
dataType = e.dataType
118+
case _ =>
119+
}
120+
e match {
121+
case a: And if canBeOptimized(a) =>
122+
parent match {
123+
case Some(_: And) | None =>
124+
f(a.left, Option.apply(a))
125+
f(a.right, Option.apply(a))
126+
case _ =>
127+
children +:= optimize(a)
128+
}
129+
case o: Or if canBeOptimized(o) =>
130+
parent match {
131+
case Some(_: Or) | None =>
132+
f(o.left, parent = Option.apply(o))
133+
f(o.right, parent = Option.apply(o))
134+
case _ =>
135+
children +:= optimize(o)
136+
}
137+
case _ =>
138+
if (parent.nonEmpty) {
139+
children +:= optimize(e)
140+
} else {
141+
children = Seq.empty[Expression]
142+
val exprNewChildren = e.children.map(p => optimize(p))
143+
resultExpr = e.withNewChildren(exprNewChildren)
144+
}
145+
}
146+
}
147+
f(expr)
148+
if (name.isDefined || collapsedExpressionExists(children)) {
149+
CHCollapsedExpression(dataType, children, name.getOrElse(""), expr.nullable, expr)
150+
} else {
151+
resultExpr
152+
}
153+
}
154+
155+
private def collapsedExpressionExists(children: Seq[Expression]): Boolean = {
156+
var res = false
157+
children.foreach {
158+
case _: CHCollapsedExpression if !res => res = true
159+
case c if !res => res = collapsedExpressionExists(c.children)
160+
case _ =>
161+
}
162+
res
163+
}
164+
}

backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala

+30
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
*/
1717
package org.apache.gluten.execution
1818

19+
import org.apache.gluten.expression.CHCollapsedExpression
20+
1921
import org.apache.spark.SparkConf
2022
import org.apache.spark.sql.{DataFrame, GlutenTestUtils, Row}
2123
import org.apache.spark.sql.catalyst.expressions._
@@ -379,6 +381,34 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS
379381
}
380382
}
381383

384+
test("GLUTEN-8557: Optimize nested and/or") {
385+
def checkCollapsedFunctions(plan: SparkPlan, functionName: String, argNum: Int): Boolean = {
386+
387+
def checkExpression(expr: Expression, functionName: String, argNum: Int): Boolean =
388+
expr match {
389+
case s: CHCollapsedExpression
390+
if s.name.equals(functionName) && s.children.size == argNum =>
391+
true
392+
case _ => expr.children.exists(c => checkExpression(c, functionName, argNum))
393+
}
394+
plan match {
395+
case f: FilterExecTransformer => return checkExpression(f.condition, functionName, argNum)
396+
case _ => return plan.children.exists(c => checkCollapsedFunctions(c, functionName, argNum))
397+
}
398+
false
399+
}
400+
runQueryAndCompare(
401+
"SELECT count(1) from json_test where int_field1 = 5 and double_field1 > 1.0" +
402+
" and string_field1 is not null") {
403+
x => assert(checkCollapsedFunctions(x.queryExecution.executedPlan, "and", 5))
404+
}
405+
runQueryAndCompare(
406+
"SELECT count(1) from json_test where int_field1 = 5 or double_field1 > 1.0" +
407+
" or string_field1 is not null") {
408+
x => assert(checkCollapsedFunctions(x.queryExecution.executedPlan, "or", 3))
409+
}
410+
}
411+
382412
test("Test covar_samp") {
383413
runQueryAndCompare("SELECT covar_samp(double_field1, int_field1) from json_test") { _ => }
384414
}

cpp-ch/local-engine/Functions/SparkFunctionTupleElement.cpp

-6
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,12 @@ class SparkFunctionTupleElement : public IFunction
6868
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
6969
{
7070
const size_t number_of_arguments = arguments.size();
71-
7271
if (number_of_arguments < 2 || number_of_arguments > 3)
7372
throw Exception(
7473
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
7574
"Number of arguments for function {} doesn't match: passed {}, should be 2 or 3",
7675
getName(),
7776
number_of_arguments);
78-
7977
std::vector<bool> arrays_is_nullable;
8078
DataTypePtr input_type = arguments[0].type;
8179
while (const DataTypeArray * array = checkAndGetDataType<DataTypeArray>(removeNullable(input_type).get()))
@@ -108,9 +106,6 @@ class SparkFunctionTupleElement : public IFunction
108106
if (*it)
109107
return_type = makeNullable(return_type);
110108
}
111-
112-
// std::cout << "return_type:" << return_type->getName() << std::endl;
113-
114109
return return_type;
115110
}
116111
else
@@ -163,7 +158,6 @@ class SparkFunctionTupleElement : public IFunction
163158
return arguments[2].column;
164159

165160
ColumnPtr res = input_col_as_tuple->getColumns()[index.value()];
166-
167161
/// Wrap into Nullable if needed
168162
if (input_col_as_nullable_tuple)
169163
{

cpp-ch/local-engine/Parser/FunctionParser.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ std::pair<DataTypePtr, Field> FunctionParser::parseLiteral(const substrait::Expr
9898
ActionsDAG::NodeRawConstPtrs
9999
FunctionParser::parseFunctionArguments(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAG & actions_dag) const
100100
{
101-
ActionsDAG::NodeRawConstPtrs parsed_args;
102101
return expression_parser->parseFunctionArguments(actions_dag, substrait_func);
103102
}
104103

gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/ScalarFunctionNode.java

+12
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,18 @@ public class ScalarFunctionNode implements ExpressionNode, Serializable {
3636
this.typeNode = typeNode;
3737
}
3838

39+
public Long getFunctionId() {
40+
return functionId;
41+
}
42+
43+
public List<ExpressionNode> getExpressionNodes() {
44+
return expressionNodes;
45+
}
46+
47+
public TypeNode getTypeNode() {
48+
return typeNode;
49+
}
50+
3951
@Override
4052
public Expression toProtobuf() {
4153
Expression.ScalarFunction.Builder scalarBuilder = Expression.ScalarFunction.newBuilder();

0 commit comments

Comments
 (0)