From e5dd22935c9f3977645927f048bfe3145d9d0247 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C4=B0lkin=20Balkanay?= Date: Fri, 5 Apr 2024 19:33:07 +0300 Subject: [PATCH] add transformation for object_agg function (#1) * add transformation for object_agg function * add more tests for object_agg * rm duplicate tests * ruff format --- fakesnow/fakes.py | 3 +- fakesnow/transforms.py | 18 +++-- tests/test_fakes.py | 14 ++++ tests/test_info_schema.py | 2 +- tests/test_transforms.py | 145 ++------------------------------------ 5 files changed, 38 insertions(+), 144 deletions(-) diff --git a/fakesnow/fakes.py b/fakesnow/fakes.py index 1ac4459..c25b5c6 100644 --- a/fakesnow/fakes.py +++ b/fakesnow/fakes.py @@ -206,7 +206,7 @@ def _execute( .transform(transforms.timestamp_ntz_ns) .transform(transforms.float_to_double) .transform(transforms.integer_precision) - # TODO(selman): Broken, failes on CTAS queries with CASTs; + # TODO(selman): Broken, fails on CTAS queries with CASTs; # CREATE TABLE SOME_TABLE AS ( # SELECT # R1 AS C1, @@ -234,6 +234,7 @@ def _execute( .transform(transforms.to_variant) .transform(transforms.show_users) .transform(transforms.create_user) + .transform(transforms.object_agg) ) sql = transformed.sql(dialect="duckdb") result_sql = None diff --git a/fakesnow/transforms.py b/fakesnow/transforms.py index 5728cb6..20ec71b 100644 --- a/fakesnow/transforms.py +++ b/fakesnow/transforms.py @@ -1425,13 +1425,13 @@ def to_json_extract_scalar(expression: exp.JSONExtract) -> exp.Expression: left = expression.left right = expression.right - # = + # = if is_json_extract(left) and isinstance(right, exp.Literal) and right.is_string: json_extract_scalar = exp.Paren(this=to_json_extract_scalar(cast(exp.JSONExtract, left))) return exp.EQ(this=json_extract_scalar, expression=right) - # = + # = elif is_json_extract(right) and isinstance(left, exp.Literal) and left.is_string: json_extract_scalar = exp.Paren(this=to_json_extract_scalar(cast(exp.JSONExtract, right))) @@ -1441,7 +1441,7 @@ def to_json_extract_scalar(expression: exp.JSONExtract) -> exp.Expression: def json_extract_in_string_literals(expression: exp.Expression) -> exp.Expression: - """Snowflake does implicit casting on JSON extract value on an IN caluse; + """Snowflake does implicit casting on JSON extract value on an IN clause; Snowflake; SELECT @@ -1483,7 +1483,7 @@ def json_extract_in_string_literals(expression: exp.Expression) -> exp.Expressio via `->` operator thus FALSE in first case of the first query. To keep things simple, if all values in IN clause are string literals, we - can extract JSON value as string/VARCHAR to achive similar behaviour. + can extract JSON value as string/VARCHAR to achieve similar behaviour. SELECT TO_JSON({'k': '10'}) AS D, @@ -1514,3 +1514,13 @@ def json_extract_in_string_literals(expression: exp.Expression) -> exp.Expressio json_extract_scalar = exp.JSONExtractScalar(this=je.this, expression=path) return exp.In(this=json_extract_scalar, expressions=expression.expressions) + + +def object_agg(expression: exp.Expression) -> exp.Expression: + if ( + isinstance(expression, exp.Anonymous) + and isinstance(expression.this, str) + and expression.this.upper() == "OBJECT_AGG" + ): + return exp.Anonymous(this="JSON_GROUP_OBJECT", expressions=expression.expressions) + return expression diff --git a/tests/test_fakes.py b/tests/test_fakes.py index 9455f8e..6fa3f00 100644 --- a/tests/test_fakes.py +++ b/tests/test_fakes.py @@ -1508,6 +1508,20 @@ def test_json_extract_cast_as_varchar(dcur: snowflake.connector.cursor.DictCurso assert dcur.fetchall() == [{"C_STR_NUMBER": 100, "C_NUM_NUMBER": 100}] +def test_json_group_object(dcur: snowflake.connector.cursor.DictCursor): + dcur.execute("create table table1 (id number, key varchar, value varchar)") + values = [(1, "a", "1"), (1, "b", "2"), (1, "c", "3"), (2, "e", "1"), (2, "f", "1"), (3, "a", "2")] + + dcur.executemany("insert into table1 values (%s, %s, %s)", values) + expected = [ + {"ID": 1, "OBJ": '{\n "a": "1",\n "b": "2",\n "c": "3"\n}'}, + {"ID": 2, "OBJ": '{\n "e": "1",\n "f": "1"\n}'}, + {"ID": 3, "OBJ": '{\n "a": "2"\n}'}, + ] + dcur.execute("select id, object_agg(key, value) as obj from table1 group by 1") + assert dindent(dcur.fetchall()) == expected + + def test_write_pandas_quoted_column_names(conn: snowflake.connector.SnowflakeConnection): with conn.cursor(snowflake.connector.cursor.DictCursor) as dcur: # colunmn names with spaces diff --git a/tests/test_info_schema.py b/tests/test_info_schema.py index 66eb69c..31c1009 100644 --- a/tests/test_info_schema.py +++ b/tests/test_info_schema.py @@ -107,7 +107,7 @@ def test_info_schema_columns_other(cur: snowflake.connector.cursor.SnowflakeCurs ] -@pytest.mark.xfail(reason="NOTE(selman): removed extact_extract_text_length transformation") +@pytest.mark.xfail(reason="NOTE(selman): removed extract_extract_text_length transformation") def test_info_schema_columns_text(cur: snowflake.connector.cursor.SnowflakeCursor): # see https://docs.snowflake.com/en/sql-reference/data-types-text cur.execute( diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 71874e0..bf57254 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -31,6 +31,7 @@ json_extract_eq_string_literal, json_extract_in_string_literals, json_extract_precedence, + object_agg, object_construct, random, regex_replace, @@ -831,144 +832,6 @@ def test_trim_cast_varchar() -> None: ) -def test__get_to_number_args() -> None: - e = sqlglot.parse_one("to_number('100')", read="snowflake") - assert isinstance(e, exp.ToNumber) - assert _get_to_number_args(e) == (None, None, None) - - e = sqlglot.parse_one("to_number('100', 10)", read="snowflake") - assert isinstance(e, exp.ToNumber) - assert _get_to_number_args(e) == (None, exp.Literal(this="10", is_string=False), None) - - e = sqlglot.parse_one("to_number('100', 10,2)", read="snowflake") - assert isinstance(e, exp.ToNumber) - assert _get_to_number_args(e) == ( - None, - exp.Literal(this="10", is_string=False), - exp.Literal(this="2", is_string=False), - ) - - e = sqlglot.parse_one("to_number('100', 'TM9')", read="snowflake") - assert isinstance(e, exp.ToNumber) - assert _get_to_number_args(e) == (exp.Literal(this="TM9", is_string=True), None, None) - - e = sqlglot.parse_one("to_number('100', 'TM9', 10)", read="snowflake") - assert isinstance(e, exp.ToNumber) - assert _get_to_number_args(e) == ( - exp.Literal(this="TM9", is_string=True), - exp.Literal(this="10", is_string=False), - None, - ) - - e = sqlglot.parse_one("to_number('100', 'TM9', 10, 2)", read="snowflake") - assert isinstance(e, exp.ToNumber) - assert _get_to_number_args(e) == ( - exp.Literal(this="TM9", is_string=True), - exp.Literal(this="10", is_string=False), - exp.Literal(this="2", is_string=False), - ) - - -def test_to_number() -> None: - assert ( - sqlglot.parse_one("SELECT to_number('100')", read="snowflake").transform(to_decimal).sql(dialect="duckdb") - == "SELECT CAST('100' AS DECIMAL(38, 0))" - ) - - assert ( - sqlglot.parse_one("SELECT to_number('100', 10)", read="snowflake").transform(to_decimal).sql(dialect="duckdb") - == "SELECT CAST('100' AS DECIMAL(10, 0))" - ) - - assert ( - sqlglot.parse_one("SELECT to_number('100', 10,2)", read="snowflake").transform(to_decimal).sql(dialect="duckdb") - == "SELECT CAST('100' AS DECIMAL(10, 2))" - ) - - with pytest.raises(NotImplementedError): - sqlglot.parse_one("SELECT to_number('100', 'TM9')", read="snowflake").transform(to_decimal).sql( - dialect="duckdb" - ) - - with pytest.raises(NotImplementedError): - sqlglot.parse_one("SELECT to_number('100', 'TM9', 10)", read="snowflake").transform(to_decimal).sql( - dialect="duckdb" - ) - - with pytest.raises(NotImplementedError): - sqlglot.parse_one("SELECT to_number('100', 'TM9', 10, 2)", read="snowflake").transform(to_decimal).sql( - dialect="duckdb" - ) - - -def test_to_number_decimal() -> None: - assert ( - sqlglot.parse_one("SELECT to_decimal('100')", read="snowflake").transform(to_decimal).sql(dialect="duckdb") - == "SELECT CAST('100' AS DECIMAL(38, 0))" - ) - - assert ( - sqlglot.parse_one("SELECT to_decimal('100', 10)", read="snowflake").transform(to_decimal).sql(dialect="duckdb") - == "SELECT CAST('100' AS DECIMAL(10, 0))" - ) - - assert ( - sqlglot.parse_one("SELECT to_decimal('100', 10,2)", read="snowflake") - .transform(to_decimal) - .sql(dialect="duckdb") - == "SELECT CAST('100' AS DECIMAL(10, 2))" - ) - - with pytest.raises(NotImplementedError): - sqlglot.parse_one("SELECT to_decimal('100', 'TM9')", read="snowflake").transform(to_decimal).sql( - dialect="duckdb" - ) - - with pytest.raises(NotImplementedError): - sqlglot.parse_one("SELECT to_decimal('100', 'TM9', 10)", read="snowflake").transform(to_decimal).sql( - dialect="duckdb" - ) - - with pytest.raises(NotImplementedError): - sqlglot.parse_one("SELECT to_decimal('100', 'TM9', 10, 2)", read="snowflake").transform(to_decimal).sql( - dialect="duckdb" - ) - - -def test_to_number_numeric() -> None: - assert ( - sqlglot.parse_one("SELECT to_numeric('100')", read="snowflake").transform(to_decimal).sql(dialect="duckdb") - == "SELECT CAST('100' AS DECIMAL(38, 0))" - ) - - assert ( - sqlglot.parse_one("SELECT to_numeric('100', 10)", read="snowflake").transform(to_decimal).sql(dialect="duckdb") - == "SELECT CAST('100' AS DECIMAL(10, 0))" - ) - - assert ( - sqlglot.parse_one("SELECT to_numeric('100', 10,2)", read="snowflake") - .transform(to_decimal) - .sql(dialect="duckdb") - == "SELECT CAST('100' AS DECIMAL(10, 2))" - ) - - with pytest.raises(NotImplementedError): - sqlglot.parse_one("SELECT to_numeric('100', 'TM9')", read="snowflake").transform(to_decimal).sql( - dialect="duckdb" - ) - - with pytest.raises(NotImplementedError): - sqlglot.parse_one("SELECT to_numeric('100', 'TM9', 10)", read="snowflake").transform(to_decimal).sql( - dialect="duckdb" - ) - - with pytest.raises(NotImplementedError): - sqlglot.parse_one("SELECT to_numeric('100', 'TM9', 10, 2)", read="snowflake").transform(to_decimal).sql( - dialect="duckdb" - ) - - def test_upper_case_unquoted_identifiers() -> None: assert ( sqlglot.parse_one("select name, name as fname from table1").transform(upper_case_unquoted_identifiers).sql() @@ -1082,3 +945,9 @@ def test_to_variant() -> None: .sql(dialect="duckdb") == "SELECT TO_JSON('str')" ) + + +def test_object_agg() -> None: + sql = "SELECT ID, OBJECT_AGG(KEY, VALUE) AS POSTCALC FROM POSTCALC_WITH_UPDATED_AT GROUP BY 1" + expected = "SELECT ID, JSON_GROUP_OBJECT(KEY, VALUE) AS POSTCALC FROM POSTCALC_WITH_UPDATED_AT GROUP BY 1" + assert sqlglot.parse_one(sql, read="snowflake").transform(object_agg).sql(dialect="duckdb") == expected