Skip to content

Commit

Permalink
Pyarrow directly (#242)
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd authored Jan 26, 2024
1 parent f509305 commit 1bb35c6
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 69 deletions.
40 changes: 31 additions & 9 deletions pantab/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,49 @@
import shutil
import tempfile
import uuid
from typing import Optional, Union
from typing import Any, Literal, Optional, Union

import pandas as pd
import pyarrow as pa
import tableauhyperapi as tab_api

import pantab._types as pantab_types
import pantab.src.pantab as libpantab # type: ignore


def _validate_table_mode(table_mode: str) -> None:
def _validate_table_mode(table_mode: Literal["a", "w"]) -> None:
if table_mode not in {"a", "w"}:
raise ValueError("'table_mode' must be either 'w' or 'a'")


def _get_capsule_from_obj(obj):
"""Returns the Arrow capsule underlying obj"""
# Check first for the Arrow C Data Interface compliance
if hasattr(obj, "__arrow_c_stream__"):
return obj.__arrow_c_stream__()

# pandas < 3.0 did not have the Arrow C Data Interface, so
# convert to PyArrow
try:
import pandas as pd
import pyarrow as pa

if isinstance(obj, pd.DataFrame):
return pa.Table.from_pandas(obj).__arrow_c_stream__()
except ModuleNotFoundError:
pass

# More introspection could happen in the future...but end with TypeError if we
# can not find what we are looking for
raise TypeError(
f"Could not convert object of type '{type(obj)}' to Arrow C Data Interface"
)


def frame_to_hyper(
df: pd.DataFrame,
df,
database: Union[str, pathlib.Path],
*,
table: pantab_types.TableType,
table_mode: str = "w",
table_mode: Literal["a", "w"] = "w",
) -> None:
"""See api.rst for documentation"""
frames_to_hyper(
Expand All @@ -33,9 +55,9 @@ def frame_to_hyper(


def frames_to_hyper(
dict_of_frames: dict[pantab_types.TableType, pd.DataFrame],
dict_of_frames: dict[pantab_types.TableType, Any],
database: Union[str, pathlib.Path],
table_mode: str = "w",
table_mode: Literal["a", "w"] = "w",
*,
hyper_process: Optional[tab_api.HyperProcess] = None,
) -> None:
Expand All @@ -55,7 +77,7 @@ def convert_to_table_name(table: pantab_types.TableType):
return (table.schema_name.name.unescaped, table.name.unescaped)

data = {
convert_to_table_name(key): pa.Table.from_pandas(val)
convert_to_table_name(key): _get_capsule_from_obj(val)
for key, val in dict_of_frames.items()
}
libpantab.write_to_hyper(data, path=str(tmp_db), table_mode=table_mode)
Expand Down
13 changes: 4 additions & 9 deletions pantab/src/pantab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ void assertColumnsEqual(
using SchemaAndTableName = std::tuple<std::string, std::string>;

void write_to_hyper(
const std::map<SchemaAndTableName, nb::object> &dict_of_exportable,
const std::map<SchemaAndTableName, nb::capsule> &dict_of_capsules,
const std::string &path, const std::string &table_mode) {
const hyperapi::HyperProcess hyper{
hyperapi::Telemetry::DoNotSendUsageDataToTableau};
Expand All @@ -432,17 +432,12 @@ void write_to_hyper(
hyperapi::Connection connection{hyper.getEndpoint(), path, createMode};
const hyperapi::Catalog &catalog = connection.getCatalog();

for (auto const &[schema_and_table, exportable] : dict_of_exportable) {
for (auto const &[schema_and_table, capsule] : dict_of_capsules) {
const auto hyper_schema = std::get<0>(schema_and_table);
const auto hyper_table = std::get<1>(schema_and_table);
const auto arrow_c_stream = nb::getattr(exportable, "__arrow_c_stream__")();

PyObject *obj = arrow_c_stream.ptr();
if (!PyCapsule_CheckExact(obj)) {
throw std::invalid_argument("Object does not provide capsule");
}
const auto c_stream = static_cast<struct ArrowArrayStream *>(
PyCapsule_GetPointer(obj, "arrow_array_stream"));
PyCapsule_GetPointer(capsule.ptr(), "arrow_array_stream"));
auto stream = nanoarrow::UniqueArrayStream{c_stream};

struct ArrowSchema schema;
Expand Down Expand Up @@ -876,7 +871,7 @@ auto read_from_hyper_query(const std::string &path, const std::string &query)
}

NB_MODULE(pantab, m) { // NOLINT
m.def("write_to_hyper", &write_to_hyper, nb::arg("dict_of_exportable"),
m.def("write_to_hyper", &write_to_hyper, nb::arg("dict_of_capsules"),
nb::arg("path"), nb::arg("table_mode"))
.def("read_from_hyper_query", &read_from_hyper_query, nb::arg("path"),
nb::arg("query"));
Expand Down
97 changes: 90 additions & 7 deletions pantab/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import pathlib

import numpy as np
Expand All @@ -7,7 +8,87 @@
import tableauhyperapi as tab_api


def get_basic_dataframe():
def basic_arrow_table():
tbl = pa.Table.from_arrays(
[
pa.array([1, 6, 0], type=pa.int16()),
pa.array([2, 7, 0], type=pa.int32()),
pa.array([3, 8, 0], type=pa.int64()),
pa.array([1, None, None], type=pa.int16()),
pa.array([2, None, None], type=pa.int32()),
pa.array([3, None, None], type=pa.int64()),
pa.array([4, 9.0, None], type=pa.float32()),
pa.array([5, 10.0, None], type=pa.float64()),
pa.array([1.0, 1.0, None], type=pa.float32()),
pa.array([2.0, 2.0, None], type=pa.float64()),
pa.array([True, False, False], type=pa.bool_()),
pa.array([True, False, None], type=pa.bool_()),
pa.array(
[datetime.date(2024, 1, 1), datetime.date(2024, 1, 1), None],
type=pa.date32(),
),
pa.array(
[
datetime.datetime(2018, 1, 1, 0, 0, 0),
datetime.datetime(2019, 1, 1, 0, 0, 0),
None,
],
type=pa.timestamp("us"),
),
pa.array(
[
datetime.datetime(2018, 1, 1, 0, 0, 0),
datetime.datetime(2019, 1, 1, 0, 0, 0),
None,
],
type=pa.timestamp("us", "utc"),
),
pa.array(["foo", "bar", None], type=pa.large_string()),
pa.array(["foo", "bar", None], type=pa.string()),
pa.array([-(2**15), 2**15 - 1, 0], type=pa.int16()),
pa.array([-(2**31), 2**31 - 1, 0], type=pa.int32()),
pa.array([-(2**63), 2**63 - 1, 0], type=pa.int64()),
pa.array([-(2**24), 2**24 - 1, None], type=pa.float32()),
pa.array([-(2**53), 2**53 - 1, None], type=pa.float64()),
pa.array(
["\xef\xff\xdc\xde\xee", "\xfa\xfb\xdd\xaf\xaa", None], type=pa.utf8()
),
pa.array([b"\xde\xad\xbe\xef", b"\xff\xee", None], type=pa.binary()),
pa.array([234, 42, None], type=pa.time64("us")),
],
names=[
"int16",
"int32",
"int64",
"Int16",
"Int32",
"Int64",
"float32",
"float64",
"Float32",
"Float64",
"bool",
"boolean",
"date32",
"datetime64",
"datetime64_utc",
"object",
"string",
"int16_limits",
"int32_limits",
"int64_limits",
"float32_limits",
"float64_limits",
"non-ascii",
"binary",
"time64us",
],
)

return tbl


def basic_dataframe():
df = pd.DataFrame(
[
[
Expand Down Expand Up @@ -49,7 +130,7 @@ def get_basic_dataframe():
False,
False,
pd.to_datetime("2024-01-01"),
pd.to_datetime("1/1/19"),
pd.to_datetime("2019-01-01"),
pd.to_datetime("2019-01-01", utc=True),
"bar",
"bar",
Expand Down Expand Up @@ -144,22 +225,24 @@ def get_basic_dataframe():
# See pandas GH issue #56994
df["binary"] = pa.array([b"\xde\xad\xbe\xef", b"\xff\xee", None], type=pa.binary())
df["binary"] = df["binary"].astype("binary[pyarrow]")
df["time64us"] = pd.DataFrame({"col": pa.array([234, 42], type=pa.time64("us"))})
df["time64us"] = pd.DataFrame(
{"col": pa.array([234, 42, None], type=pa.time64("us"))}
)
df["time64us"] = df["time64us"].astype("time64[us][pyarrow]")

return df


@pytest.fixture
def df():
@pytest.fixture(params=[basic_arrow_table, basic_dataframe])
def frame(request):
"""Fixture to use which should contain all data types."""
return get_basic_dataframe()
return request.param()


@pytest.fixture
def roundtripped():
"""Roundtripped DataFrames should use arrow dtypes by default"""
df = get_basic_dataframe()
df = basic_dataframe()
df = df.astype(
{
"int16": "int16[pyarrow]",
Expand Down
18 changes: 11 additions & 7 deletions pantab/tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

import pandas as pd
import pandas.testing as tm
import pytest
import tableauhyperapi as tab_api

import pantab


def test_read_doesnt_modify_existing_file(df, tmp_hyper):
pantab.frame_to_hyper(df, tmp_hyper, table="test")
def test_read_doesnt_modify_existing_file(frame, tmp_hyper):
pantab.frame_to_hyper(frame, tmp_hyper, table="test")
last_modified = tmp_hyper.stat().st_mtime

# Try out our read methods
Expand Down Expand Up @@ -51,8 +52,8 @@ def test_reads_non_writeable(datapath):
tm.assert_frame_equal(result, expected)


def test_read_query(df, tmp_hyper):
pantab.frame_to_hyper(df, tmp_hyper, table="test")
def test_read_query(frame, tmp_hyper):
pantab.frame_to_hyper(frame, tmp_hyper, table="test")

query = "SELECT int16 AS i, '_' || int32 AS _i2 FROM test"
result = pantab.frame_from_hyper_query(tmp_hyper, query)
Expand All @@ -63,15 +64,18 @@ def test_read_query(df, tmp_hyper):
tm.assert_frame_equal(result, expected)


def test_empty_read_query(df: pd.DataFrame, roundtripped, tmp_hyper):
def test_empty_read_query(frame, roundtripped, tmp_hyper):
"""
red-green for empty query results
"""
# sql cols need to base case insensitive & unique
table_name = "test"
pantab.frame_to_hyper(df, tmp_hyper, table=table_name)
pantab.frame_to_hyper(frame, tmp_hyper, table=table_name)
query = f"SELECT * FROM {table_name} limit 0"
expected = pd.DataFrame(columns=df.columns)

if not isinstance(frame, pd.DataFrame):
pytest.skip("Need to implement this test properly for pyarrow")
expected = pd.DataFrame(columns=frame.columns)
expected = expected.astype(roundtripped.dtypes)

result = pantab.frame_from_hyper_query(tmp_hyper, query)
Expand Down
28 changes: 19 additions & 9 deletions pantab/tests/test_roundtrip.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import pandas as pd
import pandas.testing as tm
import pyarrow as pa
from tableauhyperapi import TableName

import pantab


def test_basic(df, roundtripped, tmp_hyper, table_name, table_mode):
def test_basic(frame, roundtripped, tmp_hyper, table_name, table_mode):
# Write twice; depending on mode this should either overwrite or duplicate entries
pantab.frame_to_hyper(df, tmp_hyper, table=table_name, table_mode=table_mode)
pantab.frame_to_hyper(df, tmp_hyper, table=table_name, table_mode=table_mode)
pantab.frame_to_hyper(frame, tmp_hyper, table=table_name, table_mode=table_mode)
pantab.frame_to_hyper(frame, tmp_hyper, table=table_name, table_mode=table_mode)
result = pantab.frame_from_hyper(tmp_hyper, table=table_name)

expected = roundtripped
Expand All @@ -18,13 +19,13 @@ def test_basic(df, roundtripped, tmp_hyper, table_name, table_mode):
tm.assert_frame_equal(result, expected)


def test_multiple_tables(df, roundtripped, tmp_hyper, table_name, table_mode):
def test_multiple_tables(frame, roundtripped, tmp_hyper, table_name, table_mode):
# Write twice; depending on mode this should either overwrite or duplicate entries
pantab.frames_to_hyper(
{table_name: df, "table2": df}, tmp_hyper, table_mode=table_mode
{table_name: frame, "table2": frame}, tmp_hyper, table_mode=table_mode
)
pantab.frames_to_hyper(
{table_name: df, "table2": df}, tmp_hyper, table_mode=table_mode
{table_name: frame, "table2": frame}, tmp_hyper, table_mode=table_mode
)
result = pantab.frames_from_hyper(tmp_hyper)

Expand All @@ -41,10 +42,19 @@ def test_multiple_tables(df, roundtripped, tmp_hyper, table_name, table_mode):
tm.assert_frame_equal(val, expected)


def test_empty_roundtrip(df, roundtripped, tmp_hyper, table_name, table_mode):
def test_empty_roundtrip(frame, roundtripped, tmp_hyper, table_name, table_mode):
# object case is by definition vague, so lets punt that for now
df = df.drop(columns=["object"])
empty = df.iloc[:0]
if isinstance(frame, pd.DataFrame):
frame = frame.drop(columns=["object"])
elif isinstance(frame, pa.Table):
frame = frame.drop_columns(["object"])
else:
raise NotImplementedError("test not implemented for object")

if isinstance(frame, pa.Table):
empty = frame.schema.empty_table()
else:
empty = frame.iloc[:0]
pantab.frame_to_hyper(empty, tmp_hyper, table=table_name, table_mode=table_mode)
pantab.frame_to_hyper(empty, tmp_hyper, table=table_name, table_mode=table_mode)
result = pantab.frame_from_hyper(tmp_hyper, table=table_name)
Expand Down
Loading

0 comments on commit 1bb35c6

Please sign in to comment.