Skip to content

Commit

Permalink
Allow overrides for JSON/Geo types (#247)
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd authored Jan 28, 2024
1 parent 5f02bfc commit e6d747b
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 22 deletions.
26 changes: 21 additions & 5 deletions pantab/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import shutil
import tempfile
import uuid
from typing import Any, Literal, Optional, Union
from typing import Any, Literal, Union

import tableauhyperapi as tab_api

Expand Down Expand Up @@ -54,24 +54,33 @@ def frame_to_hyper(
*,
table: pantab_types.TableType,
table_mode: Literal["a", "w"] = "w",
json_columns: list[str] = None,
geo_columns: list[str] = None,
) -> None:
"""See api.rst for documentation"""
frames_to_hyper(
{table: df},
database,
table_mode,
table_mode=table_mode,
json_columns=json_columns,
geo_columns=geo_columns,
)


def frames_to_hyper(
dict_of_frames: dict[pantab_types.TableType, Any],
database: Union[str, pathlib.Path],
table_mode: Literal["a", "w"] = "w",
*,
hyper_process: Optional[tab_api.HyperProcess] = None,
table_mode: Literal["a", "w"] = "w",
json_columns: set[str] = None,
geo_columns: set[str] = None,
) -> None:
"""See api.rst for documentation."""
_validate_table_mode(table_mode)
if json_columns is None:
json_columns = set()
if geo_columns is None:
geo_columns = set()

tmp_db = pathlib.Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.hyper"

Expand All @@ -89,7 +98,14 @@ def convert_to_table_name(table: pantab_types.TableType):
convert_to_table_name(key): _get_capsule_from_obj(val)
for key, val in dict_of_frames.items()
}
libpantab.write_to_hyper(data, path=str(tmp_db), table_mode=table_mode)

libpantab.write_to_hyper(
data,
path=str(tmp_db),
table_mode=table_mode,
json_columns=json_columns,
geo_columns=geo_columns,
)

# In Python 3.9+ we can just pass the path object, but due to bpo 32689
# and subsequent typeshed changes it is easier to just pass as str for now
Expand Down
36 changes: 26 additions & 10 deletions pantab/src/pantab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <nanoarrow/nanoarrow.hpp>
#include <nanobind/nanobind.h>
#include <nanobind/stl/map.h>
#include <nanobind/stl/set.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/tuple.h>

Expand Down Expand Up @@ -492,7 +493,9 @@ using SchemaAndTableName = std::tuple<std::string, std::string>;

