Skip to content

Commit 3298ddc

Browse files
committed
fix some test
1 parent 88407e3 commit 3298ddc

File tree

22 files changed

+1167
-1197
lines changed

22 files changed

+1167
-1197
lines changed

mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala

+26-23
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.apache.spark.ml.attribute.AttributeGroup
2222
import org.apache.spark.ml.linalg.{Vector, Vectors}
2323
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
2424
import org.apache.spark.sql.execution.streaming.MemoryStream
25+
import org.apache.spark.sql.internal.SQLConf
2526
import org.apache.spark.sql.streaming.StreamTest
2627

2728
class VectorSizeHintSuite
@@ -55,29 +56,31 @@ class VectorSizeHintSuite
5556
}
5657

5758
test("Adding size to column of vectors.") {
58-
val size = 3
59-
val vectorColName = "vector"
60-
val denseVector = Vectors.dense(1, 2, 3)
61-
val sparseVector = Vectors.sparse(size, Array(), Array())
62-
63-
val data = Seq(denseVector, denseVector, sparseVector).map(Tuple1.apply)
64-
val dataFrame = data.toDF(vectorColName)
65-
assert(
66-
AttributeGroup.fromStructField(dataFrame.schema(vectorColName)).size == -1,
67-
s"This test requires that column '$vectorColName' not have size metadata.")
68-
69-
for (handleInvalid <- VectorSizeHint.supportedHandleInvalids) {
70-
val transformer = new VectorSizeHint()
71-
.setInputCol(vectorColName)
72-
.setSize(size)
73-
.setHandleInvalid(handleInvalid)
74-
testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataFrame, transformer, vectorColName) {
75-
rows => {
76-
assert(
77-
AttributeGroup.fromStructField(rows.head.schema(vectorColName)).size == size,
78-
"Transformer did not add expected size data.")
79-
val numRows = rows.length
80-
assert(numRows === data.length, s"Expecting ${data.length} rows, got $numRows.")
59+
withSQLConf(SQLConf.ALWAYS_INLINE_COMMON_EXPR.key -> "true") {
60+
val size = 3
61+
val vectorColName = "vector"
62+
val denseVector = Vectors.dense(1, 2, 3)
63+
val sparseVector = Vectors.sparse(size, Array(), Array())
64+
65+
val data = Seq(denseVector, denseVector, sparseVector).map(Tuple1.apply)
66+
val dataFrame = data.toDF(vectorColName)
67+
assert(
68+
AttributeGroup.fromStructField(dataFrame.schema(vectorColName)).size == -1,
69+
s"This test requires that column '$vectorColName' not have size metadata.")
70+
71+
for (handleInvalid <- VectorSizeHint.supportedHandleInvalids) {
72+
val transformer = new VectorSizeHint()
73+
.setInputCol(vectorColName)
74+
.setSize(size)
75+
.setHandleInvalid(handleInvalid)
76+
testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataFrame, transformer, vectorColName) {
77+
rows => {
78+
assert(
79+
AttributeGroup.fromStructField(rows.head.schema(vectorColName)).size == size,
80+
"Transformer did not add expected size data.")
81+
val numRows = rows.length
82+
assert(numRows === data.length, s"Expecting ${data.length} rows, got $numRows.")
83+
}
8184
}
8285
}
8386
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

+11-10
Original file line numberDiff line numberDiff line change
@@ -1281,8 +1281,8 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
12811281
} else {
12821282
true
12831283
}
1284-
// Alias, ExtractValue and CreateNamedStruct are very cheap.
1285-
case _: Alias | _: ExtractValue | _: CreateNamedStruct => e.children.forall(isCheap)
1284+
// Alias and ExtractValue are very cheap.
1285+
case _: Alias | _: ExtractValue => e.children.forall(isCheap)
12861286
case _ => false
12871287
}
12881288

@@ -1855,13 +1855,13 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe
18551855
}
18561856

18571857
if (pushDown.nonEmpty) {
1858-
// Different from Project, Aggregate is not suitable for using With to push down, because
1859-
// propagate the attributes directly need add the groupingExpressions may cause regression.
1860-
// So Aggregate only need inline common expression from parent for original project
1861-
// inheritance.
1858+
// Different from Project, Aggregate propagate the attributes directly need add the
1859+
// groupingExpressions may cause regression. So Aggregate need inline common expression
1860+
// from parent for original project inheritance and rewrite originalAttribute of push down
1861+
// With.
18621862
val newAggregateExpressions = aggregate.aggregateExpressions ++
18631863
getWithAlias(pushDown.reduce(And)).map(replaceAliasButKeepName(_, aliasMap))
1864-
val replaced = removeOriginAttribute(rewriteCondition(pushDown.reduce(And), aliasMap))
1864+
val replaced = rewriteOriginalAttribute(rewriteCondition(pushDown.reduce(And), aliasMap))
18651865
val newAggregate = aggregate.copy(child = Filter(replaced, aggregate.child),
18661866
aggregateExpressions = newAggregateExpressions)
18671867
// If there is no more filter to stay up, just eliminate the filter.
@@ -2020,9 +2020,10 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe
20202020
}
20212021
}
20222022

