Skip to content
This repository was archived by the owner on Jun 14, 2024. It is now read-only.

Commit 6aa3e13

Browse files
Support filter with indexes on nested fields
1 parent 93a6efd commit 6aa3e13

File tree

5 files changed

+379
-50
lines changed

5 files changed

+379
-50
lines changed

src/main/scala/com/microsoft/hyperspace/index/rules/FilterIndexRule.scala

+44-20
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,19 @@
1717
package com.microsoft.hyperspace.index.rules
1818

1919
import org.apache.spark.internal.Logging
20-
import org.apache.spark.sql.catalyst.analysis.CleanupAliases
21-
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
20+
import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, Resolver, UnresolvedAttribute}
21+
import org.apache.spark.sql.catalyst.expressions.Expression
2222
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LogicalPlan, Project}
2323
import org.apache.spark.sql.catalyst.rules.Rule
2424

2525
import com.microsoft.hyperspace.{ActiveSparkSession, Hyperspace}
2626
import com.microsoft.hyperspace.actions.Constants
2727
import com.microsoft.hyperspace.index.IndexLogEntry
2828
import com.microsoft.hyperspace.index.rankers.FilterIndexRanker
29+
import com.microsoft.hyperspace.index.rules.PlanUtils._
2930
import com.microsoft.hyperspace.index.sources.FileBasedRelation
3031
import com.microsoft.hyperspace.telemetry.{AppInfo, HyperspaceEventLogging, HyperspaceIndexUsageEvent}
31-
import com.microsoft.hyperspace.util.{HyperspaceConf, ResolverUtils}
32+
import com.microsoft.hyperspace.util.{HyperspaceConf, ResolverUtils, SchemaUtils}
3233

