Skip to content

Commit 7ae939a

Browse files
yaooqinnulysses-you
authored andcommitted
[SPARK-48168][SQL] Add bitwise shifting operators support
### What changes were proposed in this pull request? This PR introduces three bitwise shifting operators as aliases for existing shifting functions. ### Why are the changes needed? The bit shifting functions named in alphabet form vary from one platform to anthor. Take our shiftleft as an example, - Hive, shiftleft (where we copied it from) - MsSQL Server LEFT_SHIFT - MySQL, N/A - PostgreSQL, N/A - Presto, bitwise_left_shift The [bit shifting operators](https://en.wikipedia.org/wiki/Bitwise_operations_in_C) share a much more common and consistent way for users to port their queries. For self-consistent with existing bit operators in Spark, `AND &`, `OR |`, `XOR ^` and `NOT ~`, we now add `<<`, `>>` and `>>>`. For other systems that we can refer to: https://learn.microsoft.com/en-us/sql/t-sql/functions/left-shift-transact-sql?view=sql-server-ver16 https://www.postgresql.org/docs/9.4/functions-bitstring.html https://dev.mysql.com/doc/refman/8.0/en/bit-functions.html ### Does this PR introduce _any_ user-facing change? Yes, new operators were added but no behavior change ### How was this patch tested? new tests ### Was this patch authored or co-authored using generative AI tooling? no Closes apache#46440 from yaooqinn/SPARK-48168. Authored-by: Kent Yao <yao@apache.org> Signed-off-by: youxiduo <youxiduo@corp.netease.com>
1 parent bd95040 commit 7ae939a

File tree

48 files changed

