diff --git a/pantab/_writer.py b/pantab/_writer.py index 76d30fd7..da263e7a 100644 --- a/pantab/_writer.py +++ b/pantab/_writer.py @@ -2,7 +2,7 @@ import shutil import tempfile import uuid -from typing import Any, Literal, Optional, Union +from typing import Any, Literal, Union import tableauhyperapi as tab_api @@ -54,24 +54,33 @@ def frame_to_hyper( *, table: pantab_types.TableType, table_mode: Literal["a", "w"] = "w", + json_columns: list[str] = None, + geo_columns: list[str] = None, ) -> None: """See api.rst for documentation""" frames_to_hyper( {table: df}, database, - table_mode, + table_mode=table_mode, + json_columns=json_columns, + geo_columns=geo_columns, ) def frames_to_hyper( dict_of_frames: dict[pantab_types.TableType, Any], database: Union[str, pathlib.Path], - table_mode: Literal["a", "w"] = "w", *, - hyper_process: Optional[tab_api.HyperProcess] = None, + table_mode: Literal["a", "w"] = "w", + json_columns: set[str] = None, + geo_columns: set[str] = None, ) -> None: """See api.rst for documentation.""" _validate_table_mode(table_mode) + if json_columns is None: + json_columns = set() + if geo_columns is None: + geo_columns = set() tmp_db = pathlib.Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.hyper" @@ -89,7 +98,14 @@ def convert_to_table_name(table: pantab_types.TableType): convert_to_table_name(key): _get_capsule_from_obj(val) for key, val in dict_of_frames.items() } - libpantab.write_to_hyper(data, path=str(tmp_db), table_mode=table_mode) + + libpantab.write_to_hyper( + data, + path=str(tmp_db), + table_mode=table_mode, + json_columns=json_columns, + geo_columns=geo_columns, + ) # In Python 3.9+ we can just pass the path object, but due to bpo 32689 # and subsequent typeshed changes it is easier to just pass as str for now diff --git a/pantab/src/pantab.cpp b/pantab/src/pantab.cpp index aa0974a0..cee3d0e9 100644 --- a/pantab/src/pantab.cpp +++ b/pantab/src/pantab.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -492,7 +493,9 @@ using SchemaAndTableName = std::tuple; void write_to_hyper( const std::map &dict_of_capsules, - const std::string &path, const std::string &table_mode) { + const std::string &path, const std::string &table_mode, + const std::set &json_columns, + const std::set &geo_columns) { const std::unordered_map params = { {"log_config", ""}}; const hyperapi::HyperProcess hyper{ @@ -524,15 +527,27 @@ void write_to_hyper( struct ArrowError error; std::vector hyper_columns; for (int64_t i = 0; i < schema.n_children; i++) { - const auto hypertype = - hyperTypeFromArrowSchema(schema.children[i], &error); - - // Almost all arrow types are nullable - const hyperapi::TableDefinition::Column column{ - std::string(schema.children[i]->name), hypertype, - hyperapi::Nullability::Nullable}; + const auto col_name = std::string{schema.children[i]->name}; + if (json_columns.find(col_name) != json_columns.end()) { + const auto hypertype = hyperapi::SqlType::json(); + const hyperapi::TableDefinition::Column column{ + col_name, hypertype, hyperapi::Nullability::Nullable}; + + hyper_columns.emplace_back(std::move(column)); + } else if (geo_columns.find(col_name) != geo_columns.end()) { + const auto hypertype = hyperapi::SqlType::geography(); + const hyperapi::TableDefinition::Column column{ + col_name, hypertype, hyperapi::Nullability::Nullable}; + + hyper_columns.emplace_back(std::move(column)); + } else { + const auto hypertype = + hyperTypeFromArrowSchema(schema.children[i], &error); + const hyperapi::TableDefinition::Column column{ + col_name, hypertype, hyperapi::Nullability::Nullable}; - hyper_columns.emplace_back(std::move(column)); + hyper_columns.emplace_back(std::move(column)); + } } const hyperapi::TableName table_name{hyper_schema, hyper_table}; @@ -982,7 +997,8 @@ auto read_from_hyper_query(const std::string &path, const std::string &query) NB_MODULE(pantab, m) { // NOLINT m.def("write_to_hyper", &write_to_hyper, nb::arg("dict_of_capsules"), - nb::arg("path"), nb::arg("table_mode")) + nb::arg("path"), nb::arg("table_mode"), nb::arg("json_columns"), + nb::arg("geo_columns")) .def("read_from_hyper_query", &read_from_hyper_query, nb::arg("path"), nb::arg("query")); PyDateTime_IMPORT; diff --git a/pantab/tests/conftest.py b/pantab/tests/conftest.py index dcd8d077..f64f9269 100644 --- a/pantab/tests/conftest.py +++ b/pantab/tests/conftest.py @@ -36,9 +36,11 @@ def basic_arrow_table(): ("float32_limits", pa.float32()), ("float64_limits", pa.float64()), ("non-ascii", pa.utf8()), + ("json", pa.large_string()), ("binary", pa.binary()), ("interval", pa.month_day_nano_interval()), ("time64us", pa.time64("us")), + ("geography", pa.large_binary()), ] ) tbl = pa.Table.from_arrays( @@ -82,6 +84,7 @@ def basic_arrow_table(): pa.array( ["\xef\xff\xdc\xde\xee", "\xfa\xfb\xdd\xaf\xaa", None], ), + pa.array(['{"foo": 42}', '{"bar": -42}', None]), pa.array([b"\xde\xad\xbe\xef", b"\xff\xee", None]), pa.array( [ @@ -91,6 +94,13 @@ def basic_arrow_table(): ] ), pa.array([234, 42, None]), + pa.array( + [ + b"\x07\xaa\x02\xe0%n\xd9\x01\x01\n\x00\xce\xab\xe8\xfa=\xff\x96\xf0\x8a\x9f\x01", + b"\x07\xaa\x02\x0c&n\x82\x01\x01\n\x00\xb0\xe2\xd4\xcc>\xd4\xbc\x97\x88\x0f", + None, + ] + ), ], schema=schema, ) @@ -125,6 +135,7 @@ def basic_dataframe(): -(2**24), -(2**53), "\xef\xff\xdc\xde\xee", + '{"foo": 42}', ], [ 6, @@ -150,6 +161,7 @@ def basic_dataframe(): 2**24 - 1, 2**53 - 1, "\xfa\xfb\xdd\xaf\xaa", + '{"bar": -42}', ], [ 0, @@ -175,6 +187,7 @@ def basic_dataframe(): np.nan, np.nan, np.nan, + pd.NA, ], ], columns=[ @@ -201,6 +214,7 @@ def basic_dataframe(): "float32_limits", "float64_limits", "non-ascii", + "json", ], ) @@ -229,6 +243,7 @@ def basic_dataframe(): "float32_limits": np.float64, "float64_limits": np.float64, "non-ascii": "string", + "json": "string", } ) @@ -248,6 +263,14 @@ def basic_dataframe(): {"col": pa.array([234, 42, None], type=pa.time64("us"))} ) df["time64us"] = df["time64us"].astype("time64[us][pyarrow]") + df["geography"] = pa.array( + [ + b"\x07\xaa\x02\xe0%n\xd9\x01\x01\n\x00\xce\xab\xe8\xfa=\xff\x96\xf0\x8a\x9f\x01", + b"\x07\xaa\x02\x0c&n\x82\x01\x01\n\x00\xb0\xe2\xd4\xcc>\xd4\xbc\x97\x88\x0f", + None, + ] + ) + df["geography"] = df["geography"].astype("large_binary[pyarrow]") return df @@ -293,9 +316,11 @@ def roundtripped_pyarrow(): ("float32_limits", pa.float64()), ("float64_limits", pa.float64()), ("non-ascii", pa.large_string()), + ("json", pa.large_string()), ("binary", pa.large_binary()), ("interval", pa.month_day_nano_interval()), ("time64us", pa.time64("us")), + ("geography", pa.large_binary()), ] ) tbl = basic_arrow_table() @@ -329,10 +354,12 @@ def roundtripped_pandas(): "float32_limits": "double[pyarrow]", "float64_limits": "double[pyarrow]", "non-ascii": "large_string[pyarrow]", + "json": "large_string[pyarrow]", "string": "large_string[pyarrow]", "binary": "large_binary[pyarrow]", # "interval": "month_day_nano_interval[pyarrow]", "time64us": "time64[us][pyarrow]", + "geography": "large_binary[pyarrow]", } ) return df diff --git a/pantab/tests/test_roundtrip.py b/pantab/tests/test_roundtrip.py index f8717dbf..b56379b8 100644 --- a/pantab/tests/test_roundtrip.py +++ b/pantab/tests/test_roundtrip.py @@ -11,8 +11,22 @@ def test_basic(frame, roundtripped, tmp_hyper, table_name, table_mode, compat): expected = compat.drop_columns(expected, ["interval"]) # Write twice; depending on mode this should either overwrite or duplicate entries - pantab.frame_to_hyper(frame, tmp_hyper, table=table_name, table_mode=table_mode) - pantab.frame_to_hyper(frame, tmp_hyper, table=table_name, table_mode=table_mode) + pantab.frame_to_hyper( + frame, + tmp_hyper, + table=table_name, + table_mode=table_mode, + json_columns={"json"}, + geo_columns={"geography"}, + ) + pantab.frame_to_hyper( + frame, + tmp_hyper, + table=table_name, + table_mode=table_mode, + json_columns={"json"}, + geo_columns={"geography"}, + ) result = pantab.frame_from_hyper( tmp_hyper, table=table_name, return_type=return_type @@ -34,10 +48,18 @@ def test_multiple_tables( # Write twice; depending on mode this should either overwrite or duplicate entries pantab.frames_to_hyper( - {table_name: frame, "table2": frame}, tmp_hyper, table_mode=table_mode + {table_name: frame, "table2": frame}, + tmp_hyper, + table_mode=table_mode, + json_columns={"json"}, + geo_columns={"geography"}, ) pantab.frames_to_hyper( - {table_name: frame, "table2": frame}, tmp_hyper, table_mode=table_mode + {table_name: frame, "table2": frame}, + tmp_hyper, + table_mode=table_mode, + json_columns={"json"}, + geo_columns={"geography"}, ) result = pantab.frames_from_hyper(tmp_hyper, return_type=return_type) @@ -65,8 +87,22 @@ def test_empty_roundtrip( # object case is by definition vague, so lets punt that for now frame = compat.drop_columns(frame, ["object"]) empty = compat.empty_like(frame) - pantab.frame_to_hyper(empty, tmp_hyper, table=table_name, table_mode=table_mode) - pantab.frame_to_hyper(empty, tmp_hyper, table=table_name, table_mode=table_mode) + pantab.frame_to_hyper( + empty, + tmp_hyper, + table=table_name, + table_mode=table_mode, + json_columns={"json"}, + geo_columns={"geography"}, + ) + pantab.frame_to_hyper( + empty, + tmp_hyper, + table=table_name, + table_mode=table_mode, + json_columns={"json"}, + geo_columns={"geography"}, + ) result = pantab.frame_from_hyper( tmp_hyper, table=table_name, return_type=return_type diff --git a/pantab/tests/test_writer.py b/pantab/tests/test_writer.py index 0ad11dd5..93dc7807 100644 --- a/pantab/tests/test_writer.py +++ b/pantab/tests/test_writer.py @@ -3,7 +3,14 @@ import pandas as pd import pytest -from tableauhyperapi import Connection, CreateMode, HyperProcess, Telemetry +from tableauhyperapi import ( + Connection, + CreateMode, + HyperProcess, + SqlType, + TableName, + Telemetry, +) import pantab @@ -109,3 +116,39 @@ def test_utc_bug(tmp_hyper): expected: {frame.utc_time} actual: {[c[0] for c in resp]} """ + + +def test_geo_and_json_columns_writes_proper_type(tmp_hyper, frame): + pantab.frame_to_hyper( + frame, + tmp_hyper, + table="test", + ) + + with HyperProcess(Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU) as hyper: + with Connection( + hyper.endpoint, tmp_hyper, CreateMode.CREATE_IF_NOT_EXISTS + ) as connection: + table_def = connection.catalog.get_table_definition(TableName("test")) + json_col = table_def.get_column_by_name("json") + geo_col = table_def.get_column_by_name("geography") + assert json_col.type == SqlType.text() + assert geo_col.type == SqlType.bytes() + + pantab.frame_to_hyper( + frame, + tmp_hyper, + table="test", + json_columns={"json"}, + geo_columns={"geography"}, + ) + + with HyperProcess(Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU) as hyper: + with Connection( + hyper.endpoint, tmp_hyper, CreateMode.CREATE_IF_NOT_EXISTS + ) as connection: + table_def = connection.catalog.get_table_definition(TableName("test")) + json_col = table_def.get_column_by_name("json") + geo_col = table_def.get_column_by_name("geography") + assert json_col.type == SqlType.json() + assert geo_col.type == SqlType.geography()