Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to return more than just pandas #243

Merged
merged 5 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 26 additions & 20 deletions pantab/_reader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pathlib
import shutil
import tempfile
from typing import Optional, Union
from typing import Union

import pandas as pd
import pyarrow as pa
Expand All @@ -12,22 +12,43 @@
TableType = Union[str, tab_api.Name, tab_api.TableName]


def frame_from_hyper_query(
source: Union[str, pathlib.Path],
query: str,
*,
return_type="pandas",
):
"""See api.rst for documentation."""
# Call native library to read tuples from result set
capsule = libpantab.read_from_hyper_query(str(source), query)
stream = pa.RecordBatchReader._import_from_c_capsule(capsule)
tbl = stream.read_all()

if return_type == "pyarrow":
return tbl

df = tbl.to_pandas(types_mapper=pd.ArrowDtype)
return df


def frame_from_hyper(
source: Union[str, pathlib.Path],
*,
table: TableType,
) -> pd.DataFrame:
return_type="pandas",
):
"""See api.rst for documentation"""
if isinstance(table, (str, tab_api.Name)) or not table.schema_name:
table = tab_api.TableName("public", table)

query = f"SELECT * FROM {table}"
return frame_from_hyper_query(source, query)
return frame_from_hyper_query(source, query, return_type=return_type)


def frames_from_hyper(
source: Union[str, pathlib.Path],
) -> dict[tab_api.TableName, pd.DataFrame]:
return_type="pandas",
):
"""See api.rst for documentation."""
result: dict[TableType, pd.DataFrame] = {}

