diff --git a/pantab/_writer.py b/pantab/_writer.py index f6b0b296..9dd49034 100644 --- a/pantab/_writer.py +++ b/pantab/_writer.py @@ -2,27 +2,49 @@ import shutil import tempfile import uuid -from typing import Optional, Union +from typing import Any, Literal, Optional, Union -import pandas as pd -import pyarrow as pa import tableauhyperapi as tab_api import pantab._types as pantab_types import pantab.src.pantab as libpantab # type: ignore -def _validate_table_mode(table_mode: str) -> None: +def _validate_table_mode(table_mode: Literal["a", "w"]) -> None: if table_mode not in {"a", "w"}: raise ValueError("'table_mode' must be either 'w' or 'a'") +def _get_capsule_from_obj(obj): + """Returns the Arrow capsule underlying obj""" + # Check first for the Arrow C Data Interface compliance + if hasattr(obj, "__arrow_c_stream__"): + return obj.__arrow_c_stream__() + + # pandas < 3.0 did not have the Arrow C Data Interface, so + # convert to PyArrow + try: + import pandas as pd + import pyarrow as pa + + if isinstance(obj, pd.DataFrame): + return pa.Table.from_pandas(obj).__arrow_c_stream__() + except ModuleNotFoundError: + pass + + # More introspection could happen in the future...but end with TypeError if we + # can not find what we are looking for + raise TypeError( + f"Could not convert object of type '{type(obj)}' to Arrow C Data Interface" + ) + + def frame_to_hyper( - df: pd.DataFrame, + df, database: Union[str, pathlib.Path], *, table: pantab_types.TableType, - table_mode: str = "w", + table_mode: Literal["a", "w"] = "w", ) -> None: """See api.rst for documentation""" frames_to_hyper( @@ -33,9 +55,9 @@ def frame_to_hyper( def frames_to_hyper( - dict_of_frames: dict[pantab_types.TableType, pd.DataFrame], + dict_of_frames: dict[pantab_types.TableType, Any], database: Union[str, pathlib.Path], - table_mode: str = "w", + table_mode: Literal["a", "w"] = "w", *, hyper_process: Optional[tab_api.HyperProcess] = None, ) -> None: @@ -55,7 +77,7 @@ def convert_to_table_name(table: pantab_types.TableType): return (table.schema_name.name.unescaped, table.name.unescaped) data = { - convert_to_table_name(key): pa.Table.from_pandas(val) + convert_to_table_name(key): _get_capsule_from_obj(val) for key, val in dict_of_frames.items() } libpantab.write_to_hyper(data, path=str(tmp_db), table_mode=table_mode) diff --git a/pantab/src/pantab.cpp b/pantab/src/pantab.cpp index d3ce895c..40cf872d 100644 --- a/pantab/src/pantab.cpp +++ b/pantab/src/pantab.cpp @@ -418,7 +418,7 @@ void assertColumnsEqual( using SchemaAndTableName = std::tuple; void write_to_hyper( - const std::map &dict_of_exportable, + const std::map &dict_of_capsules, const std::string &path, const std::string &table_mode) { const hyperapi::HyperProcess hyper{ hyperapi::Telemetry::DoNotSendUsageDataToTableau}; @@ -432,17 +432,12 @@ void write_to_hyper( hyperapi::Connection connection{hyper.getEndpoint(), path, createMode}; const hyperapi::Catalog &catalog = connection.getCatalog(); - for (auto const &[schema_and_table, exportable] : dict_of_exportable) { + for (auto const &[schema_and_table, capsule] : dict_of_capsules) { const auto hyper_schema = std::get<0>(schema_and_table); const auto hyper_table = std::get<1>(schema_and_table); - const auto arrow_c_stream = nb::getattr(exportable, "__arrow_c_stream__")(); - PyObject *obj = arrow_c_stream.ptr(); - if (!PyCapsule_CheckExact(obj)) { - throw std::invalid_argument("Object does not provide capsule"); - } const auto c_stream = static_cast( - PyCapsule_GetPointer(obj, "arrow_array_stream")); + PyCapsule_GetPointer(capsule.ptr(), "arrow_array_stream")); auto stream = nanoarrow::UniqueArrayStream{c_stream}; struct ArrowSchema schema; @@ -876,7 +871,7 @@ auto read_from_hyper_query(const std::string &path, const std::string &query) } NB_MODULE(pantab, m) { // NOLINT - m.def("write_to_hyper", &write_to_hyper, nb::arg("dict_of_exportable"), + m.def("write_to_hyper", &write_to_hyper, nb::arg("dict_of_capsules"), nb::arg("path"), nb::arg("table_mode")) .def("read_from_hyper_query", &read_from_hyper_query, nb::arg("path"), nb::arg("query")); diff --git a/pantab/tests/conftest.py b/pantab/tests/conftest.py index f5362a6f..b42c21a6 100644 --- a/pantab/tests/conftest.py +++ b/pantab/tests/conftest.py @@ -1,3 +1,4 @@ +import datetime import pathlib import numpy as np @@ -7,7 +8,87 @@ import tableauhyperapi as tab_api -def get_basic_dataframe(): +def basic_arrow_table(): + tbl = pa.Table.from_arrays( + [ + pa.array([1, 6, 0], type=pa.int16()), + pa.array([2, 7, 0], type=pa.int32()), + pa.array([3, 8, 0], type=pa.int64()), + pa.array([1, None, None], type=pa.int16()), + pa.array([2, None, None], type=pa.int32()), + pa.array([3, None, None], type=pa.int64()), + pa.array([4, 9.0, None], type=pa.float32()), + pa.array([5, 10.0, None], type=pa.float64()), + pa.array([1.0, 1.0, None], type=pa.float32()), + pa.array([2.0, 2.0, None], type=pa.float64()), + pa.array([True, False, False], type=pa.bool_()), + pa.array([True, False, None], type=pa.bool_()), + pa.array( + [datetime.date(2024, 1, 1), datetime.date(2024, 1, 1), None], + type=pa.date32(), + ), + pa.array( + [ + datetime.datetime(2018, 1, 1, 0, 0, 0), + datetime.datetime(2019, 1, 1, 0, 0, 0), + None, + ], + type=pa.timestamp("us"), + ), + pa.array( + [ + datetime.datetime(2018, 1, 1, 0, 0, 0), + datetime.datetime(2019, 1, 1, 0, 0, 0), + None, + ], + type=pa.timestamp("us", "utc"), + ), + pa.array(["foo", "bar", None], type=pa.large_string()), + pa.array(["foo", "bar", None], type=pa.string()), + pa.array([-(2**15), 2**15 - 1, 0], type=pa.int16()), + pa.array([-(2**31), 2**31 - 1, 0], type=pa.int32()), + pa.array([-(2**63), 2**63 - 1, 0], type=pa.int64()), + pa.array([-(2**24), 2**24 - 1, None], type=pa.float32()), + pa.array([-(2**53), 2**53 - 1, None], type=pa.float64()), + pa.array( + ["\xef\xff\xdc\xde\xee", "\xfa\xfb\xdd\xaf\xaa", None], type=pa.utf8() + ), + pa.array([b"\xde\xad\xbe\xef", b"\xff\xee", None], type=pa.binary()), + pa.array([234, 42, None], type=pa.time64("us")), + ], + names=[ + "int16", + "int32", + "int64", + "Int16", + "Int32", + "Int64", + "float32", + "float64", + "Float32", + "Float64", + "bool", + "boolean", + "date32", + "datetime64", + "datetime64_utc", + "object", + "string", + "int16_limits", + "int32_limits", + "int64_limits", + "float32_limits", + "float64_limits", + "non-ascii", + "binary", + "time64us", + ], + ) + + return tbl + + +def basic_dataframe(): df = pd.DataFrame( [ [ @@ -49,7 +130,7 @@ def get_basic_dataframe(): False, False, pd.to_datetime("2024-01-01"), - pd.to_datetime("1/1/19"), + pd.to_datetime("2019-01-01"), pd.to_datetime("2019-01-01", utc=True), "bar", "bar", @@ -144,22 +225,24 @@ def get_basic_dataframe(): # See pandas GH issue #56994 df["binary"] = pa.array([b"\xde\xad\xbe\xef", b"\xff\xee", None], type=pa.binary()) df["binary"] = df["binary"].astype("binary[pyarrow]") - df["time64us"] = pd.DataFrame({"col": pa.array([234, 42], type=pa.time64("us"))}) + df["time64us"] = pd.DataFrame( + {"col": pa.array([234, 42, None], type=pa.time64("us"))} + ) df["time64us"] = df["time64us"].astype("time64[us][pyarrow]") return df -@pytest.fixture -def df(): +@pytest.fixture(params=[basic_arrow_table, basic_dataframe]) +def frame(request): """Fixture to use which should contain all data types.""" - return get_basic_dataframe() + return request.param() @pytest.fixture def roundtripped(): """Roundtripped DataFrames should use arrow dtypes by default""" - df = get_basic_dataframe() + df = basic_dataframe() df = df.astype( { "int16": "int16[pyarrow]", diff --git a/pantab/tests/test_reader.py b/pantab/tests/test_reader.py index 678c2b19..009410e4 100644 --- a/pantab/tests/test_reader.py +++ b/pantab/tests/test_reader.py @@ -2,13 +2,14 @@ import pandas as pd import pandas.testing as tm +import pytest import tableauhyperapi as tab_api import pantab -def test_read_doesnt_modify_existing_file(df, tmp_hyper): - pantab.frame_to_hyper(df, tmp_hyper, table="test") +def test_read_doesnt_modify_existing_file(frame, tmp_hyper): + pantab.frame_to_hyper(frame, tmp_hyper, table="test") last_modified = tmp_hyper.stat().st_mtime # Try out our read methods @@ -51,8 +52,8 @@ def test_reads_non_writeable(datapath): tm.assert_frame_equal(result, expected) -def test_read_query(df, tmp_hyper): - pantab.frame_to_hyper(df, tmp_hyper, table="test") +def test_read_query(frame, tmp_hyper): + pantab.frame_to_hyper(frame, tmp_hyper, table="test") query = "SELECT int16 AS i, '_' || int32 AS _i2 FROM test" result = pantab.frame_from_hyper_query(tmp_hyper, query) @@ -63,15 +64,18 @@ def test_read_query(df, tmp_hyper): tm.assert_frame_equal(result, expected) -def test_empty_read_query(df: pd.DataFrame, roundtripped, tmp_hyper): +def test_empty_read_query(frame, roundtripped, tmp_hyper): """ red-green for empty query results """ # sql cols need to base case insensitive & unique table_name = "test" - pantab.frame_to_hyper(df, tmp_hyper, table=table_name) + pantab.frame_to_hyper(frame, tmp_hyper, table=table_name) query = f"SELECT * FROM {table_name} limit 0" - expected = pd.DataFrame(columns=df.columns) + + if not isinstance(frame, pd.DataFrame): + pytest.skip("Need to implement this test properly for pyarrow") + expected = pd.DataFrame(columns=frame.columns) expected = expected.astype(roundtripped.dtypes) result = pantab.frame_from_hyper_query(tmp_hyper, query) diff --git a/pantab/tests/test_roundtrip.py b/pantab/tests/test_roundtrip.py index 4f12c2f2..62f513b1 100644 --- a/pantab/tests/test_roundtrip.py +++ b/pantab/tests/test_roundtrip.py @@ -1,14 +1,15 @@ import pandas as pd import pandas.testing as tm +import pyarrow as pa from tableauhyperapi import TableName import pantab -def test_basic(df, roundtripped, tmp_hyper, table_name, table_mode): +def test_basic(frame, roundtripped, tmp_hyper, table_name, table_mode): # Write twice; depending on mode this should either overwrite or duplicate entries - pantab.frame_to_hyper(df, tmp_hyper, table=table_name, table_mode=table_mode) - pantab.frame_to_hyper(df, tmp_hyper, table=table_name, table_mode=table_mode) + pantab.frame_to_hyper(frame, tmp_hyper, table=table_name, table_mode=table_mode) + pantab.frame_to_hyper(frame, tmp_hyper, table=table_name, table_mode=table_mode) result = pantab.frame_from_hyper(tmp_hyper, table=table_name) expected = roundtripped @@ -18,13 +19,13 @@ def test_basic(df, roundtripped, tmp_hyper, table_name, table_mode): tm.assert_frame_equal(result, expected) -def test_multiple_tables(df, roundtripped, tmp_hyper, table_name, table_mode): +def test_multiple_tables(frame, roundtripped, tmp_hyper, table_name, table_mode): # Write twice; depending on mode this should either overwrite or duplicate entries pantab.frames_to_hyper( - {table_name: df, "table2": df}, tmp_hyper, table_mode=table_mode + {table_name: frame, "table2": frame}, tmp_hyper, table_mode=table_mode ) pantab.frames_to_hyper( - {table_name: df, "table2": df}, tmp_hyper, table_mode=table_mode + {table_name: frame, "table2": frame}, tmp_hyper, table_mode=table_mode ) result = pantab.frames_from_hyper(tmp_hyper) @@ -41,10 +42,19 @@ def test_multiple_tables(df, roundtripped, tmp_hyper, table_name, table_mode): tm.assert_frame_equal(val, expected) -def test_empty_roundtrip(df, roundtripped, tmp_hyper, table_name, table_mode): +def test_empty_roundtrip(frame, roundtripped, tmp_hyper, table_name, table_mode): # object case is by definition vague, so lets punt that for now - df = df.drop(columns=["object"]) - empty = df.iloc[:0] + if isinstance(frame, pd.DataFrame): + frame = frame.drop(columns=["object"]) + elif isinstance(frame, pa.Table): + frame = frame.drop_columns(["object"]) + else: + raise NotImplementedError("test not implemented for object") + + if isinstance(frame, pa.Table): + empty = frame.schema.empty_table() + else: + empty = frame.iloc[:0] pantab.frame_to_hyper(empty, tmp_hyper, table=table_name, table_mode=table_mode) pantab.frame_to_hyper(empty, tmp_hyper, table=table_name, table_mode=table_mode) result = pantab.frame_from_hyper(tmp_hyper, table=table_name) diff --git a/pantab/tests/test_writer.py b/pantab/tests/test_writer.py index f6a1c2c4..48f9b98a 100644 --- a/pantab/tests/test_writer.py +++ b/pantab/tests/test_writer.py @@ -2,53 +2,69 @@ from datetime import datetime, timezone import pandas as pd +import pyarrow as pa import pytest from tableauhyperapi import Connection, CreateMode, HyperProcess, Telemetry import pantab -def test_bad_table_mode_raises(df, tmp_hyper): +def test_bad_table_mode_raises(frame, tmp_hyper): msg = "'table_mode' must be either 'w' or 'a'" with pytest.raises(ValueError, match=msg): pantab.frame_to_hyper( - df, + frame, tmp_hyper, table="test", table_mode="x", ) with pytest.raises(ValueError, match=msg): - pantab.frames_to_hyper({"a": df}, tmp_hyper, table_mode="x") + pantab.frames_to_hyper({"a": frame}, tmp_hyper, table_mode="x") @pytest.mark.parametrize( - "new_dtype,hyper_type_name", [("int64", "BIGINT"), (float, "DOUBLE PRECISION")] + "new_dtype,hyper_type_name", [("int64", "BIGINT"), ("float", "DOUBLE PRECISION")] ) def test_append_mode_raises_column_dtype_mismatch( - new_dtype, hyper_type_name, df, tmp_hyper, table_name + new_dtype, hyper_type_name, frame, tmp_hyper, table_name ): - pantab.frame_to_hyper(df, tmp_hyper, table=table_name) - - df["int16"] = df["int16"].astype(new_dtype) + if isinstance(frame, pd.DataFrame): + frame = frame[["int16"]].copy() + else: + frame = frame.select(["int16"]) + pantab.frame_to_hyper(frame, tmp_hyper, table=table_name) + + if isinstance(frame, pd.DataFrame): + frame["int16"] = frame["int16"].astype(new_dtype) + elif isinstance(frame, pa.Table): + schema = pa.schema([pa.field("int16", new_dtype)]) + frame = frame.cast(schema) + else: + raise NotImplementedError("test not implemented for object") msg = f"Column type mismatch at index 0; new: {hyper_type_name} old: SMALLINT" with pytest.raises(ValueError, match=msg): - pantab.frame_to_hyper(df, tmp_hyper, table=table_name, table_mode="a") + pantab.frame_to_hyper(frame, tmp_hyper, table=table_name, table_mode="a") -def test_append_mode_raises_ncolumns_mismatch(df, tmp_hyper, table_name): - pantab.frame_to_hyper(df, tmp_hyper, table=table_name) +def test_append_mode_raises_ncolumns_mismatch(frame, tmp_hyper, table_name): + pantab.frame_to_hyper(frame, tmp_hyper, table=table_name) - df = df.drop(columns=["int16"]) + if isinstance(frame, pd.DataFrame): + frame = frame.drop(columns=["int16"]) + elif isinstance(frame, pa.Table): + frame = frame.drop_columns(["int16"]) + else: + raise NotImplementedError("test not implemented for object") msg = "Number of columns" with pytest.raises(ValueError, match=msg): - pantab.frame_to_hyper(df, tmp_hyper, table=table_name, table_mode="a") + pantab.frame_to_hyper(frame, tmp_hyper, table=table_name, table_mode="a") -def test_failed_write_doesnt_overwrite_file(df, tmp_hyper, monkeypatch, table_mode): +def test_failed_write_doesnt_overwrite_file(frame, tmp_hyper, monkeypatch, table_mode): pantab.frame_to_hyper( - df, + frame, tmp_hyper, table="test", table_mode=table_mode, @@ -56,51 +72,61 @@ def test_failed_write_doesnt_overwrite_file(df, tmp_hyper, monkeypatch, table_mo last_modified = tmp_hyper.stat().st_mtime # Pick a dtype we know will fail - df["should_fail"] = pd.Series([tuple((1, 2))]) + if isinstance(frame, pd.DataFrame): + frame["should_fail"] = pd.Series([list((1, 2))]) + elif isinstance(frame, pa.Table): + new_column = pa.array([[1, 2], None, None]) + frame = frame.append_column("should_fail", new_column) + else: + raise NotImplementedError("test not implemented for object") # Try out our write methods - with pytest.raises(Exception): - pantab.frame_to_hyper(df, tmp_hyper, table="test", table_mode=table_mode) - pantab.frames_to_hyper({"test": df}, tmp_hyper, table_mode=table_mode) + msg = "Unsupported Arrow type" + with pytest.raises(ValueError, match=msg): + pantab.frame_to_hyper(frame, tmp_hyper, table="test", table_mode=table_mode) + pantab.frames_to_hyper({"test": frame}, tmp_hyper, table_mode=table_mode) # Neither should not update file stats assert last_modified == tmp_hyper.stat().st_mtime def test_duplicate_columns_raises(tmp_hyper): - df = pd.DataFrame([[1, 1]], columns=[1, 1]) + frame = pd.DataFrame([[1, 1]], columns=[1, 1]) msg = r"Duplicate column names found: \[1, 1\]" with pytest.raises(ValueError, match=msg): - pantab.frame_to_hyper(df, tmp_hyper, table="foo") + pantab.frame_to_hyper(frame, tmp_hyper, table="foo") with pytest.raises(ValueError, match=msg): - pantab.frames_to_hyper({"test": df}, tmp_hyper) + pantab.frames_to_hyper({"test": frame}, tmp_hyper) def test_unsupported_dtype_raises(tmp_hyper): - df = pd.DataFrame([[pd.Timedelta("1D")]]) + frame = pd.DataFrame([[pd.Timedelta("1D")]]) msg = re.escape("Unsupported Arrow type") with pytest.raises(ValueError, match=msg): - pantab.frame_to_hyper(df, tmp_hyper, table="test") + pantab.frame_to_hyper(frame, tmp_hyper, table="test") def test_utc_bug(tmp_hyper): """ Red-Green for UTC bug """ - df = pd.DataFrame( + frame = pd.DataFrame( {"utc_time": [datetime.now(timezone.utc), pd.Timestamp("today", tz="UTC")]} ) - pantab.frame_to_hyper(df, tmp_hyper, table="exp") + pantab.frame_to_hyper(frame, tmp_hyper, table="exp") with HyperProcess(Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU) as hyper: with Connection( hyper.endpoint, tmp_hyper, CreateMode.CREATE_IF_NOT_EXISTS ) as connection: resp = connection.execute_list_query("select utc_time from exp") assert all( - [actual[0].year == expected.year for actual, expected in zip(resp, df.utc_time)] + [ + actual[0].year == expected.year + for actual, expected in zip(resp, frame.utc_time) + ] ), f""" - expected: {df.utc_time} + expected: {frame.utc_time} actual: {[c[0] for c in resp]} """