Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-8966][VL] Propagate HashAggregate's ignoreNullKeys when possible #8967

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ object VeloxRuleApi {
injector.injectPostTransform(_ => EliminateLocalSort)
injector.injectPostTransform(_ => CollapseProjectExecTransformer)
injector.injectPostTransform(c => FlushableHashAggregateRule.apply(c.session))
injector.injectPostTransform(c => HashAggregateIgnoreNullKeysRule.apply(c.session))
injector.injectPostTransform(c => InsertTransitions.create(c.outputsColumnar, VeloxBatch))

// Gluten columnar: Fallback policies.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ class VeloxConfig(conf: SQLConf) extends GlutenConfig(conf) {

def veloxOrcScanEnabled: Boolean =
getConf(VELOX_ORC_SCAN_ENABLED)

def enablePropagateIgnoreNullKeys: Boolean =
getConf(VELOX_PROPAGATE_IGNORE_NULL_KEYS_ENABLED)
}

object VeloxConfig {
Expand Down Expand Up @@ -520,4 +523,13 @@ object VeloxConfig {
.internal()
.stringConf
.createWithDefault("")

val VELOX_PROPAGATE_IGNORE_NULL_KEYS_ENABLED =
buildConf("spark.gluten.sql.columnar.backend.velox.propagateIgnoreNullKeys")
.doc(
"If enabled, we will identify aggregation followed by an inner join " +
"on the grouping keys, and mark the ignoreNullKeys flag to true to " +
"avoid unnecessary aggregation on null keys.")
.booleanConf
.createWithDefault(true)
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,18 @@ abstract class HashAggregateExecTransformer(
aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
child: SparkPlan,
ignoreNullKeys: Boolean)
extends HashAggregateExecBaseTransformer(
requiredChildDistributionExpressions,
groupingExpressions,
aggregateExpressions,
aggregateAttributes,
initialInputBufferOffset,
resultExpressions,
child) {
child,
ignoreNullKeys
) {

override def output: Seq[Attribute] = {
// TODO: We should have a check to make sure the returned schema actually matches the output
Expand Down Expand Up @@ -192,7 +195,8 @@ abstract class HashAggregateExecTransformer(
private def formatExtOptimizationString(isStreaming: Boolean): String = {
val isStreamingStr = if (isStreaming) "1" else "0"
val allowFlushStr = if (allowFlush) "1" else "0"
s"isStreaming=$isStreamingStr\nallowFlush=$allowFlushStr\n"
val ignoreNullKeysStr = if (ignoreNullKeys) "1" else "0"
s"isStreaming=$isStreamingStr\nallowFlush=$allowFlushStr\nignoreNullKeys=$ignoreNullKeysStr\n"
}

// Create aggregate function node and add to list.
Expand Down Expand Up @@ -705,15 +709,18 @@ case class RegularHashAggregateExecTransformer(
aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
child: SparkPlan,
ignoreNullKeys: Boolean = false)
extends HashAggregateExecTransformer(
requiredChildDistributionExpressions,
groupingExpressions,
aggregateExpressions,
aggregateAttributes,
initialInputBufferOffset,
resultExpressions,
child) {
child,
ignoreNullKeys
) {

override protected def allowFlush: Boolean = false

Expand All @@ -737,15 +744,18 @@ case class FlushableHashAggregateExecTransformer(
aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
child: SparkPlan,
ignoreNullKeys: Boolean = false)
extends HashAggregateExecTransformer(
requiredChildDistributionExpressions,
groupingExpressions,
aggregateExpressions,
aggregateAttributes,
initialInputBufferOffset,
resultExpressions,
child) {
child,
ignoreNullKeys
) {

override protected def allowFlush: Boolean = true

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.extension

import org.apache.gluten.config.VeloxConfig
import org.apache.gluten.execution._

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.joins.BaseJoinExec

/**
* To identify aggregates that the groupby key is used as inner join keys. In this case, we can set
* ignoreNullKeys to true when convert to velox's AggregateNode.
*/
case class HashAggregateIgnoreNullKeysRule(session: SparkSession) extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = {
if (
!session.conf.get(VeloxConfig.VELOX_PROPAGATE_IGNORE_NULL_KEYS_ENABLED.key, "true").toBoolean
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use VeloxConfig.get.enablePropagateIgnoreNullKeys

) {
return plan
}
plan.transformUp {
case join: BaseJoinExec if join.joinType == Inner =>
val newLeftChild = setIgnoreKeysIfAggregateOnJoinKeys(join.left, join.leftKeys)
val newRightChild = setIgnoreKeysIfAggregateOnJoinKeys(join.right, join.rightKeys)
if (!newLeftChild.fastEquals(join.left) || !newRightChild.fastEquals(join.right)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: change it to if (newLeftChild.fastEquals(join.left) && newRightChild.fastEquals(join.right)) for better understanding.

join.withNewChildren(Seq(newLeftChild, newRightChild))
} else {
join
}
case p => p
}
}

private def setIgnoreKeysIfAggregateOnJoinKeys(
plan: SparkPlan,
joinKeys: Seq[Expression]): SparkPlan = {
def transformDown: SparkPlan => SparkPlan = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need a function here? can we directly use setIgnoreKeysIfAggregateOnJoinKeys?

case agg: FlushableHashAggregateExecTransformer =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use HashAggregateExecTransformer

val newChild = transformDown(agg.child)
val canIgnoreNullKeysRule = semanticEquals(agg.groupingExpressions, joinKeys)
agg.copy(ignoreNullKeys = canIgnoreNullKeysRule, child = newChild)
case agg: RegularHashAggregateExecTransformer =>
val newChild = transformDown(agg.child)
val canIgnoreNullKeysRule = semanticEquals(agg.groupingExpressions, joinKeys)
agg.copy(ignoreNullKeys = canIgnoreNullKeysRule, child = newChild)
case s: ShuffleQueryStageExec => s.copy(plan = transformDown(s.plan))
case p if !canPropagate(p) => p
case other => other.withNewChildren(other.children.map(transformDown))
}
val out = transformDown(plan)
out
}

private def canPropagate(plan: SparkPlan): Boolean = plan match {
case _: ProjectExecTransformer => true
case _: WholeStageTransformer => true
case _: VeloxResizeBatchesExec => true
case _: ShuffleExchangeLike => true
case _: VeloxColumnarToRowExec => true
case _ => false
}

private def semanticEquals(aggExpression: Seq[Expression], joinKeys: Seq[Expression]): Boolean = {
aggExpression.size == joinKeys.size && aggExpression.zip(joinKeys).forall {
case (e1: Expression, e2: Expression) => e1.semanticEquals(e2)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.gluten.execution

import org.apache.gluten.config.VeloxConfig
import org.apache.gluten.config.{GlutenConfig, VeloxConfig}
import org.apache.gluten.extension.columnar.validator.FallbackInjects

import org.apache.spark.SparkConf
Expand Down Expand Up @@ -1193,6 +1193,28 @@ class VeloxAggregateFunctionsDefaultSuite extends VeloxAggregateFunctionsSuite {
}
}
}

test("aggregate on join keys can set ignoreNullKeys") {
val s =
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test data does not contains nulls by createTPCHNotNullTables(), so it cannot well test this feature.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test case here is used to test the plan is correctly mark ignoreNullKeys in AggregateTransformer to true, so I think the data doesn't matter.

|select count(1) from
| (select l_orderkey, max(l_partkey) from lineitem group by l_orderkey) a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if agg offload but join fallback?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't matter. if the pattern is

 agg(k1)       t
    \               /
     \            /
    join(k1 = k2) 

then we can safely ignore the null values in k1 when doing aggregation, no matter whether join is offloaded of not.

|inner join
| (select l_orderkey from lineitem) b
|on a.l_orderkey = b.l_orderkey
|""".stripMargin
withSQLConf(GlutenConfig.COLUMNAR_FORCE_SHUFFLED_HASH_JOIN_ENABLED.key -> "true") {
runQueryAndCompare(s) {
df =>
val executedPlan = getExecutedPlan(df)
assert(executedPlan.exists {
case a: RegularHashAggregateExecTransformer if a.ignoreNullKeys => true
case a: FlushableHashAggregateExecTransformer if a.ignoreNullKeys => true
case _ => false
})
}
}
}
}

class VeloxAggregateFunctionsFlushSuite extends VeloxAggregateFunctionsSuite {
Expand Down
5 changes: 5 additions & 0 deletions cpp/velox/substrait/SubstraitToVeloxPlan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,11 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
preGroupingExprs.insert(preGroupingExprs.begin(), veloxGroupingExprs.begin(), veloxGroupingExprs.end());
}

if (aggRel.has_advanced_extension() &&
SubstraitParser::configSetInOptimization(aggRel.advanced_extension(), "ignoreNullKeys=")) {
ignoreNullKeys = true;
}

// Get the output names of Aggregation.
std::vector<std::string> aggOutNames;
aggOutNames.reserve(aggRel.measures().size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ abstract class HashAggregateExecBaseTransformer(
aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
child: SparkPlan,
ignoreNullKeys: Boolean)
extends BaseAggregateExec
with UnaryTransformSupport {

Expand Down Expand Up @@ -87,11 +88,13 @@ abstract class HashAggregateExecBaseTransformer(
s"HashAggregateTransformer(keys=$keyString, " +
s"functions=$functionString, " +
s"isStreamingAgg=$isCapableForStreamingAggregation, " +
s"ignoreNullKeys=$ignoreNullKeys, " +
s"output=$outputString)"
} else {
s"HashAggregateTransformer(keys=$keyString, " +
s"functions=$functionString, " +
s"isStreamingAgg=$isCapableForStreamingAggregation)"
s"isStreamingAgg=$isCapableForStreamingAggregation, " +
s"ignoreNullKeys=$ignoreNullKeys)"
}
}

Expand Down
Loading