3334
/**
3435
* FilterIndex rule looks for opportunities in a logical plan to replace
@@ -53,7 +54,7 @@ object FilterIndexRule
5354
case ExtractFilterNode(originalPlan, filter, outputColumns, filterColumns) =>
5455
try {
5556
val candidateIndexes =
56-
findCoveringIndexes(filter, outputColumns, filterColumns)
57+
findCoveringIndexes(filter, outputColumns, filterColumns, plan)
5758
FilterIndexRanker.rank(spark, filter, candidateIndexes) match {
5859
case Some(index) =>
5960
// As FilterIndexRule is not intended to support bucketed scan, we set
@@ -99,7 +100,8 @@ object FilterIndexRule
99100
private def findCoveringIndexes(
100101
filter: Filter,
101102
outputColumns: Seq[String],
102-
filterColumns: Seq[String]): Seq[IndexLogEntry] = {
103+
filterColumns: Seq[String],
104+
plan: LogicalPlan): Seq[IndexLogEntry] = {
103105
RuleUtils.getRelation(spark, filter) match {
104106
case Some(r) =>
105107
val indexManager = Hyperspace
@@ -111,20 +113,35 @@ object FilterIndexRule
111113
// See https://github.com/microsoft/hyperspace/issues/65
112114
val allIndexes = indexManager.getIndexes(Seq(Constants.States.ACTIVE))
113115

114-
val candidateIndexes = allIndexes.filter { index =>
115-
indexCoversPlan(
116-
outputColumns,
117-
filterColumns,
118-
index.indexedColumns,
119-
index.includedColumns)
116+
def resolveWithChildren(fieldName: String, plan: LogicalPlan, resolver: Resolver) = {
117+
plan.resolveChildren(UnresolvedAttribute.parseAttributeName(fieldName), resolver)
120118
}
121119

122-
// Get candidate via file-level metadata validation. This is performed after pruning
123-
// by column schema, as this might be expensive when there are numerous files in the
124-
// relation or many indexes to be checked.
125-
RuleUtils.getCandidateIndexes(spark, candidateIndexes, r)
126-
127-
case None => Nil // There is zero or more than one supported relations in Filter's sub-plan.
120+
// Resolve output columns with default resolver method
121+
val resolvedOutputColumnsOpt =
122+
ResolverUtils.resolve(spark, outputColumns, plan, resolveWithChildren, force = false)
123+
// Resolve
124+
val resolvedFilterColumnsOpt =
125+
ResolverUtils.resolve(spark, filterColumns, plan, resolveWithChildren, force = false)
126+
127+
(resolvedOutputColumnsOpt, resolvedFilterColumnsOpt) match {
128+
case (Some(resolvedOutputColumns), Some(resolvedFilterColumns)) =>
129+
val candidateIndexes = allIndexes.filter { index =>
130+
indexCoversPlan(
131+
SchemaUtils.prefixNestedFieldNames(resolvedOutputColumns),
132+
SchemaUtils.prefixNestedFieldNames(resolvedFilterColumns),
133+
index.indexedColumns,
134+
index.includedColumns)
135+
}
136+
137+
// Get candidate via file-level metadata validation. This is performed after pruning
138+
// by column schema, as this might be expensive when there are numerous files in the
139+
// relation or many indexes to be checked.
140+
RuleUtils.getCandidateIndexes(spark, candidateIndexes, r)
141+
142+
case _ => Nil
143+
}
144+
case _ => Nil // There is zero or more than one supported relations in Filter's sub-plan.
128145
}
129146
}
130147

@@ -136,7 +153,6 @@ object FilterIndexRule
136153
* @param filterColumns List of columns in filter predicate.
137154
* @param indexedColumns List of indexed columns (e.g. from an index being checked)
138155
* @param includedColumns List of included columns (e.g. from an index being checked)
139-
* @param fileFormat FileFormat for input relation in original logical plan.
140156
* @return 'true' if
141157
* 1. Index fully covers output and filter columns, and
142158
* 2. Filter predicate contains first column in index's 'indexed' columns.
@@ -168,9 +184,17 @@ object ExtractFilterNode {
168184
val projectColumnNames = CleanupAliases(project)
169185
.asInstanceOf[Project]
170186
.projectList
171-
.map(_.references.map(_.asInstanceOf[AttributeReference].name))
187+
.map(extractNamesFromExpression)
172188
.flatMap(_.toSeq)
173-
val filterColumnNames = condition.references.map(_.name).toSeq
189+
val filterColumnNames = extractNamesFromExpression(condition).toSeq
190+
.sortBy(-_.length)
191+
.foldLeft(Seq.empty[String]) { (acc, e) =>
192+
if (!acc.exists(i => i.startsWith(e))) {
193+
acc :+ e
194+
} else {
195+
acc
196+
}
197+
}
174198

175199
Some(project, filter, projectColumnNames, filterColumnNames)
176200

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
/*
2+
* Copyright (2020) The Hyperspace Project Authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.microsoft.hyperspace.index.rules
18+
19+
import scala.util.Try
20+
21+
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, GetStructField}
22+
import org.apache.spark.sql.types.{DataType, StructType}
23+
24+
object PlanUtils {
25+
26+
/**
27+
* The method extract field names from a Spark Catalyst [[Expression]].
28+
*
29+
* @param exp The Spark Catalyst expression from which to extract names.
30+
* @return A set of distinct field names.
31+
*/
32+
def extractNamesFromExpression(exp: Expression): Set[String] = {
33+
exp match {
34+
case AttributeReference(name, _, _, _) =>
35+
Set(s"$name")
36+
case Alias(child, _) =>
37+
extractNamesFromExpression(child)
38+
case otherExp =>
39+
otherExp.containsChild.map {
40+
case g: GetStructField =>
41+
s"${getChildNameFromStruct(g)}"
42+
case e: Expression =>
43+
extractNamesFromExpression(e).filter(_.nonEmpty).mkString(".")
44+
case _ => ""
45+
}
46+
}
47+
}
48+
49+
/**
50+
* Given a [[GetStructField]] expression for a nested field (aka a struct)
51+
* the method will extract the full field `.` (dot) separated name.
52+
*
53+
* @param field The [[GetStructField]] field from which we want to extract
54+
* the name.
55+
* @return A field name `.` (dot) separated if nested.
56+
*/
57+
def getChildNameFromStruct(field: GetStructField): String = {
58+
field.child match {
59+
case f: GetStructField =>
60+
s"${getChildNameFromStruct(f)}.${field.name.get}"
61+
case a: AttributeReference =>
62+
s"${a.name}.${field.name.get}"
63+
case _ =>
64+
s"${field.name.get}"
65+
}
66+
}
67+
68+
/**
69+
* Given an Spark Catalyst [[Expression]] and a field name the method extracts
70+
* the parent search expression and the expression that contains the field name
71+
* @param exp The Spark Catalyst [[Expression]] to extract from.
72+
* @param name The field name to search for.
73+
* @return A tuple with the parent expression and the leaf expression that
74+
* contains the given name.
75+
*/
76+
def extractSearchQuery(exp: Expression, name: String): (Expression, Expression) = {
77+
val splits = name.split(".")
78+
val expFound = exp.find {
79+
case a: AttributeReference if splits.forall(s => a.name.contains(s)) => true
80+
case f: GetStructField if splits.forall(s => f.toString().contains(s)) => true
81+
case _ => false
82+
}.get
83+
val parent = exp.find {
84+
case e: Expression if e.containsChild.contains(expFound) => true
85+
case _ => false
86+
}.get
87+
(parent, expFound)
88+
}
89+
90+
/**
91+
* Given an Spark Catalyst [[Expression]], a needle [[Expression]] and a replace
92+
* [[Expression]] the method will replace the needle with the replacement into
93+
* the parent expression.
94+
*
95+
* @param parent The parent Spark Catalyst [[Expression]] into which to replace.
96+
* @param needle The Spark Catalyst [[Expression]] needle to search for.
97+
* @param repl The replacement Spark Catalyst [[Expression]].
98+
* @return A new Spark Catalyst [[Expression]].
99+
*/
100+
def replaceInSearchQuery(
101+
parent: Expression,
102+
needle: Expression,
103+
repl: Expression): Expression = {
104+
parent.mapChildren { c =>
105+
if (c == needle) {
106+
repl
107+
} else {
108+
c
109+
}
110+
}
111+
}
112+
113+
/**
114+
* Given an Spark Catalyst [[Expression]] and a field name the method
115+
* extracts the [[AttributeReference]] for that field name.
116+
*
117+
* @param exp The Spark Catalyst [[Expression]] to extract from.
118+
* @param name The field name for which to extract the attribute reference.
119+
* @return A Spark Catalyst [[AttributeReference]] pointing to the field name.
120+
*/
121+
def extractAttributeRef(exp: Expression, name: String): AttributeReference = {
122+
val splits = name.split(".")
123+
val elem = exp.find {
124+
case a: AttributeReference if splits.contains(a.name) => true
125+
case _ => false
126+
}
127+
elem.get.asInstanceOf[AttributeReference]
128+
}
129+
130+
/**
131+
* Given a Spark Catalyst [[Expression]] and a field name the method
132+
* extracts the type of the field as a Spark SQL [[DataType]].
133+
*
134+
* @param exp The Spark Catalyst [[Expression]] from which to extract the type.
135+
* @param name The field name for which we need to get the type.
136+
* @return A Spark SQL [[DataType]] of the given field name.
137+
*/
138+
def extractTypeFromExpression(exp: Expression, name: String): DataType = {
139+
val splits = name.split(".")
140+
val elem = exp.flatMap {
141+
case a: AttributeReference =>
142+
if (splits.forall(s => a.name == s)) {
143+
Some((name, a.dataType))
144+
} else {
145+
Try({
146+
val h :: t = splits.toList
147+
if (a.name == h && a.dataType.isInstanceOf[StructType]) {
148+
val currentDataType = a.dataType.asInstanceOf[StructType]
149+
val foldedFields = t.foldLeft(Seq.empty[(String, DataType)]) { (acc, i) =>
150+
val idx = currentDataType.indexWhere(_.name.equalsIgnoreCase(i))
151+
acc :+ (i, currentDataType(idx).dataType)
152+
}
153+
Some(foldedFields.last)
154+
} else {
155+
None
156+
}
157+
}).getOrElse(None)
158+
}
159+
case f: GetStructField if splits.forall(s => f.toString().contains(s)) =>
160+
Some((name, f.dataType))
161+
case _ => None
162+
}
163+
elem.find(e => e._1 == name || e._1 == splits.last).get._2
164+
}
165+
}

