From 177606e026a9219d33233b96e2fa5729381b1abe Mon Sep 17 00:00:00 2001 From: Will Ayd Date: Thu, 19 Sep 2024 23:57:16 -0400 Subject: [PATCH] Remove runtime dependency on tableauhyperapi --- pyproject.toml | 9 +++++++-- src/pantab/_reader.py | 23 +++++----------------- src/pantab/_types.py | 23 +++++++++++++++++++--- src/pantab/_writer.py | 23 ++++++++++++---------- src/pantab/libpantab.cpp | 42 ++++++++++++++++++++++++++++++++++------ tests/test_roundtrip.py | 7 ++++++- 6 files changed, 87 insertions(+), 40 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3b7eeea5..e1f8a2a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,6 @@ dependencies = [ # in the future we need not require pyarrow as pandas implements the # PyCapsule interface. See pandas PR #56587 "pyarrow>=14.0.0", - "tableauhyperapi>=0.0.19691", ] [project.urls] @@ -68,7 +67,13 @@ build = "cp39-*64 cp310-*64 cp311-*64 cp312-*64" skip = "*musllinux*" test-command = "python -m pytest {project}/tests" -test-requires = ["pytest", "pandas>=2.0.0", "polars~=1.2.0", "narwhals"] +test-requires = [ + "pytest", + "pandas>=2.0.0", + "polars~=1.2.0", + "narwhals", + "tableauhyperapi", +] [tool.ruff] line-length = 88 diff --git a/src/pantab/_reader.py b/src/pantab/_reader.py index 59597538..ff85edc0 100644 --- a/src/pantab/_reader.py +++ b/src/pantab/_reader.py @@ -1,12 +1,9 @@ import pathlib -import shutil -import tempfile from typing import Literal, Optional, Union import pyarrow as pa -import tableauhyperapi as tab_api -import pantab._types as pantab_types +import pantab._types as pt_types import pantab.libpantab as libpantab @@ -43,13 +40,13 @@ def frame_from_hyper_query( def frame_from_hyper( source: Union[str, pathlib.Path], *, - table: pantab_types.TableNameType, + table: pt_types.TableNameType, return_type: Literal["pandas", "polars", "pyarrow"] = "pandas", process_params: Optional[dict[str, str]] = None, ): """See api.rst for documentation""" - if isinstance(table, (str, tab_api.Name)) or not table.schema_name: - table = tab_api.TableName("public", table) + if "." not in str(table): + table = f'"public".{table}' query = f"SELECT * FROM {table}" return frame_from_hyper_query( @@ -65,17 +62,7 @@ def frames_from_hyper( """See api.rst for documentation.""" result = {} - table_names = [] - with tempfile.TemporaryDirectory() as tmp_dir, tab_api.HyperProcess( - tab_api.Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU, - parameters={"log_config": ""}, - ) as hpe: - tmp_db = shutil.copy(source, tmp_dir) - with tab_api.Connection(hpe.endpoint, tmp_db) as connection: - for schema in connection.catalog.get_schema_names(): - for table in connection.catalog.get_table_names(schema=schema): - table_names.append(table) - + table_names = libpantab.get_table_names(str(source)) for table in table_names: result[table] = frame_from_hyper( source=source, diff --git a/src/pantab/_types.py b/src/pantab/_types.py index 102fae12..fb10e42b 100644 --- a/src/pantab/_types.py +++ b/src/pantab/_types.py @@ -1,5 +1,22 @@ -from typing import Union +from typing import Optional, Protocol, Union, runtime_checkable -import tableauhyperapi as tab_api -TableNameType = Union[str, tab_api.Name, tab_api.TableName] +@runtime_checkable +class TableauName(Protocol): + @property + def unescaped(self) -> str: + ... + + +@runtime_checkable +class TableauTableName(Protocol): + @property + def name(self) -> TableauName: + ... + + @property + def schema_name(self) -> Optional[TableauName]: + ... + + +TableNameType = Union[str, TableauName, TableauTableName] diff --git a/src/pantab/_writer.py b/src/pantab/_writer.py index a44bb703..ce009b6e 100644 --- a/src/pantab/_writer.py +++ b/src/pantab/_writer.py @@ -4,9 +4,7 @@ import uuid from typing import Any, Literal, Optional, Union -import tableauhyperapi as tab_api - -import pantab._types as pantab_types +import pantab._types as pt_types import pantab.libpantab as libpantab @@ -52,7 +50,7 @@ def frame_to_hyper( df, database: Union[str, pathlib.Path], *, - table: pantab_types.TableNameType, + table: pt_types.TableNameType, table_mode: Literal["a", "w"] = "w", not_null_columns: Optional[set[str]] = None, json_columns: Optional[set[str]] = None, @@ -72,7 +70,7 @@ def frame_to_hyper( def frames_to_hyper( - dict_of_frames: dict[pantab_types.TableNameType, Any], + dict_of_frames: dict[pt_types.TableNameType, Any], database: Union[str, pathlib.Path], *, table_mode: Literal["a", "w"] = "w", @@ -98,12 +96,17 @@ def frames_to_hyper( if table_mode == "a" and pathlib.Path(database).exists(): shutil.copy(database, tmp_db) - def convert_to_table_name(table: pantab_types.TableNameType): + def convert_to_table_name(table: pt_types.TableNameType): # nanobind expects a tuple of (schema, table) strings - if isinstance(table, (str, tab_api.Name)) or not table.schema_name: - table = tab_api.TableName("public", table) - - return (table.schema_name.name.unescaped, table.name.unescaped) + if isinstance(table, pt_types.TableauTableName): + if table.schema_name: + return (table.schema_name.name.unescaped, table.name.unescaped) + else: + return ("public", table.name.unescaped) + elif isinstance(table, pt_types.TableauName): + return ("public", table.unescaped) + + return ("public", table) data = { convert_to_table_name(key): _get_capsule_from_obj(val) diff --git a/src/pantab/libpantab.cpp b/src/pantab/libpantab.cpp index 5f31b8cd..b9825c34 100644 --- a/src/pantab/libpantab.cpp +++ b/src/pantab/libpantab.cpp @@ -1,17 +1,47 @@ +#include + +#include #include +#include -#include "datetime.h" #include "reader.hpp" #include "writer.hpp" namespace nb = nanobind; NB_MODULE(libpantab, m) { // NOLINT - m.def("write_to_hyper", &write_to_hyper, nb::arg("dict_of_capsules"), - nb::arg("path"), nb::arg("table_mode"), nb::arg("not_null_columns"), - nb::arg("json_columns"), nb::arg("geo_columns"), - nb::arg("process_params")) + m.def( + "get_table_names", + [](const std::string &path) { + std::unordered_map params{ + {"log_config", ""}}; + const hyperapi::HyperProcess hyper{ + hyperapi::Telemetry::DoNotSendUsageDataToTableau, "", + std::move(params)}; + hyperapi::Connection connection(hyper.getEndpoint(), path); + + std::vector result; + for (const auto &schema_name : + connection.getCatalog().getSchemaNames()) { + for (const auto &table_name : + connection.getCatalog().getTableNames(schema_name)) { + const auto schema_prefix = table_name.getSchemaName(); + if (schema_prefix) { + result.emplace_back(schema_prefix->getName().getUnescaped() + + "." + table_name.getName().getUnescaped()); + } else { + result.emplace_back(table_name.getName().getUnescaped()); + } + } + } + + return result; + }, + nb::arg("path")) + .def("write_to_hyper", &write_to_hyper, nb::arg("dict_of_capsules"), + nb::arg("path"), nb::arg("table_mode"), nb::arg("not_null_columns"), + nb::arg("json_columns"), nb::arg("geo_columns"), + nb::arg("process_params")) .def("read_from_hyper_query", &read_from_hyper_query, nb::arg("path"), nb::arg("query"), nb::arg("process_params")); - PyDateTime_IMPORT; } diff --git a/tests/test_roundtrip.py b/tests/test_roundtrip.py index 4c61ad7d..7b9e9203 100644 --- a/tests/test_roundtrip.py +++ b/tests/test_roundtrip.py @@ -73,7 +73,12 @@ def test_multiple_tables( if not isinstance(table_name, TableName) or table_name.schema_name is None: table_name = TableName("public", table_name) - assert set(result.keys()) == set((table_name, TableName("public", "table2"))) + assert set(result.keys()) == set( + ( + ".".join(table_name._unescaped_components), + ".".join(TableName("public", "table2")._unescaped_components), + ) + ) for val in result.values(): compat.assert_frame_equal(val, expected)