diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala index 3b15fa2263de..566cffd1bf18 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala @@ -96,6 +96,7 @@ object VeloxRuleApi { injector.injectPostTransform(_ => PushDownInputFileExpression.PostOffload) injector.injectPostTransform(_ => EnsureLocalSortRequirements) injector.injectPostTransform(_ => EliminateLocalSort) + injector.injectPostTransform(_ => PullOutDuplicateProject) injector.injectPostTransform(_ => CollapseProjectExecTransformer) injector.injectPostTransform(c => FlushableHashAggregateRule.apply(c.session)) injector.injectPostTransform(c => InsertTransitions.create(c.outputsColumnar, VeloxBatch)) @@ -179,6 +180,7 @@ object VeloxRuleApi { injector.injectPostTransform(_ => PushDownInputFileExpression.PostOffload) injector.injectPostTransform(_ => EnsureLocalSortRequirements) injector.injectPostTransform(_ => EliminateLocalSort) + injector.injectPostTransform(_ => PullOutDuplicateProject) injector.injectPostTransform(_ => CollapseProjectExecTransformer) injector.injectPostTransform(c => FlushableHashAggregateRule.apply(c.session)) injector.injectPostTransform(c => InsertTransitions.create(c.outputsColumnar, VeloxBatch)) diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala index 2023ad97ebab..92e232316fbe 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala @@ -195,4 +195,29 @@ class VeloxHashJoinSuite extends VeloxWholeStageTransformerSuite { } }) } + + test("pull out duplicate projections") { + withTable("t1", "t2") { + Seq((1, 1), (2, 2)).toDF("c1", "c2").write.saveAsTable("t1") + Seq(1, 2, 3).toDF("c1").write.saveAsTable("t2") + val query = + """ + |select t3.* from + |(select c1, c2 as a,c2 as b from t1) t3 + |left join t2 + |on t3.c1 = t2.c1 + |""".stripMargin + runQueryAndCompare(query) { + df => + { + val executedPlan = getExecutedPlan(df) + val bhjs = executedPlan.collect { case p: BroadcastHashJoinExecTransformer => p } + val projects = executedPlan.collect { case p: ProjectExecTransformer => p } + assert(bhjs.size == 1) + // The pulled out project and the outermost project are collapsed. + assert(projects.size == 2) + } + } + } + } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/PullOutDuplicateProject.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/PullOutDuplicateProject.scala new file mode 100644 index 000000000000..0ce23cf67e1f --- /dev/null +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/PullOutDuplicateProject.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension.columnar + +import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.execution.{BroadcastHashJoinExecTransformerBase, LimitExecTransformer, ProjectExecTransformer} + +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeSet, PredicateHelper} +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution._ + +import scala.collection.mutable.ArrayBuffer + +/** + * Velox does not allow duplicate projections in hash probe, this rule pull out duplicate + * projections. + */ +object PullOutDuplicateProject extends Rule[SparkPlan] with PredicateHelper { + override def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case bhj: BroadcastHashJoinExecTransformerBase => + val pullOutAliases = new ArrayBuffer[Alias]() + val streamedPlan = rewriteStreamPlan(bhj.streamedPlan, bhj.references, pullOutAliases) + if (pullOutAliases.isEmpty) { + bhj + } else { + val aliasMap = AttributeMap(pullOutAliases.map(a => a.toAttribute -> a)) + val newProjectList = bhj.output.map(attr => aliasMap.getOrElse(attr, attr)) + val (newLeft, newRight) = bhj.joinBuildSide match { + case BuildLeft => (bhj.left, streamedPlan) + case BuildRight => (streamedPlan, bhj.right) + } + + val newBhj = + BackendsApiManager.getSparkPlanExecApiInstance.genBroadcastHashJoinExecTransformer( + bhj.leftKeys, + bhj.rightKeys, + bhj.hashJoinType, + bhj.joinBuildSide, + bhj.condition, + newLeft, + newRight, + bhj.genJoinParametersInternal()._2 == 1) + ProjectExecTransformer(newProjectList, newBhj) + } + } + + def rewriteStreamPlan( + plan: SparkPlan, + references: AttributeSet, + pullOutAliases: ArrayBuffer[Alias]): SparkPlan = plan match { + case l @ LimitExecTransformer(child, _, _) => + val newChild = rewriteStreamPlan(child, references, pullOutAliases) + if (pullOutAliases.isEmpty) { + l + } else { + l.copy(child = newChild) + } + case p @ ProjectExecTransformer(projectList, _) => + val duplicates = + projectList + .collect { + case attr: Attribute if !references.contains(attr) => attr + case a @ Alias(attr: Attribute, _) + if !references.contains(a) && !references.contains(attr) => + attr + } + .groupBy(identity) + .mapValues(_.size) + .filter(_._2 > 1) + .keySet + if (duplicates.nonEmpty) { + val newProjectList = projectList.map { + case a @ Alias(attr: Attribute, _) if duplicates.contains(attr) => + pullOutAliases.append(a) + attr + case ne => ne + }.distinct + p.copy(projectList = newProjectList) + } else { + p + } + case _ => plan + } +}