2023-
private def removeOriginAttribute(expr: Expression): Expression = {
2023+
private def rewriteOriginalAttribute(expr: Expression): Expression = {
20242024
expr.transform {
2025-
case ced: CommonExpressionDef => ced.copy(originalAttribute = None)
2025+
case ced @ CommonExpressionDef(_, _, Some(a)) =>
2026+
ced.copy(originalAttribute = Some(a.withExprId(NamedExpression.newExprId)))
20262027
}
20272028
}
20282029

@@ -2042,7 +2043,7 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe
20422043
val replaceWithMap = cond.references.toSeq
20432044
.filter(attr => aliasMap.contains(attr))
20442045
.map(attr => attr -> aliasMap(attr))
2045-
.filter(m => !CollapseProject.isCheap(m._2))
2046+
.filterNot(m => CollapseProject.isCheap(m._2))
20462047
if (replaceWithMap.isEmpty) {
20472048
cond
20482049
} else {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
148148
refToExpr(id) = child
149149
} else if (originalAttr.nonEmpty &&
150150
inputPlans.head.output.contains(originalAttr.get.toAttribute)) {
151-
// originAlias only exists in Project or Filter. If the child already contains this
151+
// originAttr only exists in Project or Filter. If the child already contains this
152152
// attribute, extend it.
153153
refToExpr(id) = originalAttr.get.toAttribute
154154
} else {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala

+7-1
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,13 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
210210
case f @ Filter(condition, j @ Join(_, _, RightOuter | LeftOuter | FullOuter, _, _)) =>
211211
val newJoinType = buildNewJoinType(f, j)
212212
if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType))
213-
213+
case f @ Filter(condition,
214+
p @ Project(_, j @ Join(_, _, RightOuter | LeftOuter | FullOuter, _, _))) =>
215+
val aliasMap = getAliasMap(p)
216+
val newFilter = f.copy(condition = replaceAlias(condition, aliasMap))
217+
val newJoinType = buildNewJoinType(newFilter, j)
218+
if (j.joinType == newJoinType) f
219+
else Filter(condition, p.copy(child = j.copy(joinType = newJoinType)))
214220
case a @ Aggregate(_, _, Join(left, _, LeftOuter, _, _), _)
215221
if a.references.subsetOf(left.outputSet) && allDuplicateAgnostic(a) =>
216222
a.copy(child = left)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownOnePassSuite.scala

