Skip to content
This repository was archived by the owner on Jun 14, 2024. It is now read-only.

Commit 19d26f8

Browse files
Integrate review feedback (1)
1 parent ef2b45e commit 19d26f8

File tree

5 files changed

+275
-163
lines changed

5 files changed

+275
-163
lines changed

src/main/scala/com/microsoft/hyperspace/index/rules/FilterIndexRule.scala

+4-99
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,11 @@ object ExtractFilterNode {
171171
val projectColumnNames = CleanupAliases(project)
172172
.asInstanceOf[Project]
173173
.projectList
174-
.map(extractNamesFromExpression)
174+
.map(PlanUtils.extractNamesFromExpression)
175175
.flatMap(_.toSeq)
176-
val filterColumnNames = extractNamesFromExpression(condition).toSeq
176+
val filterColumnNames = PlanUtils
177+
.extractNamesFromExpression(condition)
178+
.toSeq
177179
.sortBy(-_.length)
178180
.foldLeft(Seq.empty[String]) { (acc, e) =>
179181
if (!acc.exists(i => i.startsWith(e))) {
@@ -194,103 +196,6 @@ object ExtractFilterNode {
194196

195197
case _ => None // plan does not match with any of filter index rule patterns
196198
}
197-
198-
def extractNamesFromExpression(exp: Expression): Set[String] = {
199-
exp match {
200-
case AttributeReference(name, _, _, _) =>
201-
Set(s"$name")
202-
case otherExp =>
203-
otherExp.containsChild.flatMap {
204-
case g: GetStructField =>
205-
Set(s"${getChildNameFromStruct(g)}")
206-
case e: Expression =>
207-
extractNamesFromExpression(e).filter(_.nonEmpty)
208-
case _ => Set.empty[String]
209-
}
210-
}
211-
}
212-
213-
def getChildNameFromStruct(field: GetStructField): String = {
214-
field.child match {
215-
case f: GetStructField =>
216-
s"${getChildNameFromStruct(f)}.${field.name.get}"
217-
case a: AttributeReference =>
218-
s"${a.name}.${field.name.get}"
219-
case _ =>
220-
s"${field.name.get}"
221-
}
222-
}
223-
224-
def replaceInSearchQuery(
225-
parent: Expression,
226-
needle: Expression,
227-
repl: Expression): Expression = {
228-
parent.mapChildren { c =>
229-
if (c == needle) {
230-
repl
231-
} else {
232-
c
233-
}
234-
}
235-
}
236-
237-
def extractAttributeRef(exp: Expression, name: String): AttributeReference = {
238-
val splits = name.split(SchemaUtils.NESTED_FIELD_NEEDLE_REGEX)
239-
val elem = exp.find {
240-
case a: AttributeReference if splits.contains(a.name) => true
241-
case _ => false
242-
}
243-
elem.get.asInstanceOf[AttributeReference]
244-
}
245-
246-
def extractTypeFromExpression(exp: Expression, name: String): DataType = {
247-
val splits = name.split(SchemaUtils.NESTED_FIELD_NEEDLE_REGEX)
248-
val elem = exp.flatMap {
249-
case attrRef: AttributeReference =>
250-
if (splits.forall(s => attrRef.name == s)) {
251-
Some((name, attrRef.dataType))
252-
} else {
253-
Try({
254-
val h :: t = splits.toList
255-
if (attrRef.name == h && attrRef.dataType.isInstanceOf[StructType]) {
256-
val currentDataType = attrRef.dataType.asInstanceOf[StructType]
257-
var localDT = currentDataType
258-
val foldedFields = t.foldLeft(Seq.empty[(String, DataType)]) { (acc, i) =>
259-
val collected = localDT.collect {
260-
case dt if dt.name == i =>
261-
dt.dataType match {
262-
case st: StructType =>
263-
localDT = st
264-
case _ =>
265-
}
266-
(i, dt.dataType)
267-
}
268-
acc ++ collected
269-
}
270-
Some(foldedFields.last)
271-
} else {
272-
None
273-
}
274-
}).getOrElse(None)
275-
}
276-
case f: GetStructField if splits.forall(s => f.toString().contains(s)) =>
277-
Some((name, f.dataType))
278-
case _ => None
279-
}
280-
elem.find(e => e._1 == name || e._1 == splits.last).get._2
281-
}
282-
283-
def collectAliases(plan: LogicalPlan): Seq[(String, Attribute, Expression)] = {
284-
plan
285-
.collect {
286-
case Project(projectList, _) =>
287-
projectList.collect {
288-
case a @ Alias(child, name) =>
289-
(name, a.toAttribute, child)
290-
}
291-
}
292-
.flatten
293-
}
294199
}
295200

296201
object ExtractRelation extends ActiveSparkSession {

src/main/scala/com/microsoft/hyperspace/index/rules/JoinIndexRule.scala

+6-6
Original file line numberDiff line numberDiff line change
@@ -451,13 +451,13 @@ object JoinIndexRule
451451
val fields = conditionFieldsToRelationFields(project.projectList).values
452452
fields.flatMap {
453453
case g: GetStructField =>
454-
Seq(ExtractFilterNode.getChildNameFromStruct(g))
454+
Seq(PlanUtils.getChildNameFromStruct(g))
455455
case otherFieldType =>
456-
ExtractFilterNode.extractNamesFromExpression(otherFieldType).toSeq
456+
PlanUtils.extractNamesFromExpression(otherFieldType).toSeq
457457
}
458458
case filter: Filter =>
459459
var acc = Seq.empty[String]
460-
val fls = ExtractFilterNode
460+
val fls = PlanUtils
461461
.extractNamesFromExpression(filter.condition)
462462
.toSeq
463463
.distinct
@@ -481,7 +481,7 @@ object JoinIndexRule
481481
plan.outputSet.map { i =>
482482
val attr = extractFieldFromProjection(i, projectionFields)
483483
val opt = attr.map { e =>
484-
ExtractFilterNode.getChildNameFromStruct(e.asInstanceOf[GetStructField])
484+
PlanUtils.getChildNameFromStruct(e.asInstanceOf[GetStructField])
485485
}
486486
opt.getOrElse(i.name)
487487
}
@@ -517,15 +517,15 @@ object JoinIndexRule
517517
val attrLeftName = if (lp.nonEmpty) {
518518
Try {
519519
val attrLeft = extractFieldFromProjection(attr1, lp).get
520-
ExtractFilterNode.getChildNameFromStruct(attrLeft.asInstanceOf[GetStructField])
520+
PlanUtils.getChildNameFromStruct(attrLeft.asInstanceOf[GetStructField])
521521
}.getOrElse(attr1.name)
522522
} else {
523523
attr1.name
524524
}
525525
val attrRightName = if (rp.nonEmpty) {
526526
Try {
527527
val attrRight = extractFieldFromProjection(attr2, rp).get
528-
ExtractFilterNode.getChildNameFromStruct(attrRight.asInstanceOf[GetStructField])
528+
PlanUtils.getChildNameFromStruct(attrRight.asInstanceOf[GetStructField])
529529
}.getOrElse(attr2.name)
530530
} else {
531531
attr2.name
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
/*
2+
* Copyright (2020) The Hyperspace Project Authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.microsoft.hyperspace.index.rules
18+
19+
import scala.util.Try
20+
21+
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, GetStructField}
22+
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
23+
import org.apache.spark.sql.types.{DataType, StructType}
24+
25+
import com.microsoft.hyperspace.util.SchemaUtils
26+
27+
object PlanUtils {
28+
29+
/**
30+
* Returns true if the given project is a supported project. If all of the registered
31+
* providers return None, this returns false.
32+
*
33+
* @param project Project to check if it's supported.
34+
* @return True if the given project is a supported relation.
35+
*/
36+
def isSupportedProject(project: Project): Boolean = {
37+
val containsNestedFields =
38+
SchemaUtils.hasNestedFields(project.projectList.flatMap(extractNamesFromExpression))
39+
var containsNestedChildren = false
40+
project.child.foreach {
41+
case f: Filter =>
42+
containsNestedChildren = containsNestedChildren || {
43+
SchemaUtils.hasNestedFields(
44+
SchemaUtils.unescapeFieldNames(extractNamesFromExpression(f.condition).toSeq))
45+
}
46+
case _ =>
47+
}
48+
containsNestedFields || containsNestedChildren
49+
}
50+
51+
/**
52+
* Returns true if the given filter is a supported filter. If all of the registered
53+
* providers return None, this returns false.
54+
*
55+
* @param filter Filter to check if it's supported.
56+
* @return True if the given project is a supported relation.
57+
*/
58+
def isSupportedFilter(filter: Filter): Boolean = {
59+
val containsNestedFields =
60+
SchemaUtils.hasNestedFields(extractNamesFromExpression(filter.condition).toSeq)
61+
containsNestedFields
62+
}
63+
64+
/**
65+
* Given an expression it extracts all the field names from it.
66+
*
67+
* @param exp Expression to extract field names from
68+
* @return A set of distinct strings representing the field names
69+
* (ie: `Set(nested.field.id, nested.field.other)`)
70+
*/
71+
def extractNamesFromExpression(exp: Expression): Set[String] = {
72+
exp match {
73+
case AttributeReference(name, _, _, _) =>
74+
Set(s"$name")
75+
case otherExp =>
76+
otherExp.containsChild.flatMap {
77+
case g: GetStructField =>
78+
Set(s"${getChildNameFromStruct(g)}")
79+
case e: Expression =>
80+
extractNamesFromExpression(e).filter(_.nonEmpty)
81+
case _ => Set.empty[String]
82+
}
83+
}
84+
}
85+
86+
/**
87+
* Given a nested field this method extracts the full name out of it.
88+
*
89+
* @param field The field from which to get the name from
90+
* @return The name of the field (ie: `nested.field.id`)
91+
*/
92+
def getChildNameFromStruct(field: GetStructField): String = {
93+
field.child match {
94+
case f: GetStructField =>
95+
s"${getChildNameFromStruct(f)}.${field.name.get}"
96+
case a: AttributeReference =>
97+
s"${a.name}.${field.name.get}"
98+
case _ =>
99+
s"${field.name.get}"
100+
}
101+
}
102+
103+
/**
104+
* Given an expression it extracts the attribute reference by field name.
105+
*
106+
* @param exp The expression where to look for the attribute reference
107+
* @param name The name of the field to look for
108+
* @return The attribute reference for that field name
109+
*/
110+
def extractAttributeRef(exp: Expression, name: String): AttributeReference = {
111+
val splits = name.split(SchemaUtils.NESTED_FIELD_NEEDLE_REGEX)
112+
val elem = exp.find {
113+
case a: AttributeReference if splits.contains(a.name) => true
114+
case _ => false
115+
}
116+
elem.get.asInstanceOf[AttributeReference]
117+
}
118+
119+
/**
120+
* Given and expression it extracts the type of the field by field name.
121+
*
122+
* @param exp The expression from where to extract the type from
123+
* @param name The name of the field to look for
124+
* @return The type of the field as [[DataType]]
125+
*/
126+
def extractTypeFromExpression(exp: Expression, name: String): DataType = {
127+
val splits = name.split(SchemaUtils.NESTED_FIELD_NEEDLE_REGEX)
128+
val elem = exp.flatMap {
129+
case attrRef: AttributeReference =>
130+
if (splits.forall(s => attrRef.name == s)) {
131+
Some((name, attrRef.dataType))
132+
} else {
133+
Try({
134+
val h :: t = splits.toList
135+
if (attrRef.name == h && attrRef.dataType.isInstanceOf[StructType]) {
136+
val currentDataType = attrRef.dataType.asInstanceOf[StructType]
137+
var localDT = currentDataType
138+
val foldedFields = t.foldLeft(Seq.empty[(String, DataType)]) { (acc, i) =>
139+
val collected = localDT.collect {
140+
case dt if dt.name == i =>
141+
dt.dataType match {
142+
case st: StructType =>
143+
localDT = st
144+
case _ =>
145+
}
146+
(i, dt.dataType)
147+
}
148+
acc ++ collected
149+
}
150+
Some(foldedFields.last)
151+
} else {
152+
None
153+
}
154+
}).getOrElse(None)
155+
}
156+
case f: GetStructField if splits.forall(s => f.toString().contains(s)) =>
157+
Some((name, f.dataType))
158+
case _ => None
159+
}
160+
elem.find(e => e._1 == name || e._1 == splits.last).get._2
161+
}
162+
163+
/**
164+
* Given a logical plan the method collects all aliases in the plan.
165+
* For example, given this projection
166+
* `Project [nested#548.leaf.cnt AS cnt#659, Date#543, nested#548.leaf.id AS id#660]`
167+
* the result will be:
168+
* {{{
169+
* Seq(
170+
* ("cnt", cnt#659, nested#548.leaf.cnt),
171+
* ("id", id#660, nested#548.leaf.id)
172+
* )
173+
* }}}
174+
*
175+
* @param plan The plan from which to collect the aliases
176+
* @return A collection of:
177+
* - a string representing the alias name
178+
* - the attribute the alias transforms to
179+
* - the expressions from which this alias comes from
180+
*/
181+
def collectAliases(plan: LogicalPlan): Seq[(String, Attribute, Expression)] = {
182+
plan.collect {
183+
case Project(projectList, _) =>
184+
projectList.collect {
185+
case a @ Alias(child, name) =>
186+
(name, a.toAttribute, child)
187+
}
188+
}.flatten
189+
}
190+
}

0 commit comments

Comments
 (0)