void write_to_hyper(
const std::map<SchemaAndTableName, nb::capsule> &dict_of_capsules,
const std::string &path, const std::string &table_mode) {
const std::string &path, const std::string &table_mode,
const std::set<std::string> &json_columns,
const std::set<std::string> &geo_columns) {
const std::unordered_map<std::string, std::string> params = {
{"log_config", ""}};
const hyperapi::HyperProcess hyper{
Expand Down Expand Up @@ -524,15 +527,27 @@ void write_to_hyper(
struct ArrowError error;
std::vector<hyperapi::TableDefinition::Column> hyper_columns;
for (int64_t i = 0; i < schema.n_children; i++) {
const auto hypertype =
hyperTypeFromArrowSchema(schema.children[i], &error);

// Almost all arrow types are nullable
const hyperapi::TableDefinition::Column column{
std::string(schema.children[i]->name), hypertype,
hyperapi::Nullability::Nullable};
const auto col_name = std::string{schema.children[i]->name};
if (json_columns.find(col_name) != json_columns.end()) {
const auto hypertype = hyperapi::SqlType::json();
const hyperapi::TableDefinition::Column column{
col_name, hypertype, hyperapi::Nullability::Nullable};

hyper_columns.emplace_back(std::move(column));
} else if (geo_columns.find(col_name) != geo_columns.end()) {
const auto hypertype = hyperapi::SqlType::geography();
const hyperapi::TableDefinition::Column column{
col_name, hypertype, hyperapi::Nullability::Nullable};

hyper_columns.emplace_back(std::move(column));
} else {
const auto hypertype =
hyperTypeFromArrowSchema(schema.children[i], &error);
const hyperapi::TableDefinition::Column column{
col_name, hypertype, hyperapi::Nullability::Nullable};

hyper_columns.emplace_back(std::move(column));
hyper_columns.emplace_back(std::move(column));
}
}

const hyperapi::TableName table_name{hyper_schema, hyper_table};
Expand Down Expand Up @@ -982,7 +997,8 @@ auto read_from_hyper_query(const std::string &path, const std::string &query)

NB_MODULE(pantab, m) { // NOLINT
m.def("write_to_hyper", &write_to_hyper, nb::arg("dict_of_capsules"),
nb::arg("path"), nb::arg("table_mode"))
nb::arg("path"), nb::arg("table_mode"), nb::arg("json_columns"),
nb::arg("geo_columns"))
.def("read_from_hyper_query", &read_from_hyper_query, nb::arg("path"),
nb::arg("query"));
PyDateTime_IMPORT;
Expand Down
27 changes: 27 additions & 0 deletions pantab/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ def basic_arrow_table():
("float32_limits", pa.float32()),
("float64_limits", pa.float64()),
("non-ascii", pa.utf8()),
("json", pa.large_string()),
("binary", pa.binary()),
("interval", pa.month_day_nano_interval()),
("time64us", pa.time64("us")),
("geography", pa.large_binary()),
]
)
tbl = pa.Table.from_arrays(
Expand Down Expand Up @@ -82,6 +84,7 @@ def basic_arrow_table():
pa.array(
["\xef\xff\xdc\xde\xee", "\xfa\xfb\xdd\xaf\xaa", None],
),
pa.array(['{"foo": 42}', '{"bar": -42}', None]),
pa.array([b"\xde\xad\xbe\xef", b"\xff\xee", None]),
pa.array(
[
Expand All @@ -91,6 +94,13 @@ def basic_arrow_table():
]
),
pa.array([234, 42, None]),
pa.array(
[
b"\x07\xaa\x02\xe0%n\xd9\x01\x01\n\x00\xce\xab\xe8\xfa=\xff\x96\xf0\x8a\x9f\x01",
b"\x07\xaa\x02\x0c&n\x82\x01\x01\n\x00\xb0\xe2\xd4\xcc>\xd4\xbc\x97\x88\x0f",
None,
]
),
],
schema=schema,
)
Expand Down Expand Up @@ -125,6 +135,7 @@ def basic_dataframe():
-(2**24),
-(2**53),
"\xef\xff\xdc\xde\xee",
'{"foo": 42}',
],
[
6,
Expand All @@ -150,6 +161,7 @@ def basic_dataframe():
2**24 - 1,
2**53 - 1,
"\xfa\xfb\xdd\xaf\xaa",
'{"bar": -42}',
],
[
0,
Expand All @@ -175,6 +187,7 @@ def basic_dataframe():
np.nan,
np.nan,
np.nan,
pd.NA,
],
],
columns=[
Expand All @@ -201,6 +214,7 @@ def basic_dataframe():
"float32_limits",
"float64_limits",
"non-ascii",
"json",
],
)

Expand Down Expand Up @@ -229,6 +243,7 @@ def basic_dataframe():
"float32_limits": np.float64,
"float64_limits": np.float64,
"non-ascii": "string",
"json": "string",
}
)

Expand All @@ -248,6 +263,14 @@ def basic_dataframe():
{"col": pa.array([234, 42, None], type=pa.time64("us"))}
)
df["time64us"] = df["time64us"].astype("time64[us][pyarrow]")
df["geography"] = pa.array(
[
b"\x07\xaa\x02\xe0%n\xd9\x01\x01\n\x00\xce\xab\xe8\xfa=\xff\x96\xf0\x8a\x9f\x01",
b"\x07\xaa\x02\x0c&n\x82\x01\x01\n\x00\xb0\xe2\xd4\xcc>\xd4\xbc\x97\x88\x0f",
None,
]
)
df["geography"] = df["geography"].astype("large_binary[pyarrow]")

return df

Expand Down Expand Up @@ -293,9 +316,11 @@ def roundtripped_pyarrow():
("float32_limits", pa.float64()),
("float64_limits", pa.float64()),
("non-ascii", pa.large_string()),
("json", pa.large_string()),
("binary", pa.large_binary()),
("interval", pa.month_day_nano_interval()),
("time64us", pa.time64("us")),
("geography", pa.large_binary()),
]
)
tbl = basic_arrow_table()
Expand Down Expand Up @@ -329,10 +354,12 @@ def roundtripped_pandas():
"float32_limits": "double[pyarrow]",
"float64_limits": "double[pyarrow]",
"non-ascii": "large_string[pyarrow]",
"json": "large_string[pyarrow]",
"string": "large_string[pyarrow]",
"binary": "large_binary[pyarrow]",
# "interval": "month_day_nano_interval[pyarrow]",
"time64us": "time64[us][pyarrow]",
"geography": "large_binary[pyarrow]",
}
)
return df
Expand Down
48 changes: 42 additions & 6 deletions pantab/tests/test_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,22 @@ def test_basic(frame, roundtripped, tmp_hyper, table_name, table_mode, compat):
expected = compat.drop_columns(expected, ["interval"])