-4
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,6 @@ class FilterPushdownOnePassSuite extends PlanTest {
148148
val optimized = Optimize.execute(originalQuery.analyze)
149149
val correctAnswer =
150150
x.where($"b" > 0)
151-
.select(($"a" + 1) as "a1", $"b")
152-
.select(($"a1" + 1) as "a2", $"b")
153-
.select(($"a2" + 1) as "a3", $"b")
154-
.select(($"a3" + 1) as "a4", $"b")
155151
.select($"b").analyze
156152

157153
comparePlans(optimized, correctAnswer)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala

+45-48
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions._
2727
import org.apache.spark.sql.catalyst.plans._
2828
import org.apache.spark.sql.catalyst.plans.logical._
2929
import org.apache.spark.sql.catalyst.rules._
30-
import org.apache.spark.sql.internal.SQLConf
3130
import org.apache.spark.sql.types._
3231
import org.apache.spark.unsafe.types.CalendarInterval
3332

@@ -762,8 +761,9 @@ class FilterPushdownSuite extends PlanTest {
762761
val optimized = Optimize.execute(originalQuery.analyze)
763762

764763
val correctAnswer = testRelation
765-
.where($"a" + 1 < 3)
766-
.select($"a", $"b")
764+
.select($"a", $"b", $"c", ($"a" + 1) as "aa")
765+
.where($"aa" < 3)
766+
.select($"a", $"b", $"aa")
767767
.groupBy($"a")(($"a" + 1) as "aa", count($"b") as "c")
768768
.where($"c" === 2L || $"aa" > 4)
769769
.analyze
@@ -1524,50 +1524,47 @@ class FilterPushdownSuite extends PlanTest {
15241524
}
15251525

15261526
test("SPARK-50589: avoid extra expression duplication when push filter") {
1527-
withSQLConf(SQLConf.USE_COMMON_EXPR_ID_FOR_ALIAS.key -> "false") {
1528-
// through project
1529-
val originalQuery1 = testRelation
1530-
.select($"a" + $"b" as "add", $"a" - $"b" as "sub")
1531-
.where($"add" < 10 && $"add" + $"add" > 10 && $"sub" > 0)
1532-
val correctAnswer1 = testRelation
1533-
.select($"a", $"b", $"c", $"a" + $"b" as "add", $"a" - $"b" as "sub")
1534-
.where($"add" < 10 && $"add" + $"add" > 10 && $"sub" > 0)
1535-
.select($"add", $"sub")
1536-
.analyze
1537-
val optimized1 = Optimize.execute(originalQuery1.analyze)
1538-
comparePlans(optimized1, correctAnswer1)
1539-
1540-
// through aggregate
1541-
val originalQuery2 = testRelation
1542-
.groupBy($"a")($"a", $"a" + $"a" as "add", abs($"a") as "abs", count(1) as "ct")
1543-
.where($"add" < 10 && $"add" + $"add" > 10 && $"abs" > 5)
1544-
val optimized2 = Optimize.execute(originalQuery2.analyze)
1545-
val correctAnswer2 = testRelation
1546-
.select($"a", $"b", $"c", $"a" + $"a" as "_common_expr_0")
1547-
.where($"_common_expr_0" < 10 &&
1548-
$"_common_expr_0" + $"_common_expr_0" > 10 &&
1549-
abs($"a") > 5)
1550-
.select($"a", $"b", $"c")
1551-
.groupBy($"a")($"a", $"a" + $"a" as "add", abs($"a") as "abs", count(1) as "ct")
1552-
.analyze
1553-
comparePlans(optimized2, correctAnswer2)
1554-
}
1555-
withSQLConf(SQLConf.USE_COMMON_EXPR_ID_FOR_ALIAS.key -> "false") {
1556-
// partial push down
1557-
val originalQuery3 = testRelation
1558-
.groupBy($"a")($"a", count(1) as "ct")
1559-
.select($"a" + $"a" as "add", $"ct")
1560-
.where($"add" + $"add" > 10 && $"add" > $"ct")
1561-
val optimized3 = Optimize.execute(originalQuery3.analyze)
1562-
val correctAnswer3 = testRelation
1563-
.select($"a", $"b", $"c", $"a" + $"a" as "_common_expr_0")
1564-
.where($"_common_expr_0" + $"_common_expr_0" > 10)
1565-
.select($"a", $"b", $"c")
1566-
.groupBy($"a")($"a", count(1) as "ct", $"a" + $"a" as "add")
1567-
.where($"add" > $"ct")
1568-
.select($"add", $"ct")
1569-
.analyze
1570-
comparePlans(optimized3, correctAnswer3)
1571-
}
1527+
// through project
1528+
val originalQuery1 = testRelation
1529+
.select($"a" + $"b" as "add", $"a" - $"b" as "sub")
1530+
.where($"add" < 10 && $"add" + $"add" > 10 && $"sub" > 0)
1531+
val correctAnswer1 = testRelation
1532+
.select($"a", $"b", $"c", $"a" + $"b" as "add", $"a" - $"b" as "sub")
1533+
.where($"add" < 10 && $"add" + $"add" > 10 && $"sub" > 0)
1534+
.select($"add", $"sub")
1535+
.analyze
1536+
val optimized1 = Optimize.execute(originalQuery1.analyze)
1537+
comparePlans(optimized1, correctAnswer1)
1538+
1539+
// through aggregate
1540+
val originalQuery2 = testRelation
1541+
.groupBy($"a")($"a", $"a" + $"a" as "add", abs($"a") as "abs", count(1) as "ct")
1542+
.where($"add" < 10 && $"add" + $"add" > 10 && $"abs" > 5)
1543+
val optimized2 = Optimize.execute(originalQuery2.analyze)
1544+
val correctAnswer2 = testRelation
1545+
.select($"a", $"b", $"c", $"a" + $"a" as "add", abs($"a") as "abs")
1546+
.where($"add" < 10 &&
1547+
$"add" + $"add" > 10 &&
1548+
$"abs" > 5)
1549+
.select($"a", $"b", $"c", $"add", $"abs")
1550+
.groupBy($"a")($"a", $"a" + $"a" as "add", abs($"a") as "abs", count(1) as "ct")
1551+
.analyze
1552+
comparePlans(optimized2, correctAnswer2)
1553+
1554+
// partial push down
1555+
val originalQuery3 = testRelation
1556+
.groupBy($"a")($"a", count(1) as "ct")
1557+
.select($"a" + $"a" as "add", $"ct")
1558+
.where($"add" + $"add" > 10 && $"add" > $"ct")
1559+
val optimized3 = Optimize.execute(originalQuery3.analyze)
1560+
val correctAnswer3 = testRelation
1561+
.select($"a", $"b", $"c", $"a" + $"a" as "add")
1562+
.where($"add" + $"add" > 10)
1563+
.select($"a", $"b", $"c", $"add")
1564+
.groupBy($"a")($"a", count(1) as "ct", $"a" + $"a" as "add")
1565+
.where($"add" > $"ct")
1566+
.select($"add", $"ct")
1567+
.analyze
1568+
comparePlans(optimized3, correctAnswer3)
15721569
}
15731570
}

0 commit comments

Comments
 (0)