Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-8557][CH] Collapse nested function calls for And/Or for performance optimization #8558

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -947,4 +947,14 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
outputAttributes: Seq[Attribute],
child: Seq[SparkPlan]): ColumnarRangeBaseExec =
CHRangeExecTransformer(start, end, step, numSlices, numElements, outputAttributes, child)

override def genCollapseNestedExpressionsTransformer(
substraitExprName: String,
children: Seq[ExpressionTransformer],
original: Expression): ExpressionTransformer =
CHCollapseNestedExpressionsTransformer(substraitExprName, children, original)

override def expressionCollapseSupported(exprName: String): Boolean =
GlutenConfig.get.getSupportedCollapsedExpressions.split(",").exists(c => exprName.equals(c))

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.expression

import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.DataType

import java.util

/**
* Collapse nested expressions for optimization, to reduce expression calls. Now support `and`,
* `or`. e.g. select ... and(and(a=1, b=2), c=3) => select ... and(a=1, b=2, c=3).
*/
case class CHCollapseNestedExpressionsTransformer(
substraitExprName: String,
children: Seq[ExpressionTransformer],
original: Expression)
extends ExpressionTransformer
with Logging {

override def doTransform(args: Object): ExpressionNode = {
if (canBeOptimized(original)) {
val functionMap = args.asInstanceOf[util.HashMap[String, java.lang.Long]]
val newExprNode = doTransform(original, children, functionMap)
logDebug("The new expression node: " + newExprNode.toProtobuf)
newExprNode
} else {
super.doTransform(args)
}
}

def getExpressionName(expr: Expression): Option[String] = expr match {
case _: And => ExpressionMappings.expressionsMap.get(classOf[And])
case _: Or => ExpressionMappings.expressionsMap.get(classOf[Or])
case _ => Option.empty[String]
}

private def canBeOptimized(expr: Expression): Boolean = {
var exprCall = expr
expr match {
case a: Alias => exprCall = a.child
case _ =>
}
val exprName = getExpressionName(exprCall)
exprName match {
case None =>
exprCall match {
case _: LeafExpression => false
case _ => exprCall.children.exists(c => canBeOptimized(c))
}
case Some(f) =>
GlutenConfig.get.getSupportedCollapsedExpressions.split(",").exists(c => c.equals(f))
}
}

private def doTransform0(
expr: Expression,
dataType: DataType,
childNodes: Seq[ExpressionNode],
childTypes: Seq[DataType],
functionMap: util.Map[String, java.lang.Long]): ExpressionNode = {
val funcName: String = ConverterUtils.makeFuncName(substraitExprName, childTypes)
val functionId = ExpressionBuilder.newScalarFunction(functionMap, funcName)
val childNodeList = new util.ArrayList[ExpressionNode]()
childNodes.foreach(c => childNodeList.add(c))
val typeNode = ConverterUtils.getTypeNode(dataType, expr.nullable)
ExpressionBuilder.makeScalarFunction(functionId, childNodeList, typeNode)
}

private def doTransform(
expr: Expression,
transformers: Seq[ExpressionTransformer],
functionMap: util.Map[String, java.lang.Long]): ExpressionNode = {

var dataType = null.asInstanceOf[DataType]
var children = Seq.empty[ExpressionNode]
var childTypes = Seq.empty[DataType]

def f(
e: Expression,
ts: ExpressionTransformer = null,
parent: Option[Expression] = Option.empty): Unit = {
parent match {
case None =>
dataType = e.dataType
e match {
case a: And if canBeOptimized(a) =>
f(a.left, transformers.head, Option.apply(a))
f(a.right, transformers(1), Option.apply(a))
case o: Or if canBeOptimized(o) =>
f(o.left, transformers.head, Option.apply(o))
f(o.right, transformers(1), Option.apply(o))
case _ =>
}
case Some(_: And) =>
e match {
case a: And if canBeOptimized(a) =>
val childTransformers = ts.children
f(a.left, childTransformers.head, Option.apply(a))
f(a.right, childTransformers(1), Option.apply(a))
case _ =>
children +:= ts.doTransform(functionMap)
childTypes +:= e.dataType
}
case Some(_: Or) =>
e match {
case o: Or if canBeOptimized(o) =>
val childTransformers = ts.children
f(o.left, childTransformers.head, Option.apply(o))
f(o.right, childTransformers(1), Option.apply(o))
case _ =>
children +:= ts.doTransform(functionMap)
childTypes +:= e.dataType
}
case _ =>
}
}
f(expr)
if (children.size <= 2) {
super.doTransform(functionMap)
} else {
doTransform0(expr, dataType, children, childTypes, functionMap)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,21 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS
}
}

test("GLUTEN-8557: Optimize nested and/or") {

runQueryAndCompare(
"SELECT count(1) from json_test where int_field1 = 5 and double_field1 > 1.0" +
" and string_field1 is not null") { _ => }

runQueryAndCompare(
"SELECT count(1) from json_test where int_field1 = 5 or double_field1 > 1.0" +
" or string_field1 is not null") { _ => }

runQueryAndCompare(
"SELECT count(1) from json_test where int_field1 = 5 and double_field1 > 1.0" +
" or double_field1 < 100 or string_field1 is not null") { _ => }
}

test("Test covar_samp") {
runQueryAndCompare("SELECT covar_samp(double_field1, int_field1) from json_test") { _ => }
}
Expand Down
6 changes: 0 additions & 6 deletions cpp-ch/local-engine/Functions/SparkFunctionTupleElement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,12 @@ class SparkFunctionTupleElement : public IFunction
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
const size_t number_of_arguments = arguments.size();

if (number_of_arguments < 2 || number_of_arguments > 3)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Number of arguments for function {} doesn't match: passed {}, should be 2 or 3",
getName(),
number_of_arguments);

