Skip to content

Commit ded7852

Browse files
committed
Copy only top-level DECLARE statements (#44)
In case of isolate_top_level_statements=True, copy DECLARE statements to subsequent statements only if they are top level.
1 parent a092cda commit ded7852

File tree

4 files changed

+44
-4
lines changed

4 files changed

+44
-4
lines changed

src/pytsql/grammar/tsqlParser.py

+7
Original file line numberDiff line numberDiff line change
@@ -24890,6 +24890,13 @@ def getRuleIndex(self):
2489024890
return tsqlParser.RULE_data_type
2489124891

2489224892

24893+
@staticmethod
24894+
def is_top_level_statement(node: ParserRuleContext):
24895+
"""Check wether node is a top level SQL statement."""
24896+
cur = node.parentCtx
24897+
while isinstance(cur, tsqlParser.Sql_clauseContext) or isinstance(cur, tsqlParser.Sql_clausesContext):
24898+
cur = cur.parentCtx
24899+
return isinstance(cur, tsqlParser.BatchContext)
2489324900

2489424901

2489524902
def data_type(self):

src/pytsql/tsql.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def visitChildren(self, node: antlr4.ParserRuleContext) -> List[str]:
112112
else:
113113
result = super().visitChildren(node)
114114

115-
if isinstance(node, tsqlParser.Declare_statementContext):
115+
if isinstance(node, tsqlParser.Declare_statementContext) and tsqlParser.is_top_level_statement(node):
116116
self.dynamics.extend(result)
117117

118118
return result

tests/integration/test_multiple_statements.py

+17
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,23 @@ def test_stored_procedure_declaration(engine):
182182
executes(statement, engine, None)
183183

184184

185+
def test_top_level_declaration(engine):
186+
statement = """
187+
DROP DATABASE IF EXISTS top_level_declaration
188+
CREATE DATABASE top_level_declaration
189+
USE top_level_declaration
190+
GO
191+
192+
DECLARE @Current AS DATE = '2022-01-01'
193+
GO
194+
SELECT @Current as a INTO dummy01
195+
GO
196+
SELECT @Current as b INTO dummy02
197+
GO
198+
"""
199+
executes(statement, engine, None)
200+
201+
185202
def get_table(
186203
engine: Engine, table_name: str, schema: Optional[str] = None
187204
) -> sa.Table:

tests/unit/test_dynamics.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,36 @@ def test_declaration_in_control_flow():
77
seed = """
88
IF 1 = 1
99
DECLARE @A INT = 5
10-
SELECT * FROM x
10+
SELECT @A
1111
"""
1212
splits = _split(seed)
1313
assert len(splits) == 2
1414

1515
assert_strings_equal_disregarding_whitespace(
1616
splits[0], "IF 1 = 1 DECLARE @A INT = 5"
1717
)
18+
# unfortunately we can't be right here because otherwise we would need to get
19+
# the output of the declaration
1820
assert_strings_equal_disregarding_whitespace(
1921
splits[1],
20-
"""
22+
"""SELECT @A""",
23+
)
24+
25+
26+
def test_select_in_control_flow():
27+
seed = """
28+
IF 1 = 0
29+
BEGIN
2130
DECLARE @A INT = 5
2231
SELECT * FROM x
23-
""",
32+
END
33+
"""
34+
splits = _split(seed)
35+
assert len(splits) == 1
36+
37+
# this is beyond the complexity we want to manage with isolate_top_level_statements=True
38+
assert_strings_equal_disregarding_whitespace(
39+
splits[0], "IF 1 = 0 BEGIN DECLARE @A INT = 5 SELECT * FROM x END"
2440
)
2541

2642

0 commit comments

Comments
 (0)