Skip to content

Commit

Permalink
rewrite With expression immediately at the end of the filter pushdown…
Browse files Browse the repository at this point in the history
… rule
  • Loading branch information
zml1206 committed Jan 6, 2025
1 parent 8fd06da commit 57016f2
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,13 @@ object CommonExpressionId {

/**
* A wrapper of common expression to carry the id.
*
* @param originAlias only used for push down predicates to make it idempotent.
*/
case class CommonExpressionDef(child: Expression, id: CommonExpressionId = new CommonExpressionId())
case class CommonExpressionDef(
child: Expression,
id: CommonExpressionId = new CommonExpressionId(),
originAlias: Option[Alias] = None)
extends UnaryExpression with Unevaluable {
override def dataType: DataType = child.dataType
override protected def withNewChildInternal(newChild: Expression): Expression =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
ReorderJoin,
EliminateOuterJoin,
PushDownPredicates,
RewriteWithExpression,
CollapseProject,
PushDownLeftSemiAntiJoin,
PushLeftSemiLeftAntiThroughJoin,
OptimizeJoinCondition,
Expand Down Expand Up @@ -161,26 +163,22 @@ abstract class Optimizer(catalogManager: CatalogManager)
val operatorOptimizationBatch: Seq[Batch] = Seq(
Batch("Operator Optimization before Inferring Filters", fixedPoint,
operatorOptimizationRuleSet: _*),
Batch("Rewrite With expression", fixedPoint,
RewriteWithExpression,
CollapseProject),
Batch("Infer Filters", Once,
InferFiltersFromGenerate,
InferFiltersFromConstraints),
Batch("Operator Optimization after Inferring Filters", fixedPoint,
operatorOptimizationRuleSet: _*),
Batch("Push extra predicate through join", fixedPoint,
PushExtraPredicateThroughJoin,
PushDownPredicates),
Batch("Rewrite With expression", fixedPoint,
PushDownPredicates,
RewriteWithExpression,
CollapseProject))

val batches: Seq[Batch] = flattenBatches(Seq(
Batch("Finish Analysis", FixedPoint(1), FinishAnalysis),
// We must run this batch after `ReplaceExpressions`, as `RuntimeReplaceable` expression
// may produce `With` expressions that need to be rewritten.
Batch("Rewrite With expression", fixedPoint, RewriteWithExpression),
Batch("Rewrite With expression", Once, RewriteWithExpression),
//////////////////////////////////////////////////////////////////////////////////////////
// Optimizer rules start here
//////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -1999,34 +1997,25 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe
cond: Expression,
aliasMap: AttributeMap[Alias]): Expression = {
if (!SQLConf.get.getConf(SQLConf.ALWAYS_INLINE_COMMON_EXPR)) {
// SubqueryExpression can't contain common expression ref, replace alias for it first.
val newCond = replaceAliasForSubqueryExpression(cond, aliasMap)
val replaceWithMap = newCond.collect {case a: Attribute => a }
val replaceWithMap = cond.collect {case a: Attribute => a }
.groupBy(identity)
.transform((_, v) => v.size)
.filter(m => aliasMap.contains(m._1) && m._2 > 1)
.map(m => m._1 -> trimAliases(aliasMap.getOrElse(m._1, m._1)))
.map(m => m._1 -> aliasMap(m._1))
.filter(m => !CollapseProject.isCheap(m._2))
if (replaceWithMap.isEmpty) {
newCond
cond
} else {
val defsMap = AttributeMap(replaceWithMap.map(m => m._1 -> CommonExpressionDef(m._2)))
val defsMap = AttributeMap(replaceWithMap.map(m =>
m._1 -> CommonExpressionDef(child = trimAliases(m._2), originAlias = Some(m._2))))
val refsMap = AttributeMap(defsMap.map(m => m._1 -> new CommonExpressionRef(m._2)))
splitConjunctivePredicates(newCond)
splitConjunctivePredicates(cond)
.map(rewriteByWith(_, defsMap, refsMap))
.reduce(And)
}
} else cond
}

private def replaceAliasForSubqueryExpression(
expr: Expression,
aliasMap: AttributeMap[Alias]): Expression = {
expr.transform {
case s: SubqueryExpression => replaceAlias(s, aliasMap)
}
}

private def rewriteByWith(
expr: Expression,
defsMap: AttributeMap[CommonExpressionDef],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ import org.apache.spark.util.Utils
*/
object RewriteWithExpression extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
var p = plan
while (p.containsPattern(WITH_EXPRESSION)) {
p = applyOnce(p)
}
p
}

private def applyOnce(plan: LogicalPlan): LogicalPlan = {
plan.transformUpWithSubqueriesAndPruning(_.containsPattern(WITH_EXPRESSION)) {
// For aggregates, separate the computation of the aggregations themselves from the final
// result by moving the final result computation into a projection above it. This prevents
Expand Down Expand Up @@ -115,7 +123,7 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
_, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith = true))
val refToExpr = mutable.HashMap.empty[CommonExpressionId, Expression]

defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) =>
defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id, originAlias), index) =>
if (id.canonicalized) {
throw SparkException.internalError(
"Cannot rewrite canonicalized Common expression definitions")
Expand Down Expand Up @@ -146,12 +154,22 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
}
refToExpr(id) = existingCommonExpr.get._1.toAttribute
} else {
val aliasName = if (SQLConf.get.getConf(SQLConf.USE_COMMON_EXPR_ID_FOR_ALIAS)) {
s"_common_expr_${id.id}"
} else {
s"_common_expr_$index"
val alias = originAlias match {
case Some(a) =>
a.copy(child = child)(
exprId = a.exprId,
qualifier = a.qualifier,
explicitMetadata = Option(a.metadata),
nonInheritableMetadataKeys = a.nonInheritableMetadataKeys
)
case _ =>
val aliasName = if (SQLConf.get.getConf(SQLConf.USE_COMMON_EXPR_ID_FOR_ALIAS)) {
s"_common_expr_${id.id}"
} else {
s"_common_expr_$index"
}
Alias(child, aliasName)()
}
val alias = Alias(child, aliasName)()
val fakeProj = Project(Seq(alias), inputPlans(childPlanIndex))
if (PlanHelper.specialExpressionsInUnsupportedOperator(fakeProj).nonEmpty) {
// We have to inline the common expression if it cannot be put in a Project.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

Expand All @@ -43,11 +42,11 @@ class FilterPushdownSuite extends PlanTest {
PushPredicateThroughNonJoin,
BooleanSimplification,
PushPredicateThroughJoin,
RewriteWithExpression,
CollapseProject) ::
Batch("Push extra predicate through join", FixedPoint(10),
PushExtraPredicateThroughJoin,
PushDownPredicates) ::
Batch("Rewrite With expression", FixedPoint(10),
PushDownPredicates,
RewriteWithExpression,
CollapseProject) :: Nil
}
Expand Down Expand Up @@ -1545,35 +1544,33 @@ class FilterPushdownSuite extends PlanTest {
}

test("SPARK-50589: avoid extra expression duplication when push filter") {
withSQLConf(SQLConf.USE_COMMON_EXPR_ID_FOR_ALIAS.key -> "false") {
// through project
val originalQuery1 = testRelation
.select($"a" + $"b" as "add", $"a" - $"b" as "sub")
.where($"add" < 10 && $"add" + $"add" > 10 && $"sub" > 0)
val optimized1 = Optimize.execute(originalQuery1.analyze)
val correctAnswer1 = testRelation
.select($"a", $"b", $"c", $"a" + $"b" as "_common_expr_0")
.where($"_common_expr_0" < 10 &&
$"_common_expr_0" + $"_common_expr_0" > 10 &&
$"a" - $"b" > 0)
.select($"a" + $"b" as "add", $"a" - $"b" as "sub")
.analyze
comparePlans(optimized1, correctAnswer1)

// through aggregate
val originalQuery2 = testRelation
.groupBy($"a")($"a", $"a" + $"a" as "add", abs($"a") as "abs", count(1) as "ct")
.where($"add" < 10 && $"add" + $"add" > 10 && $"abs" > 5)
val optimized2 = Optimize.execute(originalQuery2.analyze)
val correctAnswer2 = testRelation
.select($"a", $"b", $"c", $"a" + $"a" as "_common_expr_0")
.where($"_common_expr_0" < 10 &&
$"_common_expr_0" + $"_common_expr_0" > 10 &&
abs($"a") > 5)
.select($"a", $"b", $"c")
.groupBy($"a")($"a", $"a" + $"a" as "add", abs($"a") as "abs", count(1) as "ct")
.analyze
comparePlans(optimized2, correctAnswer2)
}
// through project
val originalQuery1 = testRelation
.select($"a" + $"b" as "add", $"a" - $"b" as "sub")
.where($"add" < 10 && $"add" + $"add" > 10 && $"sub" > 0)
val optimized1 = Optimize.execute(originalQuery1.analyze)
val correctAnswer1 = testRelation
.select($"a", $"b", $"c", $"a" + $"b" as "add")
.where($"add" < 10 &&
$"add" + $"add" > 10 &&
$"a" - $"b" > 0)
.select($"a" + $"b" as "add", $"a" - $"b" as "sub")
.analyze
comparePlans(optimized1, correctAnswer1)

// through aggregate
val originalQuery2 = testRelation
.groupBy($"a")($"a", $"a" + $"a" as "add", abs($"a") as "abs", count(1) as "ct")
.where($"add" < 10 && $"add" + $"add" > 10 && $"abs" > 5)
val optimized2 = Optimize.execute(originalQuery2.analyze)
val correctAnswer2 = testRelation
.select($"a", $"b", $"c", $"a" + $"a" as "add")
.where($"add" < 10 &&
$"add" + $"add" > 10 &&
abs($"a") > 5)
.select($"a", $"b", $"c")
.groupBy($"a")($"a", $"a" + $"a" as "add", abs($"a") as "abs", count(1) as "ct")
.analyze
comparePlans(optimized2, correctAnswer2)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
val batches =
Batch("PushDownFilters", FixedPoint(100),
PushPredicateThroughJoin,
PushPredicateThroughNonJoin) ::
Batch("Rewrite With expression", FixedPoint(10),
PushPredicateThroughNonJoin,
RewriteWithExpression,
CollapseProject) ::
Batch("InferFilters", Once,
Expand All @@ -44,8 +43,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
PruneFilters) ::
Batch("PushDownFilters", FixedPoint(100),
PushPredicateThroughJoin,
PushPredicateThroughNonJoin) ::
Batch("Rewrite With expression", FixedPoint(10),
PushPredicateThroughNonJoin,
RewriteWithExpression,
CollapseProject) :: Nil
}
Expand Down Expand Up @@ -154,24 +152,22 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
}

test("inner join with alias: alias contains multiple attributes") {
withSQLConf(SQLConf.USE_COMMON_EXPR_ID_FOR_ALIAS.key -> "false") {
val t1 = testRelation.subquery("t1")
val t2 = testRelation.subquery("t2")

val originalQuery = t1.select($"a", Coalesce(Seq($"a", $"b")).as("int_col")).as("t")
.join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
.analyze
val correctAnswer = t1
.select($"a", $"b", $"c", Coalesce(Seq($"a", $"b")) as "_common_expr_0")
.where(IsNotNull($"a") && IsNotNull($"_common_expr_0") &&
$"a" === $"_common_expr_0")
.select($"a", Coalesce(Seq($"a", $"b")).as("int_col")).as("t")
.join(t2.where(IsNotNull($"a")), Inner,
Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
.analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}
val t1 = testRelation.subquery("t1")
val t2 = testRelation.subquery("t2")

val originalQuery = t1.select($"a", Coalesce(Seq($"a", $"b")).as("int_col")).as("t")
.join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
.analyze
val correctAnswer = t1
.select($"a", $"b", $"c", Coalesce(Seq($"a", $"b")) as "int_col")
.where(IsNotNull($"a") && IsNotNull($"int_col") &&
$"a" === $"int_col")
.select($"a", Coalesce(Seq($"a", $"b")).as("int_col")).as("t")
.join(t2.where(IsNotNull($"a")), Inner,
Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
.analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("inner join with alias: alias contains single attributes") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor
class RewriteWithExpressionSuite extends PlanTest {

object Optimizer extends RuleExecutor[LogicalPlan] {
val batches = Batch("Rewrite With expression", FixedPoint(5),
val batches = Batch("Rewrite With expression", Once,
PullOutGroupingExpressions,
RewriteWithExpression) :: Nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ class SparkOptimizer(
MergeScalarSubqueries,
RewriteDistinctAggregates),
Batch("Pushdown Filters from PartitionPruning", fixedPoint,
PushDownPredicates),
Batch("Rewrite With expression", fixedPoint,
PushDownPredicates,
RewriteWithExpression,
CollapseProject),
Batch("Cleanup filters that cannot be pushed down", Once,
Expand All @@ -90,11 +89,10 @@ class SparkOptimizer(
ColumnPruning,
LimitPushDown,
PushPredicateThroughNonJoin,
RewriteWithExpression,
CollapseProject,
PushProjectionThroughLimit,
RemoveNoopOperators),
Batch("Rewrite With expression", fixedPoint,
RewriteWithExpression,
CollapseProject),
Batch("Infer window group limit", Once,
InferWindowGroupLimit,
LimitPushDown,
Expand Down

0 comments on commit 57016f2

Please sign in to comment.