# Write twice; depending on mode this should either overwrite or duplicate entries
pantab.frame_to_hyper(frame, tmp_hyper, table=table_name, table_mode=table_mode)
pantab.frame_to_hyper(frame, tmp_hyper, table=table_name, table_mode=table_mode)
pantab.frame_to_hyper(
frame,
tmp_hyper,
table=table_name,
table_mode=table_mode,
json_columns={"json"},
geo_columns={"geography"},
)
pantab.frame_to_hyper(
frame,
tmp_hyper,
table=table_name,
table_mode=table_mode,
json_columns={"json"},
geo_columns={"geography"},
)

result = pantab.frame_from_hyper(
tmp_hyper, table=table_name, return_type=return_type
Expand All @@ -34,10 +48,18 @@ def test_multiple_tables(

# Write twice; depending on mode this should either overwrite or duplicate entries
pantab.frames_to_hyper(
{table_name: frame, "table2": frame}, tmp_hyper, table_mode=table_mode
{table_name: frame, "table2": frame},
tmp_hyper,
table_mode=table_mode,
json_columns={"json"},
geo_columns={"geography"},
)
pantab.frames_to_hyper(
{table_name: frame, "table2": frame}, tmp_hyper, table_mode=table_mode
{table_name: frame, "table2": frame},
tmp_hyper,
table_mode=table_mode,
json_columns={"json"},
geo_columns={"geography"},
)

result = pantab.frames_from_hyper(tmp_hyper, return_type=return_type)
Expand Down Expand Up @@ -65,8 +87,22 @@ def test_empty_roundtrip(
# object case is by definition vague, so lets punt that for now
frame = compat.drop_columns(frame, ["object"])
empty = compat.empty_like(frame)
pantab.frame_to_hyper(empty, tmp_hyper, table=table_name, table_mode=table_mode)
pantab.frame_to_hyper(empty, tmp_hyper, table=table_name, table_mode=table_mode)
pantab.frame_to_hyper(
empty,
tmp_hyper,
table=table_name,
table_mode=table_mode,
json_columns={"json"},
geo_columns={"geography"},
)
pantab.frame_to_hyper(
empty,
tmp_hyper,
table=table_name,
table_mode=table_mode,
json_columns={"json"},
geo_columns={"geography"},
)

result = pantab.frame_from_hyper(
tmp_hyper, table=table_name, return_type=return_type
Expand Down
45 changes: 44 additions & 1 deletion pantab/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@

import pandas as pd
import pytest
from tableauhyperapi import Connection, CreateMode, HyperProcess, Telemetry
from tableauhyperapi import (
Connection,
CreateMode,
HyperProcess,
SqlType,
TableName,
Telemetry,
)

import pantab

Expand Down Expand Up @@ -109,3 +116,39 @@ def test_utc_bug(tmp_hyper):
expected: {frame.utc_time}
actual: {[c[0] for c in resp]}
"""


def test_geo_and_json_columns_writes_proper_type(tmp_hyper, frame):
pantab.frame_to_hyper(
frame,
tmp_hyper,
table="test",
)

with HyperProcess(Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU) as hyper:
with Connection(
hyper.endpoint, tmp_hyper, CreateMode.CREATE_IF_NOT_EXISTS
) as connection:
table_def = connection.catalog.get_table_definition(TableName("test"))
json_col = table_def.get_column_by_name("json")
geo_col = table_def.get_column_by_name("geography")
assert json_col.type == SqlType.text()
assert geo_col.type == SqlType.bytes()

pantab.frame_to_hyper(
frame,
tmp_hyper,
table="test",
json_columns={"json"},
geo_columns={"geography"},
)

with HyperProcess(Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU) as hyper:
with Connection(
hyper.endpoint, tmp_hyper, CreateMode.CREATE_IF_NOT_EXISTS
) as connection:
table_def = connection.catalog.get_table_definition(TableName("test"))
json_col = table_def.get_column_by_name("json")
geo_col = table_def.get_column_by_name("geography")
assert json_col.type == SqlType.json()
assert geo_col.type == SqlType.geography()

0 comments on commit e6d747b

Please sign in to comment.