Expand All @@ -45,22 +66,7 @@ def frames_from_hyper(
result[table] = frame_from_hyper(
source=source,
table=table,
return_type=return_type,
)

return result


def frame_from_hyper_query(
source: Union[str, pathlib.Path],
query: str,
*,
hyper_process: Optional[tab_api.HyperProcess] = None,
) -> pd.DataFrame:
"""See api.rst for documentation."""
# Call native library to read tuples from result set
capsule = libpantab.read_from_hyper_query(str(source), query)
stream = pa.RecordBatchReader._import_from_c_capsule(capsule)
tbl = stream.read_all()
df = tbl.to_pandas(types_mapper=pd.ArrowDtype)

return df
Empty file removed pantab/tests/__init__.py
Empty file.
234 changes: 179 additions & 55 deletions pantab/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,86 +3,87 @@

import numpy as np
import pandas as pd
import pandas.testing as tm
import pyarrow as pa
import pytest
import tableauhyperapi as tab_api


def basic_arrow_table():
schema = pa.schema(
[
("int16", pa.int16()),
("int32", pa.int32()),
("int64", pa.int64()),
("Int16", pa.int16()),
("Int32", pa.int32()),
("Int64", pa.int64()),
("float32", pa.float32()),
("float64", pa.float64()),
("Float32", pa.float32()),
("Float64", pa.float64()),
("bool", pa.bool_()),
("boolean", pa.bool_()),
("date32", pa.date32()),
("datetime64", pa.timestamp("us")),
("datetime64_utc", pa.timestamp("us", "utc")),
("object", pa.large_string()),
("string", pa.string()),
("int16_limits", pa.int16()),
("int32_limits", pa.int32()),
("int64_limits", pa.int64()),
("float32_limits", pa.float32()),
("float64_limits", pa.float64()),
("non-ascii", pa.utf8()),
("binary", pa.binary()),
("time64us", pa.time64("us")),
]
)
tbl = pa.Table.from_arrays(
[
pa.array([1, 6, 0], type=pa.int16()),
pa.array([2, 7, 0], type=pa.int32()),
pa.array([3, 8, 0], type=pa.int64()),
pa.array([1, None, None], type=pa.int16()),
pa.array([2, None, None], type=pa.int32()),
pa.array([3, None, None], type=pa.int64()),
pa.array([4, 9.0, None], type=pa.float32()),
pa.array([5, 10.0, None], type=pa.float64()),
pa.array([1.0, 1.0, None], type=pa.float32()),
pa.array([2.0, 2.0, None], type=pa.float64()),
pa.array([True, False, False], type=pa.bool_()),
pa.array([True, False, None], type=pa.bool_()),
pa.array([1, 6, 0]),
pa.array([2, 7, 0]),
pa.array([3, 8, 0]),
pa.array([1, None, None]),
pa.array([2, None, None]),
pa.array([3, None, None]),
pa.array([4, 9.0, None]),
pa.array([5, 10.0, None]),
pa.array([1.0, 1.0, None]),
pa.array([2.0, 2.0, None]),
pa.array([True, False, False]),
pa.array([True, False, None]),
pa.array(
[datetime.date(2024, 1, 1), datetime.date(2024, 1, 1), None],
type=pa.date32(),
),
pa.array(
[
datetime.datetime(2018, 1, 1, 0, 0, 0),
datetime.datetime(2019, 1, 1, 0, 0, 0),
None,
],
type=pa.timestamp("us"),
),
pa.array(
[
datetime.datetime(2018, 1, 1, 0, 0, 0),
datetime.datetime(2019, 1, 1, 0, 0, 0),
None,
],
type=pa.timestamp("us", "utc"),
),
pa.array(["foo", "bar", None], type=pa.large_string()),
pa.array(["foo", "bar", None], type=pa.string()),
pa.array([-(2**15), 2**15 - 1, 0], type=pa.int16()),
pa.array([-(2**31), 2**31 - 1, 0], type=pa.int32()),
pa.array([-(2**63), 2**63 - 1, 0], type=pa.int64()),
pa.array([-(2**24), 2**24 - 1, None], type=pa.float32()),
pa.array([-(2**53), 2**53 - 1, None], type=pa.float64()),
pa.array(["foo", "bar", None]),
pa.array(["foo", "bar", None]),
pa.array([-(2**15), 2**15 - 1, 0]),
pa.array([-(2**31), 2**31 - 1, 0]),
pa.array([-(2**63), 2**63 - 1, 0]),
pa.array([-(2**24), 2**24 - 1, None]),
pa.array([-(2**53), 2**53 - 1, None]),
pa.array(
["\xef\xff\xdc\xde\xee", "\xfa\xfb\xdd\xaf\xaa", None], type=pa.utf8()
["\xef\xff\xdc\xde\xee", "\xfa\xfb\xdd\xaf\xaa", None],
),
pa.array([b"\xde\xad\xbe\xef", b"\xff\xee", None], type=pa.binary()),
pa.array([234, 42, None], type=pa.time64("us")),
],
names=[
"int16",
"int32",
"int64",
"Int16",
"Int32",
"Int64",
"float32",
"float64",
"Float32",
"Float64",
"bool",
"boolean",
"date32",
"datetime64",
"datetime64_utc",
"object",
"string",
"int16_limits",
"int32_limits",
"int64_limits",
"float32_limits",
"float64_limits",
"non-ascii",
"binary",
"time64us",
pa.array([b"\xde\xad\xbe\xef", b"\xff\xee", None]),
pa.array([234, 42, None]),
],
schema=schema,
)

return tbl
Expand Down Expand Up @@ -239,8 +240,42 @@ def frame(request):
return request.param()


@pytest.fixture
def roundtripped():
def roundtripped_pyarrow():
schema = pa.schema(
[
("int16", pa.int16()),
("int32", pa.int32()),
("int64", pa.int64()),
("Int16", pa.int16()),
("Int32", pa.int32()),
("Int64", pa.int64()),
("float32", pa.float64()),
("float64", pa.float64()),
("Float32", pa.float64()),
("Float64", pa.float64()),
("bool", pa.bool_()),
("boolean", pa.bool_()),
("date32", pa.date32()),
("datetime64", pa.timestamp("us")),
("datetime64_utc", pa.timestamp("us", "UTC")),
("object", pa.large_string()),
("string", pa.large_string()),
("int16_limits", pa.int16()),
("int32_limits", pa.int32()),
("int64_limits", pa.int64()),
("float32_limits", pa.float64()),
("float64_limits", pa.float64()),
("non-ascii", pa.large_string()),
("binary", pa.large_binary()),
("time64us", pa.time64("us")),
]
)
tbl = basic_arrow_table()

return tbl.cast(schema)


def roundtripped_pandas():
"""Roundtripped DataFrames should use arrow dtypes by default"""
df = basic_dataframe()
df = df.astype(
Expand All @@ -259,7 +294,6 @@ def roundtripped():
"boolean": "boolean[pyarrow]",
"datetime64": "timestamp[us][pyarrow]",
"datetime64_utc": "timestamp[us, UTC][pyarrow]",
# "timedelta64": "timedelta64[ns]",
"object": "large_string[pyarrow]",
"int16_limits": "int16[pyarrow]",
"int32_limits": "int32[pyarrow]",
Expand All @@ -275,6 +309,17 @@ def roundtripped():
return df


@pytest.fixture(
params=[
("pandas", roundtripped_pandas),
("pyarrow", roundtripped_pyarrow),
]
)
def roundtripped(request):
result_obj = request.param[1]()
return (request.param[0], result_obj)


@pytest.fixture
def tmp_hyper(tmp_path):
"""A temporary file name to write / read a Hyper extract from."""
Expand Down Expand Up @@ -305,3 +350,82 @@ def table_name(request):
def datapath():
"""Location of data files in test folder."""
return pathlib.Path(__file__).parent / "data"


class Compat:
@staticmethod
def assert_frame_equal(result, expected):
assert isinstance(result, type(expected))
if isinstance(result, pd.DataFrame):
tm.assert_frame_equal(result, expected)
return
elif isinstance(result, pa.Table):
assert result.equals(expected, check_metadata=True)
return
else:
raise NotImplementedError("assert_frame_equal not implemented for type")

@staticmethod
def concat_frames(frame1, frame2):
assert isinstance(frame1, type(frame2))
if isinstance(frame1, pd.DataFrame):
return pd.concat([frame1, frame2]).reset_index(drop=True)
elif isinstance(frame1, pa.Table):
return pa.concat_tables([frame1, frame2])
else:
raise NotImplementedError("concat_frames not implemented for type")

@staticmethod
def empty_like(frame):
if isinstance(frame, pd.DataFrame):
return frame.iloc[:0]
elif isinstance(frame, pa.Table):
return frame.schema.empty_table()
else:
raise NotImplementedError("empty_like not implemented for type")

@staticmethod
def drop_columns(frame, columns):
if isinstance(frame, pd.DataFrame):
return frame.drop(columns=columns)
elif isinstance(frame, pa.Table):
return frame.drop_columns(columns)
else:
raise NotImplementedError("drop_columns not implemented for type")

@staticmethod
def select_columns(frame, columns):
if isinstance(frame, pd.DataFrame):
return frame[columns]
elif isinstance(frame, pa.Table):
return frame.select(columns)
else:
raise NotImplementedError("select_columns not implemented for type")

@staticmethod
def cast_column_to_type(frame, column, type_):
if isinstance(frame, pd.DataFrame):
frame[column] = frame[column].astype(type_)
return frame
elif isinstance(frame, pa.Table):
schema = pa.schema([pa.field(column, type_)])
return frame.cast(schema)
else:
raise NotImplementedError("cast_column_to_type not implemented for type")

@staticmethod
def add_non_writeable_column(frame):
if isinstance(frame, pd.DataFrame):
frame["should_fail"] = pd.Series([list((1, 2))])
return frame
elif isinstance(frame, pa.Table):
new_column = pa.array([[1, 2], None, None])
frame = frame.append_column("should_fail", new_column)
return frame
else:
raise NotImplementedError("test not implemented for object")


@pytest.fixture()
def compat():
return Compat
Loading
Loading