Skip to content

Commit

Permalink
check and set ignoreNullKeys
Browse files Browse the repository at this point in the history
  • Loading branch information
WangGuangxin committed Mar 11, 2025
1 parent 26f025f commit 455f62a
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ object VeloxRuleApi {
injector.injectPostTransform(_ => EliminateLocalSort)
injector.injectPostTransform(_ => CollapseProjectExecTransformer)
injector.injectPostTransform(c => FlushableHashAggregateRule.apply(c.session))
injector.injectPostTransform(c => HashAggregateIgnoreNullKeysRule.apply(c.session))
injector.injectPostTransform(c => InsertTransitions.create(c.outputsColumnar, VeloxBatch))

// Gluten columnar: Fallback policies.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ class VeloxConfig(conf: SQLConf) extends GlutenConfig(conf) {

def veloxOrcScanEnabled: Boolean =
getConf(VELOX_ORC_SCAN_ENABLED)

def enablePropagateIgnoreNullKeys: Boolean =
getConf(VELOX_PROPAGATE_IGNORE_NULL_KEYS_ENABLED)
}

object VeloxConfig {
Expand Down Expand Up @@ -520,4 +523,13 @@ object VeloxConfig {
.internal()
.stringConf
.createWithDefault("")

val VELOX_PROPAGATE_IGNORE_NULL_KEYS_ENABLED =
buildConf("spark.gluten.sql.columnar.backend.velox.propagateIgnoreNullKeys")
.doc(
"If enabled, we will identify aggregation followed by an inner join " +
"on the grouping keys, and mark the ignoreNullKeys flag to true to " +
"avoid unnecessary aggregation on null keys.")
.booleanConf
.createWithDefault(true)
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,18 @@ abstract class HashAggregateExecTransformer(
aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
child: SparkPlan,
ignoreNullKeys: Boolean)
extends HashAggregateExecBaseTransformer(
requiredChildDistributionExpressions,
groupingExpressions,
aggregateExpressions,
aggregateAttributes,
initialInputBufferOffset,
resultExpressions,
child) {
child,
ignoreNullKeys
) {

override def output: Seq[Attribute] = {
// TODO: We should have a check to make sure the returned schema actually matches the output
Expand Down Expand Up @@ -192,7 +195,8 @@ abstract class HashAggregateExecTransformer(
private def formatExtOptimizationString(isStreaming: Boolean): String = {
val isStreamingStr = if (isStreaming) "1" else "0"
val allowFlushStr = if (allowFlush) "1" else "0"
s"isStreaming=$isStreamingStr\nallowFlush=$allowFlushStr\n"
val ignoreNullKeysStr = if (ignoreNullKeys) "1" else "0"
s"isStreaming=$isStreamingStr\nallowFlush=$allowFlushStr\nignoreNullKeys=$ignoreNullKeysStr\n"
}

// Create aggregate function node and add to list.
Expand Down Expand Up @@ -705,15 +709,18 @@ case class RegularHashAggregateExecTransformer(
aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
child: SparkPlan,
ignoreNullKeys: Boolean = false)
extends HashAggregateExecTransformer(
requiredChildDistributionExpressions,
groupingExpressions,
aggregateExpressions,
aggregateAttributes,
initialInputBufferOffset,
resultExpressions,
child) {
child,
ignoreNullKeys
) {

override protected def allowFlush: Boolean = false

Expand All @@ -737,15 +744,18 @@ case class FlushableHashAggregateExecTransformer(
aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
child: SparkPlan,
ignoreNullKeys: Boolean = false)
extends HashAggregateExecTransformer(
requiredChildDistributionExpressions,
groupingExpressions,
aggregateExpressions,
aggregateAttributes,
initialInputBufferOffset,
resultExpressions,
child) {
child,
ignoreNullKeys
) {

override protected def allowFlush: Boolean = true

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.extension

import org.apache.gluten.config.VeloxConfig
import org.apache.gluten.execution._

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.joins.BaseJoinExec

/**
* To identify aggregates that the groupby key is used as inner join keys. In this case, we can set
* ignoreNullKeys to true when convert to velox's AggregateNode.
*/
case class HashAggregateIgnoreNullKeysRule(session: SparkSession) extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = {
if (
!session.conf.get(VeloxConfig.VELOX_PROPAGATE_IGNORE_NULL_KEYS_ENABLED.key, "true").toBoolean
) {
return plan
}
plan.transformUp {
case join: BaseJoinExec if join.joinType == Inner =>
val newLeftChild = setIgnoreKeysIfAggregateOnJoinKeys(join.left, join.leftKeys)
val newRightChild = setIgnoreKeysIfAggregateOnJoinKeys(join.right, join.rightKeys)
if (!newLeftChild.fastEquals(join.left) || !newRightChild.fastEquals(join.right)) {
join.withNewChildren(Seq(newLeftChild, newRightChild))
} else {
join
}
case p => p
}
}

private def setIgnoreKeysIfAggregateOnJoinKeys(
plan: SparkPlan,
joinKeys: Seq[Expression]): SparkPlan = {
def transformDown: SparkPlan => SparkPlan = {
case agg: FlushableHashAggregateExecTransformer =>
val newChild = transformDown(agg.child)
val canIgnoreNullKeysRule = semanticEquals(agg.groupingExpressions, joinKeys)
agg.copy(ignoreNullKeys = canIgnoreNullKeysRule, child = newChild)
case agg: RegularHashAggregateExecTransformer =>
val newChild = transformDown(agg.child)
val canIgnoreNullKeysRule = semanticEquals(agg.groupingExpressions, joinKeys)
agg.copy(ignoreNullKeys = canIgnoreNullKeysRule, child = newChild)
case s: ShuffleQueryStageExec => s.copy(plan = transformDown(s.plan))
case p if !canPropagate(p) => p
case other => other.withNewChildren(other.children.map(transformDown))
}
val out = transformDown(plan)
out
}

private def canPropagate(plan: SparkPlan): Boolean = plan match {
case _: ProjectExecTransformer => true
case _: WholeStageTransformer => true
case _: VeloxResizeBatchesExec => true
case _: ShuffleExchangeLike => true
case _: VeloxColumnarToRowExec => true
case _ => false
}

private def semanticEquals(aggExpression: Seq[Expression], joinKeys: Seq[Expression]): Boolean = {
aggExpression.size == joinKeys.size && aggExpression.zip(joinKeys).forall {
case (e1: Expression, e2: Expression) => e1.semanticEquals(e2)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.gluten.execution

import org.apache.gluten.config.VeloxConfig
import org.apache.gluten.config.{GlutenConfig, VeloxConfig}
import org.apache.gluten.extension.columnar.validator.FallbackInjects

import org.apache.spark.SparkConf
Expand Down Expand Up @@ -1193,6 +1193,28 @@ class VeloxAggregateFunctionsDefaultSuite extends VeloxAggregateFunctionsSuite {
}
}
}

test("aggregate on join keys can set ignoreNullKeys") {
val s =
"""
|select count(1) from
| (select l_orderkey, max(l_partkey) from lineitem group by l_orderkey) a
|inner join
| (select l_orderkey from lineitem) b
|on a.l_orderkey = b.l_orderkey
|""".stripMargin
withSQLConf(GlutenConfig.COLUMNAR_FORCE_SHUFFLED_HASH_JOIN_ENABLED.key -> "true") {
runQueryAndCompare(s) {
df =>
val executedPlan = getExecutedPlan(df)
assert(executedPlan.exists {
case a: RegularHashAggregateExecTransformer if a.ignoreNullKeys => true
case a: FlushableHashAggregateExecTransformer if a.ignoreNullKeys => true
case _ => false
})
}
}
}
}

class VeloxAggregateFunctionsFlushSuite extends VeloxAggregateFunctionsSuite {
Expand Down
5 changes: 5 additions & 0 deletions cpp/velox/substrait/SubstraitToVeloxPlan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,11 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
preGroupingExprs.insert(preGroupingExprs.begin(), veloxGroupingExprs.begin(), veloxGroupingExprs.end());
}

if (aggRel.has_advanced_extension() &&
SubstraitParser::configSetInOptimization(aggRel.advanced_extension(), "ignoreNullKeys=")) {
ignoreNullKeys = true;
}

// Get the output names of Aggregation.
std::vector<std::string> aggOutNames;
aggOutNames.reserve(aggRel.measures().size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ abstract class HashAggregateExecBaseTransformer(
aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
child: SparkPlan,
ignoreNullKeys: Boolean)
extends BaseAggregateExec
with UnaryTransformSupport {

Expand Down Expand Up @@ -87,11 +88,13 @@ abstract class HashAggregateExecBaseTransformer(
s"HashAggregateTransformer(keys=$keyString, " +
s"functions=$functionString, " +
s"isStreamingAgg=$isCapableForStreamingAggregation, " +
s"ignoreNullKeys=$ignoreNullKeys, " +
s"output=$outputString)"
} else {
s"HashAggregateTransformer(keys=$keyString, " +
s"functions=$functionString, " +
s"isStreamingAgg=$isCapableForStreamingAggregation)"
s"isStreamingAgg=$isCapableForStreamingAggregation, " +
s"ignoreNullKeys=$ignoreNullKeys)"
}
}

Expand Down

0 comments on commit 455f62a

Please sign in to comment.