Skip to content

Commit

Permalink
Remove runtime dependency on tableauhyperapi
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd committed Sep 20, 2024
1 parent 19f2237 commit 177606e
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 40 deletions.
9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ dependencies = [
# in the future we need not require pyarrow as pandas implements the
# PyCapsule interface. See pandas PR #56587
"pyarrow>=14.0.0",
"tableauhyperapi>=0.0.19691",
]

[project.urls]
Expand Down Expand Up @@ -68,7 +67,13 @@ build = "cp39-*64 cp310-*64 cp311-*64 cp312-*64"
skip = "*musllinux*"

test-command = "python -m pytest {project}/tests"
test-requires = ["pytest", "pandas>=2.0.0", "polars~=1.2.0", "narwhals"]
test-requires = [
"pytest",
"pandas>=2.0.0",
"polars~=1.2.0",
"narwhals",
"tableauhyperapi",
]

[tool.ruff]
line-length = 88
Expand Down
23 changes: 5 additions & 18 deletions src/pantab/_reader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import pathlib
import shutil
import tempfile
from typing import Literal, Optional, Union

import pyarrow as pa
import tableauhyperapi as tab_api

import pantab._types as pantab_types
import pantab._types as pt_types
import pantab.libpantab as libpantab


Expand Down Expand Up @@ -43,13 +40,13 @@ def frame_from_hyper_query(
def frame_from_hyper(
source: Union[str, pathlib.Path],
*,
table: pantab_types.TableNameType,
table: pt_types.TableNameType,
return_type: Literal["pandas", "polars", "pyarrow"] = "pandas",
process_params: Optional[dict[str, str]] = None,
):
"""See api.rst for documentation"""
if isinstance(table, (str, tab_api.Name)) or not table.schema_name:
table = tab_api.TableName("public", table)
if "." not in str(table):
table = f'"public".{table}'

query = f"SELECT * FROM {table}"
return frame_from_hyper_query(
Expand All @@ -65,17 +62,7 @@ def frames_from_hyper(
"""See api.rst for documentation."""
result = {}

table_names = []
with tempfile.TemporaryDirectory() as tmp_dir, tab_api.HyperProcess(
tab_api.Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU,
parameters={"log_config": ""},
) as hpe:
tmp_db = shutil.copy(source, tmp_dir)
with tab_api.Connection(hpe.endpoint, tmp_db) as connection:
for schema in connection.catalog.get_schema_names():
for table in connection.catalog.get_table_names(schema=schema):
table_names.append(table)

table_names = libpantab.get_table_names(str(source))
for table in table_names:
result[table] = frame_from_hyper(
source=source,
Expand Down
23 changes: 20 additions & 3 deletions src/pantab/_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
from typing import Union
from typing import Optional, Protocol, Union, runtime_checkable

import tableauhyperapi as tab_api

TableNameType = Union[str, tab_api.Name, tab_api.TableName]
@runtime_checkable
class TableauName(Protocol):
@property
def unescaped(self) -> str:
...


@runtime_checkable
class TableauTableName(Protocol):
@property
def name(self) -> TableauName:
...

@property
def schema_name(self) -> Optional[TableauName]:
...


TableNameType = Union[str, TableauName, TableauTableName]
23 changes: 13 additions & 10 deletions src/pantab/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
import uuid
from typing import Any, Literal, Optional, Union

import tableauhyperapi as tab_api

import pantab._types as pantab_types
import pantab._types as pt_types
import pantab.libpantab as libpantab


Expand Down Expand Up @@ -52,7 +50,7 @@ def frame_to_hyper(
df,
database: Union[str, pathlib.Path],
*,
table: pantab_types.TableNameType,
table: pt_types.TableNameType,
table_mode: Literal["a", "w"] = "w",
not_null_columns: Optional[set[str]] = None,
json_columns: Optional[set[str]] = None,
Expand All @@ -72,7 +70,7 @@ def frame_to_hyper(


def frames_to_hyper(
dict_of_frames: dict[pantab_types.TableNameType, Any],
dict_of_frames: dict[pt_types.TableNameType, Any],
database: Union[str, pathlib.Path],
*,
table_mode: Literal["a", "w"] = "w",
Expand All @@ -98,12 +96,17 @@ def frames_to_hyper(
if table_mode == "a" and pathlib.Path(database).exists():
shutil.copy(database, tmp_db)

def convert_to_table_name(table: pantab_types.TableNameType):
def convert_to_table_name(table: pt_types.TableNameType):
# nanobind expects a tuple of (schema, table) strings
if isinstance(table, (str, tab_api.Name)) or not table.schema_name:
table = tab_api.TableName("public", table)

return (table.schema_name.name.unescaped, table.name.unescaped)
if isinstance(table, pt_types.TableauTableName):
if table.schema_name:
return (table.schema_name.name.unescaped, table.name.unescaped)
else:
return ("public", table.name.unescaped)
elif isinstance(table, pt_types.TableauName):
return ("public", table.unescaped)

return ("public", table)

data = {
convert_to_table_name(key): _get_capsule_from_obj(val)
Expand Down
42 changes: 36 additions & 6 deletions src/pantab/libpantab.cpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,47 @@
#include <unordered_map>

#include <hyperapi/hyperapi.hpp>
#include <nanobind/nanobind.h>
#include <nanobind/stl/vector.h>

#include "datetime.h"
#include "reader.hpp"
#include "writer.hpp"

namespace nb = nanobind;

NB_MODULE(libpantab, m) { // NOLINT
m.def("write_to_hyper", &write_to_hyper, nb::arg("dict_of_capsules"),
nb::arg("path"), nb::arg("table_mode"), nb::arg("not_null_columns"),
nb::arg("json_columns"), nb::arg("geo_columns"),
nb::arg("process_params"))
m.def(
"get_table_names",
[](const std::string &path) {
std::unordered_map<std::string, std::string> params{
{"log_config", ""}};
const hyperapi::HyperProcess hyper{
hyperapi::Telemetry::DoNotSendUsageDataToTableau, "",
std::move(params)};
hyperapi::Connection connection(hyper.getEndpoint(), path);

std::vector<std::string> result;
for (const auto &schema_name :
connection.getCatalog().getSchemaNames()) {
for (const auto &table_name :
connection.getCatalog().getTableNames(schema_name)) {
const auto schema_prefix = table_name.getSchemaName();
if (schema_prefix) {
result.emplace_back(schema_prefix->getName().getUnescaped() +
"." + table_name.getName().getUnescaped());
} else {
result.emplace_back(table_name.getName().getUnescaped());
}
}
}

return result;
},
nb::arg("path"))
.def("write_to_hyper", &write_to_hyper, nb::arg("dict_of_capsules"),
nb::arg("path"), nb::arg("table_mode"), nb::arg("not_null_columns"),
nb::arg("json_columns"), nb::arg("geo_columns"),
nb::arg("process_params"))
.def("read_from_hyper_query", &read_from_hyper_query, nb::arg("path"),
nb::arg("query"), nb::arg("process_params"));
PyDateTime_IMPORT;
}
7 changes: 6 additions & 1 deletion tests/test_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,12 @@ def test_multiple_tables(
if not isinstance(table_name, TableName) or table_name.schema_name is None:
table_name = TableName("public", table_name)

assert set(result.keys()) == set((table_name, TableName("public", "table2")))
assert set(result.keys()) == set(
(
".".join(table_name._unescaped_components),
".".join(TableName("public", "table2")._unescaped_components),
)
)
for val in result.values():
compat.assert_frame_equal(val, expected)

Expand Down

0 comments on commit 177606e

Please sign in to comment.