From 904fea1545fe0b4f2a1f4a8cfef1d0bbd103e19f Mon Sep 17 00:00:00 2001 From: Will Ayd Date: Thu, 18 Jan 2024 18:07:06 -0500 Subject: [PATCH] VARCHAR support --- pantab/src/pantab.cpp | 46 +++++++++++++------------------------ pantab/tests/test_reader.py | 43 +++++++++++++++++++++++++++++++--- 2 files changed, 56 insertions(+), 33 deletions(-) diff --git a/pantab/src/pantab.cpp b/pantab/src/pantab.cpp index 628d1a87..ccfd175a 100644 --- a/pantab/src/pantab.cpp +++ b/pantab/src/pantab.cpp @@ -615,36 +615,22 @@ static auto makeReadHelper(const ArrowSchemaView *schema_view, } } -static auto -arrowTypeFromHyper(const hyperapi::SqlType &sqltype) -> enum ArrowType { - if (sqltype == hyperapi::SqlType::smallInt()){return NANOARROW_TYPE_INT16;} -else if (sqltype == hyperapi::SqlType::integer()) { - return NANOARROW_TYPE_INT32; -} -else if (sqltype == hyperapi::SqlType::bigInt()) { - return NANOARROW_TYPE_INT64; -} -else if (sqltype == hyperapi::SqlType::doublePrecision()) { - return NANOARROW_TYPE_DOUBLE; -} -else if (sqltype == hyperapi::SqlType::text()) { - return NANOARROW_TYPE_LARGE_STRING; -} -else if (sqltype == hyperapi::SqlType::boolean()) { - return NANOARROW_TYPE_BOOL; -} -else if (sqltype == hyperapi::SqlType::timestamp()) { - return NANOARROW_TYPE_TIMESTAMP; -} -else if (sqltype == hyperapi::SqlType::timestampTZ()) { - return NANOARROW_TYPE_TIMESTAMP; // todo: how to encode tz info? -} -else if (sqltype == hyperapi::SqlType::date()) { - return NANOARROW_TYPE_DATE32; -} - -throw nb::type_error( - ("unimplemented pandas dtype for type: " + sqltype.toString()).c_str()); +static auto arrowTypeFromHyper(const hyperapi::SqlType &sqltype) + -> enum ArrowType { + switch (sqltype.getTag()){ + case hyperapi::TypeTag::SmallInt : return NANOARROW_TYPE_INT16; + case hyperapi::TypeTag::Int : return NANOARROW_TYPE_INT32; + case hyperapi::TypeTag::BigInt : return NANOARROW_TYPE_INT64; + case hyperapi::TypeTag::Double : return NANOARROW_TYPE_DOUBLE; + case hyperapi::TypeTag::Varchar : case hyperapi::TypeTag::Char : + case hyperapi::TypeTag::Text : return NANOARROW_TYPE_LARGE_STRING; + case hyperapi::TypeTag::Bool : return NANOARROW_TYPE_BOOL; + case hyperapi::TypeTag::Date : return NANOARROW_TYPE_DATE32; + case hyperapi::TypeTag::Timestamp : case hyperapi::TypeTag:: + TimestampTZ : return NANOARROW_TYPE_TIMESTAMP; + default : throw nb::type_error( + ("Reader not implemented for type: " + sqltype.toString()).c_str()); + } } static auto releaseArrowStream(void *ptr) noexcept -> void { diff --git a/pantab/tests/test_reader.py b/pantab/tests/test_reader.py index 0a237bb9..72e6139c 100644 --- a/pantab/tests/test_reader.py +++ b/pantab/tests/test_reader.py @@ -1,7 +1,7 @@ import pandas as pd import pandas.testing as tm import pytest -from tableauhyperapi import TableName +import tableauhyperapi as tab_api import pantab @@ -31,7 +31,7 @@ def test_reports_unsupported_type(datapath): def test_read_non_roundtrippable(datapath): result = pantab.frame_from_hyper( - datapath / "dates.hyper", table=TableName("Extract", "Extract") + datapath / "dates.hyper", table=tab_api.TableName("Extract", "Extract") ) expected = pd.DataFrame( [["1900-01-01", "2000-01-01"], [pd.NaT, "2050-01-01"]], @@ -43,7 +43,8 @@ def test_read_non_roundtrippable(datapath): def test_reads_non_writeable(datapath): result = pantab.frame_from_hyper( - datapath / "non_pantab_writeable.hyper", table=TableName("public", "table") + datapath / "non_pantab_writeable.hyper", + table=tab_api.TableName("public", "table"), ) expected = pd.DataFrame( @@ -85,3 +86,39 @@ def test_empty_read_query(df: pd.DataFrame, roundtripped, tmp_hyper): result = pantab.frame_from_hyper_query(tmp_hyper, query) tm.assert_frame_equal(result, expected) + + +def test_read_varchar(tmp_hyper): + column_name = "VARCHAR Column" + table_name = tab_api.TableName("public", "table") + table = tab_api.TableDefinition( + table_name=table_name, + columns=[ + tab_api.TableDefinition.Column( + name=column_name, + type=tab_api.SqlType.varchar(42), + nullability=tab_api.NOT_NULLABLE, + ) + ], + ) + + with tab_api.HyperProcess( + telemetry=tab_api.Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU + ) as hyper: + with tab_api.Connection( + endpoint=hyper.endpoint, + database=tmp_hyper, + create_mode=tab_api.CreateMode.CREATE_AND_REPLACE, + ) as connection: + connection.catalog.create_table(table_definition=table) + + with tab_api.Inserter(connection, table) as inserter: + inserter.add_rows([["foo"], ["bar"]]) + inserter.execute() + + expected = pd.DataFrame( + [["foo"], ["bar"]], columns=[column_name], dtype="large_string[pyarrow]" + ) + + result = pantab.frame_from_hyper(tmp_hyper, table=table_name) + tm.assert_frame_equal(result, expected)