std::vector<bool> arrays_is_nullable;
DataTypePtr input_type = arguments[0].type;
while (const DataTypeArray * array = checkAndGetDataType<DataTypeArray>(removeNullable(input_type).get()))
Expand Down Expand Up @@ -108,9 +106,6 @@ class SparkFunctionTupleElement : public IFunction
if (*it)
return_type = makeNullable(return_type);
}

// std::cout << "return_type:" << return_type->getName() << std::endl;

return return_type;
}
else
Expand Down Expand Up @@ -163,7 +158,6 @@ class SparkFunctionTupleElement : public IFunction
return arguments[2].column;

ColumnPtr res = input_col_as_tuple->getColumns()[index.value()];

/// Wrap into Nullable if needed
if (input_col_as_nullable_tuple)
{
Expand Down
1 change: 0 additions & 1 deletion cpp-ch/local-engine/Parser/FunctionParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ std::pair<DataTypePtr, Field> FunctionParser::parseLiteral(const substrait::Expr
ActionsDAG::NodeRawConstPtrs
FunctionParser::parseFunctionArguments(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAG & actions_dag) const
{
ActionsDAG::NodeRawConstPtrs parsed_args;
return expression_parser->parseFunctionArguments(actions_dag, substrait_func);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -711,4 +711,12 @@ trait SparkPlanExecApi {
numElements: BigInt,
outputAttributes: Seq[Attribute],
child: Seq[SparkPlan]): ColumnarRangeBaseExec

def genCollapseNestedExpressionsTransformer(
substraitExprName: String,
children: Seq[ExpressionTransformer],
original: Expression): ExpressionTransformer =
GenericExpressionTransformer(substraitExprName, children, original)

def expressionCollapseSupported(exprName: String): Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,14 @@ object ExpressionConverter extends SQLConfHelper with Logging {
substraitExprName,
expr.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)),
j)
case expr
if BackendsApiManager.getSparkPlanExecApiInstance.expressionCollapseSupported(
ExpressionMappings.expressionsMap.getOrElse(expr.getClass, "")) =>
BackendsApiManager.getSparkPlanExecApiInstance.genCollapseNestedExpressionsTransformer(
substraitExprName,
expr.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)),
expr
)
case expr =>
GenericExpressionTransformer(
substraitExprName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ class GlutenConfig(conf: SQLConf) extends Logging {
def scanFileSchemeValidationEnabled: Boolean =
getConf(VELOX_SCAN_FILE_SCHEME_VALIDATION_ENABLED)

def getSupportedCollapsedExpressions: String = getConf(GLUTEN_SUPPORTED_COLLAPSED_FUNCTIONS)

// Whether to use GlutenShuffleManager (experimental).
def isUseGlutenShuffleManager: Boolean =
conf
Expand Down Expand Up @@ -689,6 +691,13 @@ object GlutenConfig {
.stringConf
.createWithDefault("")

val GLUTEN_SUPPORTED_COLLAPSED_FUNCTIONS =
buildConf("spark.gluten.sql.supported.collapseNestedFunctions")
.internal()
.doc("Collapse nested functions as one for optimization.")
.stringConf
.createWithDefault("and,or");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@KevinyhZou, do you find any unsuitable corner cases? E.g., wrong result, performance degradation. If no, can we always enable this optimization without introducing a config?

Copy link
Contributor Author

@KevinyhZou KevinyhZou Feb 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not found unsuitable case for the ci testing and my own testing. But I'm afraid there maybe some unsuitable case in our online sqls, which the ci testing do not cover,So I think it‘s better to keep them


val GLUTEN_SOFT_AFFINITY_ENABLED =
buildConf("spark.gluten.soft-affinity.enabled")
.doc("Whether to enable Soft Affinity scheduling.")
Expand Down
Loading