Skip to content

Commit

Permalink
Alias could be wrapped with backticks in RENAME and other commands (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
LantaoJin authored Feb 25, 2025
1 parent 7c87378 commit b9a0dc5
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -186,4 +186,67 @@ class FlintSparkPPLRenameITSuite
val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test multiple renamed fields with backticks alias") {
val frame = sql(s"""
| source = $testTable | rename name as `renamed_name`, country as `renamed_country` | fields `renamed_name`, `age`, `renamed_country`
| """.stripMargin)

val expectedResults: Array[Row] =
Array(
Row("Jake", 70, "USA"),
Row("Hello", 30, "USA"),
Row("John", 25, "Canada"),
Row("Jane", 20, "Canada"))
assertSameRows(expectedResults, frame)

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val fieldsProjectList = Seq(
UnresolvedAttribute("renamed_name"),
UnresolvedAttribute("age"),
UnresolvedAttribute("renamed_country"))
val renameProjectList =
Seq(
UnresolvedStar(None),
Alias(UnresolvedAttribute("name"), "renamed_name")(),
Alias(UnresolvedAttribute("country"), "renamed_country")())
val innerProject = Project(renameProjectList, table)
val planDropColumn = DataFrameDropColumns(
Seq(UnresolvedAttribute("name"), UnresolvedAttribute("country")),
innerProject)
val expectedPlan = Project(fieldsProjectList, planDropColumn)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test renamed field with backticks alias used in aggregation") {
val frame = sql(s"""
| source = $testTable | rename age as `user_age` | stats avg(`user_age`) by country
| """.stripMargin)

val expectedResults: Array[Row] = Array(Row(22.5, "Canada"), Row(50.0, "USA"))
assertSameRows(expectedResults, frame)

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val renameProjectList =
Seq(UnresolvedStar(None), Alias(UnresolvedAttribute("age"), "user_age")())
val aggregateExpressions =
Seq(
Alias(
UnresolvedFunction(
Seq("AVG"),
Seq(UnresolvedAttribute("user_age")),
isDistinct = false),
"avg(`user_age`)")(),
Alias(UnresolvedAttribute("country"), "country")())
val innerProject = Project(renameProjectList, table)
val planDropColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("age")), innerProject)
val aggregatePlan = Aggregate(
Seq(Alias(UnresolvedAttribute("country"), "country")()),
aggregateExpressions,
planDropColumn)
val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ joinType
;

sideAlias
: (LEFT EQUAL leftAlias = ident)? COMMA? (RIGHT EQUAL rightAlias = ident)?
: (LEFT EQUAL leftAlias = qualifiedName)? COMMA? (RIGHT EQUAL rightAlias = qualifiedName)?
;

joinCriteria
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,13 @@ public UnresolvedPlan visitJoinCommand(OpenSearchPPLParser.JoinCommandContext ct
joinType = Join.JoinType.CROSS;
}
Join.JoinHint joinHint = getJoinHint(ctx.joinHintList());
Optional<String> leftAlias = ctx.sideAlias().leftAlias != null ? Optional.of(ctx.sideAlias().leftAlias.getText()) : Optional.empty();
Optional<String> leftAlias = ctx.sideAlias().leftAlias != null ? Optional.of(internalVisitExpression(ctx.sideAlias().leftAlias).toString()) : Optional.empty();
Optional<String> rightAlias = Optional.empty();
if (ctx.tableOrSubqueryClause().alias != null) {
rightAlias = Optional.of(ctx.tableOrSubqueryClause().alias.getText());
rightAlias = Optional.of(internalVisitExpression(ctx.tableOrSubqueryClause().alias).toString());
}
if (ctx.sideAlias().rightAlias != null) {
rightAlias = Optional.of(ctx.sideAlias().rightAlias.getText());
rightAlias = Optional.of(internalVisitExpression(ctx.sideAlias().rightAlias).toString());
}