src/main/scala/com/microsoft/hyperspace/index/rules/RuleUtils.scala

+78-5
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.collection.mutable
2121
import org.apache.hadoop.fs.Path
2222
import org.apache.spark.sql.SparkSession
2323
import org.apache.spark.sql.catalyst.catalog.BucketSpec
24-
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, In, Literal, Not}
24+
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, ExprId, GetStructField, In, Literal, Not}
2525
import org.apache.spark.sql.catalyst.optimizer.OptimizeIn
2626
import org.apache.spark.sql.catalyst.plans.logical._
2727
import org.apache.spark.sql.execution.datasources._
@@ -32,8 +32,9 @@ import com.microsoft.hyperspace.Hyperspace
3232
import com.microsoft.hyperspace.index._
3333
import com.microsoft.hyperspace.index.IndexLogEntryTags.{HYBRIDSCAN_RELATED_CONFIGS, IS_HYBRIDSCAN_CANDIDATE}
3434
import com.microsoft.hyperspace.index.plans.logical.{BucketUnion, IndexHadoopFsRelation}
35+
import com.microsoft.hyperspace.index.rules.PlanUtils._
3536
import com.microsoft.hyperspace.index.sources.FileBasedRelation
36-
import com.microsoft.hyperspace.util.HyperspaceConf
37+
import com.microsoft.hyperspace.util.{HyperspaceConf, ResolverUtils, SchemaUtils}
3738

