Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zml1206 committed Jan 9, 2025
1 parent d74dc60 commit d93a9f4
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,13 @@ object CommonExpressionId {
/**
* A wrapper of common expression to carry the id.
*
* @param originAlias only used for push down predicates to make it idempotent. If it
* @param originalAttribute only used for push down predicates to make it idempotent. If it
* is not none, we should propagate this attribute
*/
case class CommonExpressionDef(
child: Expression,
id: CommonExpressionId = new CommonExpressionId(),
originAlias: Option[Alias] = None)
originalAttribute: Option[Attribute] = None)
extends UnaryExpression with Unevaluable {
override def dataType: DataType = child.dataType
override protected def withNewChildInternal(newChild: Expression): Expression =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1811,10 +1811,8 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe
maxIterationsSetting = SQLConf.OPTIMIZER_MAX_ITERATIONS.key)
val batches = Seq(
Batch("RewriteWithExpression", fixedPoint, RewriteWithExpression),
Batch("Optimize after RewriteWithExpression", fixedPoint,
CollapseProject,
ColumnPruning,
RemoveNoopOperators)
// CollapseProject is needed to ensure idempotence
Batch("CollapseProject", fixedPoint, CollapseProject)
)
}

Expand Down Expand Up @@ -1863,7 +1861,7 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe
// inheritance.
val newAggregateExpressions = aggregate.aggregateExpressions ++
getWithAlias(pushDown.reduce(And)).map(replaceAliasButKeepName(_, aliasMap))
val replaced = removeOriginAlias(rewriteCondition(pushDown.reduce(And), aliasMap))
val replaced = removeOriginAttribute(rewriteCondition(pushDown.reduce(And), aliasMap))
val newAggregate = aggregate.copy(child = Filter(replaced, aggregate.child),
aggregateExpressions = newAggregateExpressions)
// If there is no more filter to stay up, just eliminate the filter.
Expand Down Expand Up @@ -1977,6 +1975,7 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe
case _: Sort => true
case _: BatchEvalPython => true
case _: ArrowEvalPython => true
case _: Expand => true
case _ => false
}

Expand Down Expand Up @@ -2021,9 +2020,9 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe
}
}

private def removeOriginAlias(expr: Expression): Expression = {
private def removeOriginAttribute(expr: Expression): Expression = {
expr.transform {
case ced: CommonExpressionDef => ced.copy(originAlias = None)
case ced: CommonExpressionDef => ced.copy(originalAttribute = None)
}
}

Expand All @@ -2048,7 +2047,8 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe
cond
} else {
val defsMap = AttributeMap(replaceWithMap.map(m =>
m._1 -> CommonExpressionDef(child = trimAliases(m._2), originAlias = Some(m._2))))
m._1 -> CommonExpressionDef(child = trimAliases(m._2),
originalAttribute = Some(m._2.toAttribute))))
val refsMap = AttributeMap(defsMap.map(m => m._1 -> new CommonExpressionRef(m._2)))
splitConjunctivePredicates(cond)
.map(rewriteByWith(_, defsMap, refsMap))
Expand All @@ -2074,13 +2074,13 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe

private def getWithAttributes(condition: Expression): Seq[Attribute] = {
condition.collect {
case CommonExpressionDef(_, _, Some(alias)) => alias.toAttribute
case CommonExpressionDef(_, _, Some(attr)) => attr
}.distinct
}

private def getWithAlias(condition: Expression): Seq[Alias] = {
condition.collect {
case CommonExpressionDef(_, _, Some(alias)) => alias
case CommonExpressionDef(child, _, Some(attr)) => Alias(child, attr.name)(attr.exprId)
}.distinct
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,13 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
// child, and we need to project away the extra columns to keep the plan schema unchanged.
assert(p.output.length <= newPlan.output.length)

def hasOriginAlias(expr: Expression): Boolean = expr match {
def hasOriginalAttribute(expr: Expression): Boolean = expr match {
case w: With =>
if (w.defs.exists(_.originAlias.nonEmpty)) true else false
case e => e.children.exists(hasOriginAlias)
if (w.defs.exists(_.originalAttribute.nonEmpty)) true else false
case e => e.children.exists(hasOriginalAttribute)
}
// If this iteration contains attribute that require propagate, the column cannot be pruning.
val needPropagate = p.expressions.exists(hasOriginAlias)
val needPropagate = p.expressions.exists(hasOriginalAttribute)
if (p.output.length < newPlan.output.length && !needPropagate) {
assert(p.outputSet.subsetOf(newPlan.outputSet))
Project(p.output, newPlan)
Expand All @@ -123,13 +123,13 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
_, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith = true))
val refToExpr = mutable.HashMap.empty[CommonExpressionId, Expression]

defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id, originAlias), index) =>
defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id, originalAttr), index) =>
if (id.canonicalized) {
throw SparkException.internalError(
"Cannot rewrite canonicalized Common expression definitions")
}

if (originAlias.isEmpty &&
if (originalAttr.isEmpty &&
(CollapseProject.isCheap(child) || !commonExprIdSet.contains(id))) {
refToExpr(id) = child
} else {
Expand All @@ -146,11 +146,11 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
// TODO: we should calculate the ref count and also inline the common expression
// if it's ref count is 1.
refToExpr(id) = child
} else if (originAlias.nonEmpty &&
inputPlans.head.output.contains(originAlias.get.toAttribute)) {
} else if (originalAttr.nonEmpty &&
inputPlans.head.output.contains(originalAttr.get.toAttribute)) {
// originAlias only exists in Project or Filter. If the child already contains this
// attribute, extend it.
refToExpr(id) = originAlias.get.toAttribute
refToExpr(id) = originalAttr.get.toAttribute
} else {
val commonExprs = commonExprsPerChild(childPlanIndex)
val existingCommonExpr = commonExprs.find(_._2 == id.id)
Expand All @@ -160,14 +160,9 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
}
refToExpr(id) = existingCommonExpr.get._1.toAttribute
} else {
val alias = originAlias match {
val alias = originalAttr match {
case Some(a) =>
a.copy(child = child)(
exprId = a.exprId,
qualifier = a.qualifier,
explicitMetadata = Option(a.metadata),
nonInheritableMetadataKeys = a.nonInheritableMetadataKeys
)
Alias(child, a.name)(a.exprId)
case _ =>
val aliasName = if (SQLConf.get.getConf(SQLConf.USE_COMMON_EXPR_ID_FOR_ALIAS)) {
s"_common_expr_${id.id}"
Expand Down
Loading

0 comments on commit d93a9f4

Please sign in to comment.