+469
-153
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+469
-153
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
Project [shiftleft(cast(b#0 as int), 2) AS shiftleft(b, 2)#0]
1+
Project [(cast(b#0 as int) << 2) AS (b << 2)#0]
22
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
Project [shiftright(cast(b#0 as int), 2) AS shiftright(b, 2)#0]
1+
Project [(cast(b#0 as int) >> 2) AS (b >> 2)#0]
22
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
Project [shiftrightunsigned(cast(b#0 as int), 2) AS shiftrightunsigned(b, 2)#0]
1+
Project [(cast(b#0 as int) >>> 2) AS (b >>> 2)#0]
22
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Aggregate [a#0, b#0, spark_grouping_id#0L], [a#0, b#0, cast((shiftright(spark_grouping_id#0L, 1) & 1) as tinyint) AS grouping(a)#0, cast((shiftright(spark_grouping_id#0L, 0) & 1) as tinyint) AS grouping(b)#0, spark_grouping_id#0L AS grouping_id(a, b)#0L]
1+
Aggregate [a#0, b#0, spark_grouping_id#0L], [a#0, b#0, cast(((spark_grouping_id#0L >> 1) & 1) as tinyint) AS grouping(a)#0, cast(((spark_grouping_id#0L >> 0) & 1) as tinyint) AS grouping(b)#0, spark_grouping_id#0L AS grouping_id(a, b)#0L]
22
+- Expand [[id#0L, a#0, b#0, a#0, b#0, 0], [id#0L, a#0, b#0, a#0, null, 1], [id#0L, a#0, b#0, null, b#0, 2], [id#0L, a#0, b#0, null, null, 3]], [id#0L, a#0, b#0, a#0, b#0, spark_grouping_id#0L]
33
+- Project [id#0L, a#0, b#0, a#0 AS a#0, b#0 AS b#0]
44
+- LocalRelation <empty>, [id#0L, a#0, b#0]

sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4

+36-4
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,35 @@ lexer grammar SqlBaseLexer;
6969
public void markUnclosedComment() {
7070
has_unclosed_bracketed_comment = true;
7171
}
72+
73+
/**
74+
* When greater than zero, it's in the middle of parsing ARRAY/MAP/STRUCT type.
75+
*/
76+
public int complex_type_level_counter = 0;
77+
78+
/**
79+
* Increase the counter by one when hits KEYWORD 'ARRAY', 'MAP', 'STRUCT'.
80+
*/
81+
public void incComplexTypeLevelCounter() {
82+
complex_type_level_counter++;
83+
}
84+
85+
/**
86+
* Decrease the counter by one when hits close tag '>' && the counter greater than zero
87+
* which means we are in the middle of complex type parsing. Otherwise, it's a dangling
88+
* GT token and we do nothing.
89+
*/
90+
public void decComplexTypeLevelCounter() {
91+
if (complex_type_level_counter > 0) complex_type_level_counter--;
92+
}
93+
94+
/**
95+
* If the counter is zero, it's a shift right operator. It can be closing tags of an complex
96+
* type definition, such as MAP<INT, ARRAY<INT>>.
97+
*/
98+
public boolean isShiftRightOperator() {
99+
return complex_type_level_counter == 0 ? true : false;
100+
}
72101
}
73102

74103
SEMICOLON: ';';
@@ -100,7 +129,7 @@ ANTI: 'ANTI';
100129
ANY: 'ANY';
101130
ANY_VALUE: 'ANY_VALUE';
102131
ARCHIVE: 'ARCHIVE';
103-
ARRAY: 'ARRAY';
132+
ARRAY: 'ARRAY' {incComplexTypeLevelCounter();};
104133
AS: 'AS';
105134
ASC: 'ASC';
106135
AT: 'AT';
@@ -259,7 +288,7 @@ LOCKS: 'LOCKS';
259288
LOGICAL: 'LOGICAL';
260289
LONG: 'LONG';
261290
MACRO: 'MACRO';
262-
MAP: 'MAP';
291+
MAP: 'MAP' {incComplexTypeLevelCounter();};
263292
MATCHED: 'MATCHED';
264293
MERGE: 'MERGE';
265294
MICROSECOND: 'MICROSECOND';
@@ -362,7 +391,7 @@ STATISTICS: 'STATISTICS';
362391
STORED: 'STORED';
363392
STRATIFY: 'STRATIFY';
364393
STRING: 'STRING';
365-
STRUCT: 'STRUCT';
394+
STRUCT: 'STRUCT' {incComplexTypeLevelCounter();};
366395
SUBSTR: 'SUBSTR';
367396
SUBSTRING: 'SUBSTRING';
368397
SYNC: 'SYNC';
@@ -439,8 +468,11 @@ NEQ : '<>';
439468
NEQJ: '!=';
440469
LT : '<';
441470
LTE : '<=' | '!>';
442-
GT : '>';
471+
GT : '>' {decComplexTypeLevelCounter();};
443472
GTE : '>=' | '!<';
473+
SHIFT_LEFT: '<<';
474+
SHIFT_RIGHT: '>>' {isShiftRightOperator()}?;
475+
SHIFT_RIGHT_UNSIGNED: '>>>' {isShiftRightOperator()}?;
444476

445477
PLUS: '+';
446478
MINUS: '-';

sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4

+8
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,7 @@ describeFuncName
395395
| comparisonOperator
396396
| arithmeticOperator
397397
| predicateOperator
398+
| shiftOperator
398399
| BANG
399400
;
400401

@@ -989,6 +990,13 @@ valueExpression
989990
| left=valueExpression operator=HAT right=valueExpression #arithmeticBinary
990991
| left=valueExpression operator=PIPE right=valueExpression #arithmeticBinary
991992
| left=valueExpression comparisonOperator right=valueExpression #comparison
993+
| left=valueExpression shiftOperator right=valueExpression #shiftExpression
994+
;
995+
996+
shiftOperator
997+
: SHIFT_LEFT
998+
| SHIFT_RIGHT
999+
| SHIFT_RIGHT_UNSIGNED
9921000
;
9931001

9941002
datetimeUnit

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

+3
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,9 @@ object FunctionRegistry {
800800
expression[BitwiseNot]("~"),
801801
expression[BitwiseOr]("|"),
802802
expression[BitwiseXor]("^"),
803+
expression[ShiftLeft]("<<", true, Some("4.0.0")),
804+
expression[ShiftRight](">>", true, Some("4.0.0")),
805+
expression[ShiftRightUnsigned](">>>", true, Some("4.0.0")),
803806
expression[BitwiseCount]("bit_count"),
804807
expression[BitAndAgg]("bit_and"),
805808
expression[BitOrAgg]("bit_or"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala

+69-65
Original file line numberDiff line numberDiff line change
@@ -1261,6 +1261,41 @@ case class Pow(left: Expression, right: Expression)
12611261
newLeft: Expression, newRight: Expression): Expression = copy(left = newLeft, right = newRight)
12621262
}
12631263

1264+
sealed trait BitShiftOperation
1265+
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
1266+
1267+
def symbol: String
1268+
def shiftInt: (Int, Int) => Int
1269+
def shiftLong: (Long, Int) => Long
1270+
1271+
override def inputTypes: Seq[AbstractDataType] =
1272+
Seq(TypeCollection(IntegerType, LongType), IntegerType)
1273+
1274+
override def dataType: DataType = left.dataType
1275+
1276+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
1277+
defineCodeGen(ctx, ev, (left, right) => s"$left $symbol $right")
1278+
}
1279+
1280+
override protected def nullSafeEval(input1: Any, input2: Any): Any = input1 match {
1281+
case l: jl.Long => shiftLong(l, input2.asInstanceOf[Int])
1282+
case i: jl.Integer => shiftInt(i, input2.asInstanceOf[Int])
1283+
}
1284+
1285+
override def toString: String = {
1286+
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse(symbol) match {
1287+
case alias if alias == symbol => s"($left $symbol $right)"
1288+
case _ => super.toString
1289+
}
1290+
}
1291+
1292+
override def sql: String = {
1293+
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse(symbol) match {
1294+
case alias if alias == symbol => s"(${left.sql} $symbol ${right.sql})"
1295+
case _ => super.sql
1296+
}
1297+
}
1298+
}
12641299

12651300
/**
12661301
* Bitwise left shift.
@@ -1269,111 +1304,80 @@ case class Pow(left: Expression, right: Expression)
12691304
* @param right number of bits to left shift.
12701305
*/
12711306
@ExpressionDescription(
1272-
usage = "_FUNC_(base, expr) - Bitwise left shift.",
1307+
usage = "base << exp - Bitwise left shift.",
12731308
examples = """
12741309
Examples:
1275-
> SELECT _FUNC_(2, 1);
1310+
> SELECT shiftleft(2, 1);
1311+
4
1312+
> SELECT 2 << 1;
12761313
4
12771314
""",
1315+
note = """
1316+
`<<` operator is added in Spark 4.0.0 as an alias for `shiftleft`.
1317+
""",
12781318
since = "1.5.0",
12791319
group = "bitwise_funcs")
1280-
case class ShiftLeft(left: Expression, right: Expression)
1281-
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
1282-
1283-
override def inputTypes: Seq[AbstractDataType] =
1284-
Seq(TypeCollection(IntegerType, LongType), IntegerType)
1285-
1286-
override def dataType: DataType = left.dataType
1287-
1288-
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
1289-
input1 match {
1290-
case l: jl.Long => l << input2.asInstanceOf[jl.Integer]
1291-
case i: jl.Integer => i << input2.asInstanceOf[jl.Integer]
1292-
}
1293-
}
1294-
1295-
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
1296-
defineCodeGen(ctx, ev, (left, right) => s"$left << $right")
1297-
}
1298-
1320+
case class ShiftLeft(left: Expression, right: Expression) extends BitShiftOperation {
1321+
override def symbol: String = "<<"
1322+
override def shiftInt: (Int, Int) => Int = (x: Int, y: Int) => x << y
1323+
override def shiftLong: (Long, Int) => Long = (x: Long, y: Int) => x << y
1324+
val shift: (Number, Int) => Any = (x: Number, y: Int) => x.longValue() << y
12991325
override protected def withNewChildrenInternal(
13001326
newLeft: Expression, newRight: Expression): ShiftLeft = copy(left = newLeft, right = newRight)
13011327
}
13021328

1303-
13041329
/**
13051330
* Bitwise (signed) right shift.
13061331
*
13071332
* @param left the base number to shift.
13081333
* @param right number of bits to right shift.
13091334
*/
13101335
@ExpressionDescription(
1311-
usage = "_FUNC_(base, expr) - Bitwise (signed) right shift.",
1336+
usage = "base >> expr - Bitwise (signed) right shift.",
13121337
examples = """
13131338
Examples:
1314-
> SELECT _FUNC_(4, 1);
1339+
> SELECT shiftright(4, 1);
1340+
2
1341+
> SELECT 4 >> 1;
13151342
2
13161343
""",
1344+
note = """
1345+
`>>` operator is added in Spark 4.0.0 as an alias for `shiftright`.
1346+
""",
13171347
since = "1.5.0",
13181348
group = "bitwise_funcs")
1319-
case class ShiftRight(left: Expression, right: Expression)
1320-
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
1321-
1322-
override def inputTypes: Seq[AbstractDataType] =
1323-
Seq(TypeCollection(IntegerType, LongType), IntegerType)
1324-
1325-
override def dataType: DataType = left.dataType
1326-
1327-
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
1328-
input1 match {
1329-
case l: jl.Long => l >> input2.asInstanceOf[jl.Integer]
1330-
case i: jl.Integer => i >> input2.asInstanceOf[jl.Integer]
1331-
}
1332-
}
1333-
1334-
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
1335-
defineCodeGen(ctx, ev, (left, right) => s"$left >> $right")
1336-
}
1337-
1349+
case class ShiftRight(left: Expression, right: Expression) extends BitShiftOperation {
1350+
override def symbol: String = ">>"
1351+
override def shiftInt: (Int, Int) => Int = (x: Int, y: Int) => x >> y
1352+
override def shiftLong: (Long, Int) => Long = (x: Long, y: Int) => x >> y
13381353
override protected def withNewChildrenInternal(
13391354
newLeft: Expression, newRight: Expression): ShiftRight = copy(left = newLeft, right = newRight)
13401355
}
13411356

1342-
13431357
/**
13441358
* Bitwise unsigned right shift, for integer and long data type.
13451359
*
13461360
* @param left the base number.
13471361
* @param right the number of bits to right shift.
13481362
*/
13491363
@ExpressionDescription(
1350-
usage = "_FUNC_(base, expr) - Bitwise unsigned right shift.",
1364+
usage = "base >>> expr - Bitwise unsigned right shift.",
13511365
examples = """
13521366
Examples:
1353-
> SELECT _FUNC_(4, 1);
1367+
> SELECT shiftrightunsigned(4, 1);
13541368
2
1369+
> SELECT 4 >>> 1;
1370+
2
1371+
""",
1372+
note = """
1373+
`>>>` operator is added in Spark 4.0.0 as an alias for `shiftrightunsigned`.
13551374
""",
13561375
since = "1.5.0",
13571376
group = "bitwise_funcs")
1358-
case class ShiftRightUnsigned(left: Expression, right: Expression)
1359-
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
1360-
1361-
override def inputTypes: Seq[AbstractDataType] =
1362-
Seq(TypeCollection(IntegerType, LongType), IntegerType)
1363-
1364-
override def dataType: DataType = left.dataType
1365-
1366-
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
1367-
input1 match {
1368-
case l: jl.Long => l >>> input2.asInstanceOf[jl.Integer]
1369-
case i: jl.Integer => i >>> input2.asInstanceOf[jl.Integer]
1370-
}
1371-
}
1372-
1373-
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
1374-
defineCodeGen(ctx, ev, (left, right) => s"$left >>> $right")
1375-
}
1376-
1377+
case class ShiftRightUnsigned(left: Expression, right: Expression) extends BitShiftOperation {
1378+
override def symbol: String = ">>>"
1379+
override def shiftInt: (Int, Int) => Int = (x: Int, y: Int) => x >>> y
1380+
override def shiftLong: (Long, Int) => Long = (x: Long, y: Int) => x >>> y
13771381
override protected def withNewChildrenInternal(
13781382
newLeft: Expression, newRight: Expression): ShiftRightUnsigned =
13791383
copy(left = newLeft, right = newRight)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

+11
Original file line numberDiff line numberDiff line change
@@ -2196,6 +2196,17 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
21962196
}
21972197
}
21982198

2199+
override def visitShiftExpression(ctx: ShiftExpressionContext): Expression = withOrigin(ctx) {
2200+
val left = expression(ctx.left)
2201+
val right = expression(ctx.right)
2202+
val operator = ctx.shiftOperator().getChild(0).asInstanceOf[TerminalNode]
2203+
operator.getSymbol.getType match {
2204+
case SqlBaseParser.SHIFT_LEFT => ShiftLeft(left, right)
2205+
case SqlBaseParser.SHIFT_RIGHT => ShiftRight(left, right)
2206+
case SqlBaseParser.SHIFT_RIGHT_UNSIGNED => ShiftRightUnsigned(left, right)
2207+
}
2208+
}
2209+
21992210
/**
22002211
* Create a unary arithmetic expression. The following arithmetic operators are supported:
22012212
* - Plus: '+'

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ trait SQLKeywordUtils extends SparkFunSuite with SQLHelper {
9898
}
9999
(symbol, literals) :: Nil
100100
} else {
101-
val literal = literalDef.replaceAll("'", "").trim
101+
val literal = literalDef.split("\\{")(0).replaceAll("'", "").trim
102102
// The case where a symbol string and its literal string are different,
103103
// e.g., `SETMINUS: 'MINUS';`.
104104
if (symbol != literal) {

sql/core/src/test/resources/sql-functions/sql-expression-schema.md

+6-3
Original file line numberDiff line numberDiff line change
@@ -289,9 +289,12 @@
289289
| org.apache.spark.sql.catalyst.expressions.Sha1 | sha | SELECT sha('Spark') | struct<sha(Spark):string> |
290290
| org.apache.spark.sql.catalyst.expressions.Sha1 | sha1 | SELECT sha1('Spark') | struct<sha1(Spark):string> |
291291
| org.apache.spark.sql.catalyst.expressions.Sha2 | sha2 | SELECT sha2('Spark', 256) | struct<sha2(Spark, 256):string> |
292-
| org.apache.spark.sql.catalyst.expressions.ShiftLeft | shiftleft | SELECT shiftleft(2, 1) | struct<shiftleft(2, 1):int> |
293-
| org.apache.spark.sql.catalyst.expressions.ShiftRight | shiftright | SELECT shiftright(4, 1) | struct<shiftright(4, 1):int> |
294-
| org.apache.spark.sql.catalyst.expressions.ShiftRightUnsigned | shiftrightunsigned | SELECT shiftrightunsigned(4, 1) | struct<shiftrightunsigned(4, 1):int> |
292+
| org.apache.spark.sql.catalyst.expressions.ShiftLeft | << | SELECT shiftleft(2, 1) | struct<(2 << 1):int> |
293+
| org.apache.spark.sql.catalyst.expressions.ShiftLeft | shiftleft | SELECT shiftleft(2, 1) | struct<(2 << 1):int> |
294+
| org.apache.spark.sql.catalyst.expressions.ShiftRight | >> | SELECT shiftright(4, 1) | struct<(4 >> 1):int> |
295+
| org.apache.spark.sql.catalyst.expressions.ShiftRight | shiftright | SELECT shiftright(4, 1) | struct<(4 >> 1):int> |
296+
| org.apache.spark.sql.catalyst.expressions.ShiftRightUnsigned | >>> | SELECT shiftrightunsigned(4, 1) | struct<(4 >>> 1):int> |
297+
| org.apache.spark.sql.catalyst.expressions.ShiftRightUnsigned | shiftrightunsigned | SELECT shiftrightunsigned(4, 1) | struct<(4 >>> 1):int> |
295298
| org.apache.spark.sql.catalyst.expressions.Shuffle | shuffle | SELECT shuffle(array(1, 20, 3, 5)) | struct<shuffle(array(1, 20, 3, 5)):array<int>> |
296299
| org.apache.spark.sql.catalyst.expressions.Signum | sign | SELECT sign(40) | struct<sign(40):double> |
297300
| org.apache.spark.sql.catalyst.expressions.Signum | signum | SELECT signum(40) | struct<SIGNUM(40):double> |

0 commit comments

Comments
 (0)