3839
object RuleUtils {
3940

@@ -278,10 +279,30 @@ object RuleUtils {
278279
new ParquetFileFormat,
279280
Map(IndexConstants.INDEX_RELATION_IDENTIFIER))(spark, index)
280281

281-
val updatedOutput = relation.plan.output
282-
.filter(attr => indexFsRelation.schema.fieldNames.contains(attr.name))
283-
.map(_.asInstanceOf[AttributeReference])
282+
val resolvedFields =
283+
ResolverUtils.resolve(spark, index.indexedColumns ++ index.includedColumns, relation.plan)
284+
val updatedOutput =
285+
if (resolvedFields.isDefined && resolvedFields.get.exists(_._2)) {
286+
indexFsRelation.schema.flatMap { s =>
287+
val exprId = getFieldPosition(index, s.name)
288+
relation.plan.output.find(a => s.name.contains(a.name)).map { a =>
289+
AttributeReference(s.name, s.dataType, a.nullable, a.metadata)(
290+
ExprId(exprId),
291+
a.qualifier)
292+
}
293+
}
294+
} else {
295+
relation.plan.output
296+
.filter(attr => indexFsRelation.schema.fieldNames.contains(attr.name))
297+
.map(_.asInstanceOf[AttributeReference])
298+
}
284299
relation.createLogicalRelation(indexFsRelation, updatedOutput)
300+
301+
case p: Project if provider.isSupportedProject(p) =>
302+
transformProject(p, index)
303+
304+
case f: Filter if provider.isSupportedFilter(f) =>
305+
transformFilter(f, index)
285306
}
286307
}
287308

@@ -568,4 +589,56 @@ object RuleUtils {
568589
assert(shuffleInjected)
569590
shuffled
570591
}
592+
593+
private def transformProject(project: Project, index: IndexLogEntry): Project = {
594+
val projectedFields = project.projectList.map { exp =>
595+
val fieldName = extractNamesFromExpression(exp).head
596+
val escapedFieldName = SchemaUtils.prefixNestedFieldName(fieldName)
597+
val attr = extractAttributeRef(exp, fieldName)
598+
val fieldType = extractTypeFromExpression(exp, fieldName)
599+
val exprId = getFieldPosition(index, escapedFieldName)
600+
attr.copy(escapedFieldName, fieldType, attr.nullable, attr.metadata)(
601+
ExprId(exprId),
602+
attr.qualifier)
603+
}
604+
project.copy(projectList = projectedFields)
605+
}
606+
607+
private def transformFilter(filter: Filter, index: IndexLogEntry): Filter = {
608+
val fieldNames = extractNamesFromExpression(filter.condition)
609+
var mutableFilter = filter
610+
fieldNames.foreach { fieldName =>
611+
val escapedFieldName = SchemaUtils.prefixNestedFieldName(fieldName)
612+
val nestedFields = getNestedFields(index)
613+
if (nestedFields.nonEmpty &&
614+
nestedFields.exists(i => i.equalsIgnoreCase(escapedFieldName))) {
615+
val (parentExpresion, exp) =
616+
extractSearchQuery(filter.condition, fieldName)
617+
val fieldType = extractTypeFromExpression(exp, fieldName)
618+
val attr = extractAttributeRef(exp, fieldName)
619+
val exprId = getFieldPosition(index, escapedFieldName)
620+
val newAttr = attr.copy(escapedFieldName, fieldType, attr.nullable, attr.metadata)(
621+
ExprId(exprId),
622+
attr.qualifier)
623+
val newExp = exp match {
624+
case _: GetStructField => newAttr
625+
case other: Expression => other
626+
}
627+
val newParentExpression =
628+
replaceInSearchQuery(parentExpresion, exp, newExp)
629+
mutableFilter = filter.copy(condition = newParentExpression)
630+
} else {
631+
filter
632+
}
633+
}
634+
mutableFilter
635+
}
636+
637+
private def getNestedFields(index: IndexLogEntry): Seq[String] = {
638+
index.schema.fieldNames.filter(_.startsWith(SchemaUtils.NESTED_FIELD_PREFIX))
639+
}
640+
641+
private def getFieldPosition(index: IndexLogEntry, fieldName: String): Int = {
642+
index.schema.fieldNames.indexWhere(_.equalsIgnoreCase(fieldName))
643+
}
571644
}

0 commit comments

Comments
 (0)