UnresolvedPlan rightRelation = visit(ctx.tableOrSubqueryClause());
Expand Down Expand Up @@ -248,7 +248,7 @@ public UnresolvedPlan visitRenameCommand(OpenSearchPPLParser.RenameCommandContex
.map(
ct ->
new Alias(
ct.renamedField.getText(),
((Field) internalVisitExpression(ct.renamedField)).getField().toString(),
internalVisitExpression(ct.orignalField)))
.collect(Collectors.toList()));
}
Expand All @@ -262,7 +262,7 @@ public UnresolvedPlan visitStatsCommand(OpenSearchPPLParser.StatsCommandContext
String name =
aggCtx.alias == null
? getTextInQuery(aggCtx)
: aggCtx.alias.getText();
: ((Field) internalVisitExpression(aggCtx.alias)).getField().toString();
Alias alias = new Alias(name, aggExpression);
aggListBuilder.add(alias);
}
Expand Down Expand Up @@ -442,7 +442,7 @@ private Trendline.TrendlineComputation toTrendlineComputation(OpenSearchPPLParse
throw new SyntaxCheckException("Number of trendline data-points must be greater than or equal to 1");
}
Field dataField = (Field) expressionBuilder.visitFieldExpression(ctx.field);
String alias = ctx.alias == null?dataField.getField().toString()+"_trendline":ctx.alias.getText();
String alias = ctx.alias == null? dataField.getField().toString() + "_trendline" : internalVisitExpression(ctx.alias).toString();
String computationType = ctx.trendlineType().getText();
return new Trendline.TrendlineComputation(numberOfDataPoints, dataField, alias, Trendline.TrendlineType.valueOf(computationType.toUpperCase()));
}
Expand Down Expand Up @@ -537,7 +537,7 @@ public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ct
public UnresolvedPlan visitTableOrSubqueryClause(OpenSearchPPLParser.TableOrSubqueryClauseContext ctx) {
if (ctx.subSearch() != null) {
return ctx.alias != null
? new SubqueryAlias(ctx.alias.getText(), visitSubSearch(ctx.subSearch()))
? new SubqueryAlias(internalVisitExpression(ctx.alias).toString(), visitSubSearch(ctx.subSearch()))
: visitSubSearch(ctx.subSearch());
} else {
return visitTableSourceClause(ctx.tableSourceClause());
Expand All @@ -547,7 +547,7 @@ public UnresolvedPlan visitTableOrSubqueryClause(OpenSearchPPLParser.TableOrSubq
@Override
public UnresolvedPlan visitTableSourceClause(OpenSearchPPLParser.TableSourceClauseContext ctx) {
Relation relation = new Relation(ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList()));
return ctx.alias != null ? new SubqueryAlias(ctx.alias.getText(), relation) : relation;
return ctx.alias != null ? new SubqueryAlias(internalVisitExpression(ctx.alias).toString(), relation) : relation;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1105,4 +1105,14 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate)
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}

test("test average price with backticks alias") {
val expectedPlan = planTransformer.visit(
plan(pplParser, "source = table | stats avg(price) as avg_price"),
new CatalystPlanContext)
val logPlan = planTransformer.visit(
plan(pplParser, "source = table | stats avg(`price`) as `avg_price`"),
new CatalystPlanContext)
comparePlans(expectedPlan, logPlan, false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,29 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite
comparePlans(expectedPlan, logPlan, false)
}

test("Search multiple tables - with backticks table alias") {
val expectedPlan =
planTransformer.visit(
plan(
pplParser,
"""
| source=table1, table2, table3 as t
| | where t.name = 'Molly'
|""".stripMargin),
new CatalystPlanContext)
val logPlan =
planTransformer.visit(
plan(
pplParser,
"""
| source=table1, table2, table3 as `t`
| | where `t`.`name` = 'Molly'
|""".stripMargin),
new CatalystPlanContext)

comparePlans(expectedPlan, logPlan, false)
}

test("test fields + field list") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -845,4 +845,64 @@ class PPLLogicalPlanJoinTranslatorTestSuite
Project(Seq(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name")), joinPlan1)
comparePlans(expectedPlan, logicalPlan, checkAnalysis = false)
}

test("test multiple joins with table and subquery backticks alias") {
val originPlan = plan(
pplParser,
s"""
| source = table1 as t1
| | JOIN left = l right = r ON t1.id = t2.id
| [
| source = table2 as t2
| ]
| | JOIN left = l right = r ON t2.id = t3.id
| [
| source = table3 as t3
| ]
| | JOIN left = l right = r ON t3.id = t4.id
| [
| source = table4 as t4
| ]
| """.stripMargin)
val expectedPlan = planTransformer.visit(originPlan, new CatalystPlanContext)
val logPlan = plan(
pplParser,
s"""
| source = table1 as `t1`
| | JOIN left = `l` right = `r` ON `t1`.`id` = `t2`.`id`
| [
| source = table2 as `t2`
| ]
| | JOIN left = `l` right = `r` ON `t2`.`id` = `t3`.`id`
| [
| source = table3 as `t3`
| ]
| | JOIN left = `l` right = `r` ON `t3`.`id` = `t4`.`id`
| [
| source = table4 as `t4`
| ]
| """.stripMargin)
val logicalPlan = planTransformer.visit(logPlan, new CatalystPlanContext)
comparePlans(expectedPlan, logicalPlan, checkAnalysis = false)
}

test("test complex backticks subquery alias") {
val originPlan = plan(
pplParser,
s"""
| source = $testTable1
| | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = $testTable2 as ttt ] as tt
| | fields t1.name, t2.name
| """.stripMargin)
val expectedPlan = planTransformer.visit(originPlan, new CatalystPlanContext)
val logPlan = plan(
pplParser,
s"""
| source = $testTable1
| | JOIN left = `t1` right = `t2` ON `t1`.`name` = `t2`.`name` [ source = $testTable2 as `ttt` ] as `tt`
| | fields `t1`.`name`, `t2`.`name`
| """.stripMargin)
val logicalPlan = planTransformer.visit(logPlan, new CatalystPlanContext)
comparePlans(expectedPlan, logicalPlan, checkAnalysis = false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,16 @@ class PPLLogicalPlanRenameTranslatorTestSuite
Project(seq(UnresolvedAttribute("eval_rand")), planDropColumn)
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}

test("test rename with backticks alias") {
val expectedPlan =
planTransformer.visit(
plan(pplParser, "source=t | rename a as r_a, b as r_b | fields c"),
new CatalystPlanContext)
val logPlan =
planTransformer.visit(
plan(pplParser, "source=t | rename `a` as `r_a`, `b` as `r_b` | fields `c`"),
new CatalystPlanContext)
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite
comparePlans(logPlan, expectedPlan, checkAnalysis = false)
}

test("test trendline with sort and backticks alias") {
val expectedPlan =
planTransformer.visit(
plan(pplParser, "source=relation | trendline sort - age sma(3, age) as age_sma"),
new CatalystPlanContext)
val logPlan =
planTransformer.visit(
plan(pplParser, "source=relation | trendline sort - `age` sma(3, `age`) as `age_sma`"),
new CatalystPlanContext)
comparePlans(logPlan, expectedPlan, checkAnalysis = false)
}

test("test trendline with multiple trendline sma commands") {
val context = new CatalystPlanContext
val logPlan =
Expand Down

0 comments on commit b9a0dc5

Please sign in to comment.