Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Polars support #245

Merged
merged 4 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies:
- pandas
- pandas-stubs
- pip
- polars
- pyarrow
- python
- pytest
Expand All @@ -19,4 +20,3 @@ dependencies:
- sphinx_rtd_theme
- pip:
- asv
- tableauhyperapi>=0.0.14567
14 changes: 10 additions & 4 deletions pantab/_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import tempfile
from typing import Union

import pandas as pd
import pyarrow as pa
import tableauhyperapi as tab_api

Expand All @@ -26,9 +25,16 @@ def frame_from_hyper_query(

if return_type == "pyarrow":
return tbl
elif return_type == "polars":
import polars as pl

df = tbl.to_pandas(types_mapper=pd.ArrowDtype)
return df
return pl.from_arrow(tbl)
elif return_type == "pandas":
import pandas as pd

return tbl.to_pandas(types_mapper=pd.ArrowDtype)

raise NotImplementedError("Please choose an appropriate 'return_type' value")


def frame_from_hyper(
Expand All @@ -50,7 +56,7 @@ def frames_from_hyper(
return_type="pandas",
):
"""See api.rst for documentation."""
result: dict[TableType, pd.DataFrame] = {}
result = {}

table_names = []
with tempfile.TemporaryDirectory() as tmp_dir, tab_api.HyperProcess(
Expand Down
9 changes: 9 additions & 0 deletions pantab/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ def _get_capsule_from_obj(obj):
except ModuleNotFoundError:
pass

# see polars GH issue #12530 - PyCapsule interface not yet developed
try:
import polars as pl

if isinstance(obj, pl.DataFrame):
return obj.to_arrow().__arrow_c_stream__()
except ModuleNotFoundError:
pass

# More introspection could happen in the future...but end with TypeError if we
# can not find what we are looking for
raise TypeError(
Expand Down
47 changes: 42 additions & 5 deletions pantab/src/pantab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,31 @@ class Date32InsertHelper : public InsertHelper {
}
};

template <enum ArrowTimeUnit TU> class TimeInsertHelper : public InsertHelper {
public:
using InsertHelper::InsertHelper;

void insertValueAtIndex(size_t idx) override {
if (ArrowArrayViewIsNull(&array_view_, idx)) {
// MSVC on cibuildwheel doesn't like this templated optional
// inserter_->add(std::optional<T>{std::nullopt});
hyperapi::internal::ValueInserter{*inserter_}.addNull();
return;
}

int64_t value = ArrowArrayViewGetIntUnsafe(&array_view_, idx);
// TODO: check for overflow in these branches
if constexpr (TU == NANOARROW_TIME_UNIT_SECOND) {
value *= 1'000'000;
} else if constexpr (TU == NANOARROW_TIME_UNIT_MILLI) {
value *= 1000;
} else if constexpr (TU == NANOARROW_TIME_UNIT_NANO) {
value /= 1000;
}
hyperapi::internal::ValueInserter{*inserter_}.addValue(value);
}
};

template <enum TimeUnit TU, bool TZAware>
class TimestampInsertHelper : public InsertHelper {
public:
Expand Down Expand Up @@ -367,13 +392,25 @@ static auto makeInsertHelper(std::shared_ptr<hyperapi::Inserter> inserter,
"This code block should not be hit - contact a developer");
case NANOARROW_TYPE_TIME64:
switch (schema_view.time_unit) {
// must be a smarter way to do this!
case NANOARROW_TIME_UNIT_SECOND: // untested
return std::unique_ptr<InsertHelper>(
new TimeInsertHelper<NANOARROW_TIME_UNIT_SECOND>(
inserter, chunk, schema, error, column_position));
case NANOARROW_TIME_UNIT_MILLI: // untested
return std::unique_ptr<InsertHelper>(
new TimeInsertHelper<NANOARROW_TIME_UNIT_MILLI>(
inserter, chunk, schema, error, column_position));
case NANOARROW_TIME_UNIT_MICRO:
return std::unique_ptr<InsertHelper>(new IntegralInsertHelper<int64_t>(
inserter, chunk, schema, error, column_position));
default:
throw std::invalid_argument(
"Only microsecond-precision timestamp writes are implemented!");
return std::unique_ptr<InsertHelper>(
new TimeInsertHelper<NANOARROW_TIME_UNIT_MICRO>(
inserter, chunk, schema, error, column_position));
case NANOARROW_TIME_UNIT_NANO:
return std::unique_ptr<InsertHelper>(
new TimeInsertHelper<NANOARROW_TIME_UNIT_NANO>(
inserter, chunk, schema, error, column_position));
}
break;
default:
throw std::invalid_argument("makeInsertHelper: Unsupported Arrow type: " +
std::to_string(schema_view.type));
Expand Down
39 changes: 37 additions & 2 deletions pantab/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pandas as pd
import pandas.testing as tm
import polars as pl
import pyarrow as pa
import pytest
import tableauhyperapi as tab_api
Expand All @@ -26,7 +27,7 @@ def basic_arrow_table():
("boolean", pa.bool_()),
("date32", pa.date32()),
("datetime64", pa.timestamp("us")),
("datetime64_utc", pa.timestamp("us", "utc")),
("datetime64_utc", pa.timestamp("us", "UTC")),
("object", pa.large_string()),
("string", pa.string()),
("int16_limits", pa.int16()),
Expand Down Expand Up @@ -234,7 +235,13 @@ def basic_dataframe():
return df


@pytest.fixture(params=[basic_arrow_table, basic_dataframe])
def basic_polars_frame():
tbl = basic_arrow_table()
df = pl.from_arrow(tbl)
return df


@pytest.fixture(params=[basic_arrow_table, basic_dataframe, basic_polars_frame])
def frame(request):
"""Fixture to use which should contain all data types."""
return request.param()
Expand Down Expand Up @@ -309,10 +316,16 @@ def roundtripped_pandas():
return df


def roundtripped_polars():
df = basic_polars_frame()
return df


@pytest.fixture(
params=[
("pandas", roundtripped_pandas),
("pyarrow", roundtripped_pyarrow),
("polars", roundtripped_polars),
]
)
def roundtripped(request):
Expand Down Expand Up @@ -362,6 +375,8 @@ def assert_frame_equal(result, expected):
elif isinstance(result, pa.Table):
assert result.equals(expected, check_metadata=True)
return
elif isinstance(result, pl.DataFrame):
assert result.equals(expected)
else:
raise NotImplementedError("assert_frame_equal not implemented for type")

Expand All @@ -372,6 +387,8 @@ def concat_frames(frame1, frame2):
return pd.concat([frame1, frame2]).reset_index(drop=True)
elif isinstance(frame1, pa.Table):
return pa.concat_tables([frame1, frame2])
elif isinstance(frame1, pl.DataFrame):
return pl.concat([frame1, frame2])
else:
raise NotImplementedError("concat_frames not implemented for type")

Expand All @@ -381,6 +398,8 @@ def empty_like(frame):
return frame.iloc[:0]
elif isinstance(frame, pa.Table):
return frame.schema.empty_table()
elif isinstance(frame, pl.DataFrame):
return frame.filter(False)
else:
raise NotImplementedError("empty_like not implemented for type")

Expand All @@ -390,6 +409,8 @@ def drop_columns(frame, columns):
return frame.drop(columns=columns)
elif isinstance(frame, pa.Table):
return frame.drop_columns(columns)
elif isinstance(frame, pl.DataFrame):
return frame.drop(columns=columns)
else:
raise NotImplementedError("drop_columns not implemented for type")

Expand All @@ -399,6 +420,8 @@ def select_columns(frame, columns):
return frame[columns]
elif isinstance(frame, pa.Table):
return frame.select(columns)
elif isinstance(frame, pl.DataFrame):
return frame.select(columns)
else:
raise NotImplementedError("select_columns not implemented for type")

Expand All @@ -410,6 +433,13 @@ def cast_column_to_type(frame, column, type_):
elif isinstance(frame, pa.Table):
schema = pa.schema([pa.field(column, type_)])
return frame.cast(schema)
elif isinstance(frame, pl.DataFrame):
# hacky :-(
if type_ == "int64":
frame = frame.cast({column: pl.Int64()})
elif type_ == "float":
frame = frame.cast({column: pl.Float64()})
return frame
else:
raise NotImplementedError("cast_column_to_type not implemented for type")

Expand All @@ -422,6 +452,11 @@ def add_non_writeable_column(frame):
new_column = pa.array([[1, 2], None, None])
frame = frame.append_column("should_fail", new_column)
return frame
elif isinstance(frame, pl.DataFrame):
frame = frame.with_columns(
pl.Series(name="should_fail", values=[list((1, 2))])
)
return frame
else:
raise NotImplementedError("test not implemented for object")

Expand Down
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,10 @@ classifiers = [
keywords = ["tableau", "visualization", "pandas", "dataframe"]

dependencies = [
"pandas>=2.0.0",
"tableauhyperapi>=0.0.14567",
"numpy",
# in the future we need not require pyarrow as pandas implements the
# PyCapsule interface. See pandas PR #56587
"pyarrow>=14.0.0",
"tableauhyperapi>=0.0.14567",
]

[project.urls]
Expand Down Expand Up @@ -70,7 +68,7 @@ build = "cp39-*64 cp310-*64 cp311-*64 cp312-*64"
skip = "*musllinux*"

test-command = "pytest --import-mode=importlib {project}/pantab/tests"
test-requires = ["pytest"]
test-requires = ["pytest", "pandas>=2.0.0", "polars", "numpy"]

[tool.ruff]
line-length = 88
Expand Down
Loading