From 48f6541c8da7771a839a2776f90a28daea43ab43 Mon Sep 17 00:00:00 2001 From: William Ayd Date: Thu, 18 Jan 2024 17:48:40 -0500 Subject: [PATCH] Pythonless read (#226) --- CMakeLists.txt | 2 +- pantab/_reader.py | 40 +--- pantab/src/pantab.cpp | 365 ++++++++++++++++++++------------- pantab/tests/conftest.py | 6 +- pantab/tests/test_reader.py | 4 +- pantab/tests/test_roundtrip.py | 8 - 6 files changed, 235 insertions(+), 190 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d3a64826..b61250db 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,7 +40,7 @@ find_package(tableauhyperapi-cxx CONFIG REQUIRED) FetchContent_Declare(nanoarrow-project GIT_REPOSITORY https://github.com/apache/arrow-nanoarrow.git - GIT_TAG apache-arrow-nanoarrow-0.3.0 + GIT_TAG b3c952a3e21c2b47df85dbede3444f852614a3e2 ) FetchContent_MakeAvailable(nanoarrow-project) diff --git a/pantab/_reader.py b/pantab/_reader.py index a85b62d7..32915724 100644 --- a/pantab/_reader.py +++ b/pantab/_reader.py @@ -4,6 +4,7 @@ from typing import Dict, Optional, Union import pandas as pd +import pyarrow as pa import tableauhyperapi as tab_api import pantab.src.pantab as libpantab # type: ignore @@ -20,25 +21,8 @@ def frame_from_hyper( if isinstance(table, (str, tab_api.Name)) or not table.schema_name: table = tab_api.TableName("public", table) - data, columns, dtypes = libpantab.read_from_hyper_table( - str(source), - table.schema_name.name.unescaped, # TODO: this probably allows injection - table.name.unescaped, - ) - df = pd.DataFrame(data, columns=columns) - dtype_map = {k: v for k, v in zip(columns, dtypes) if v != "datetime64[ns, UTC]"} - df = df.astype(dtype_map) - - tz_aware_columns = { - col for col, dtype in zip(columns, dtypes) if dtype == "datetime64[ns, UTC]" - } - for col in tz_aware_columns: - try: - df[col] = df[col].dt.tz_localize("UTC") - except AttributeError: # happens when df[col] is empty - df[col] = df[col].astype("datetime64[ns, UTC]") - - return df + query = f"SELECT * FROM {table}" + return frame_from_hyper_query(source, query) def frames_from_hyper( @@ -74,19 +58,9 @@ def frame_from_hyper_query( ) -> pd.DataFrame: """See api.rst for documentation.""" # Call native library to read tuples from result set - df = pd.DataFrame(libpantab.read_from_hyper_query(str(source), query)) - data, columns, dtypes = libpantab.read_from_hyper_query(str(source), query) - df = pd.DataFrame(data, columns=columns) - dtype_map = {k: v for k, v in zip(columns, dtypes) if v != "datetime64[ns, UTC]"} - df = df.astype(dtype_map) - - tz_aware_columns = { - col for col, dtype in zip(columns, dtypes) if dtype == "datetime64[ns, UTC]" - } - for col in tz_aware_columns: - try: - df[col] = df[col].dt.tz_localize("UTC") - except AttributeError: # happens when df[col] is empty - df[col] = df[col].astype("datetime64[ns, UTC]") + capsule = libpantab.read_from_hyper_query(str(source), query) + stream = pa.RecordBatchReader._import_from_c_capsule(capsule) + tbl = stream.read_all() + df = tbl.to_pandas(types_mapper=pd.ArrowDtype) return df diff --git a/pantab/src/pantab.cpp b/pantab/src/pantab.cpp index ceb48d6e..628d1a87 100644 --- a/pantab/src/pantab.cpp +++ b/pantab/src/pantab.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -18,6 +19,8 @@ #include #include +#include "nanoarrow/nanoarrow.h" +#include "nanoarrow/nanoarrow_types.h" #include "numpy_datetime.h" namespace nb = nanobind; @@ -440,240 +443,316 @@ void write_to_hyper( class ReadHelper { public: - ReadHelper() = default; + ReadHelper(struct ArrowArray *array) : array_(array) {} virtual ~ReadHelper() = default; - virtual auto Read(const hyperapi::Value &) -> nb::object = 0; + virtual auto Read(const hyperapi::Value &) -> void = 0; + +protected: + struct ArrowArray *array_; }; class IntegralReadHelper : public ReadHelper { - auto Read(const hyperapi::Value &value) -> nb::object override { + using ReadHelper::ReadHelper; + + auto Read(const hyperapi::Value &value) -> void override { if (value.isNull()) { - return nb::none(); + if (ArrowArrayAppendNull(array_, 1)) { + throw std::runtime_error("ArrowAppendNull failed"); + } + return; } - return nb::int_(value.get()); + if (ArrowArrayAppendInt(array_, value.get())) { + throw std::runtime_error("ArrowAppendInt failed"); + }; } }; class FloatReadHelper : public ReadHelper { - auto Read(const hyperapi::Value &value) -> nb::object override { + using ReadHelper::ReadHelper; + + auto Read(const hyperapi::Value &value) -> void override { if (value.isNull()) { - return nb::none(); + if (ArrowArrayAppendNull(array_, 1)) { + throw std::runtime_error("ArrowAppendNull failed"); + } + return; } - return nb::float_(value.get()); + if (ArrowArrayAppendDouble(array_, value.get())) { + throw std::runtime_error("ArrowAppendDouble failed"); + }; } }; class BooleanReadHelper : public ReadHelper { - auto Read(const hyperapi::Value &value) -> nb::object override { - // TODO: bool support added in nanobind >= 1..9.0 - // return nb::bool_(value.get()); + using ReadHelper::ReadHelper; + + auto Read(const hyperapi::Value &value) -> void override { if (value.isNull()) { - return nb::none(); + if (ArrowArrayAppendNull(array_, 1)) { + throw std::runtime_error("ArrowAppendNull failed"); + } + return; } - return nb::int_(value.get()); + if (ArrowArrayAppendInt(array_, value.get())) { + throw std::runtime_error("ArrowAppendBool failed"); + }; } }; class StringReadHelper : public ReadHelper { - auto Read(const hyperapi::Value &value) -> nb::object override { + using ReadHelper::ReadHelper; + + auto Read(const hyperapi::Value &value) -> void override { if (value.isNull()) { - return nb::none(); + if (ArrowArrayAppendNull(array_, 1)) { + throw std::runtime_error("ArrowAppendNull failed"); + } + return; } - return nb::str(value.get().c_str()); + + const auto strval = value.get(); + const ArrowStringView string_view{strval.c_str(), + static_cast(strval.size())}; + + if (ArrowArrayAppendString(array_, string_view)) { + throw std::runtime_error("ArrowAppendString failed"); + }; } }; class DateReadHelper : public ReadHelper { - auto Read(const hyperapi::Value &value) -> nb::object override { + using ReadHelper::ReadHelper; + + auto Read(const hyperapi::Value &value) -> void override { if (value.isNull()) { - return nb::none(); + if (ArrowArrayAppendNull(array_, 1)) { + throw std::runtime_error("ArrowAppendNull failed"); + } + return; } + // TODO: need some bounds /overflow checking + // tableau uses uint32 but we have int32 + constexpr int32_t tableau_to_unix_days = 2440588; const auto hyper_date = value.get(); - const auto year = hyper_date.getYear(); - const auto month = hyper_date.getMonth(); - const auto day = hyper_date.getDay(); + const auto raw_value = static_cast(hyper_date.getRaw()); + const auto arrow_value = raw_value - tableau_to_unix_days; - PyObject *result = PyDate_FromDate(year, month, day); - if (result == nullptr) { - throw std::invalid_argument("could not parse date"); + struct ArrowBuffer *data_buffer = ArrowArrayBuffer(array_, 1); + if (ArrowBufferAppendInt32(data_buffer, arrow_value)) { + throw std::runtime_error("Failed to append date32 value"); } - return nb::object(result, nb::detail::steal_t{}); + + struct ArrowBitmap *validity_bitmap = ArrowArrayValidityBitmap(array_); + if (ArrowBitmapAppend(validity_bitmap, true, 1)) { + throw std::runtime_error("Could not append validity buffer for date32"); + }; + array_->length++; } }; template class DatetimeReadHelper : public ReadHelper { - auto Read(const hyperapi::Value &value) -> nb::object override { + using ReadHelper::ReadHelper; + + auto Read(const hyperapi::Value &value) -> void override { if (value.isNull()) { - return nb::none(); + if (ArrowArrayAppendNull(array_, 1)) { + throw std::runtime_error("ArrowAppendNull failed"); + } + return; } using timestamp_t = typename std::conditional::type; const auto hyper_ts = value.get(); - const auto hyper_date = hyper_ts.getDate(); - const auto hyper_time = hyper_ts.getTime(); - const auto year = hyper_date.getYear(); - const auto month = hyper_date.getMonth(); - const auto day = hyper_date.getDay(); - const auto hour = hyper_time.getHour(); - const auto min = hyper_time.getMinute(); - const auto sec = hyper_time.getSecond(); - const auto usec = hyper_time.getMicrosecond(); - - PyObject *result = - PyDateTime_FromDateAndTime(year, month, day, hour, min, sec, usec); - if (result == nullptr) { - throw std::invalid_argument("could not parse timestamp"); + + // TODO: need some bounds /overflow checking + // tableau uses uint64 but we have int64 + constexpr int64_t tableau_to_unix_usec = + 2440588LL * 24 * 60 * 60 * 1000 * 1000; + const auto raw_usec = static_cast(hyper_ts.getRaw()); + const auto arrow_value = raw_usec - tableau_to_unix_usec; + + struct ArrowBuffer *data_buffer = ArrowArrayBuffer(array_, 1); + if (ArrowBufferAppendInt64(data_buffer, arrow_value)) { + throw std::runtime_error("Failed to append timestamp64 value"); } - return nb::object(result, nb::detail::steal_t{}); + + struct ArrowBitmap *validity_bitmap = ArrowArrayValidityBitmap(array_); + if (ArrowBitmapAppend(validity_bitmap, true, 1)) { + throw std::runtime_error( + "Could not append validity buffer for timestamp"); + }; + array_->length++; } }; -static auto makeReadHelper(hyperapi::SqlType sqltype) +static auto makeReadHelper(const ArrowSchemaView *schema_view, + struct ArrowArray *array) -> std::unique_ptr { - if ((sqltype == hyperapi::SqlType::smallInt()) || - (sqltype == hyperapi::SqlType::integer()) || - (sqltype == hyperapi::SqlType::bigInt())) { - return std::unique_ptr(new IntegralReadHelper()); - } else if (sqltype == hyperapi::SqlType::doublePrecision()) { - return std::unique_ptr(new FloatReadHelper()); - } else if ((sqltype == hyperapi::SqlType::text())) { - return std::unique_ptr(new StringReadHelper()); - } else if (sqltype == hyperapi::SqlType::boolean()) { - return std::unique_ptr(new BooleanReadHelper()); - } else if (sqltype == hyperapi::SqlType::date()) { - return std::unique_ptr(new DateReadHelper()); - } else if (sqltype == hyperapi::SqlType::timestamp()) { - return std::unique_ptr(new DatetimeReadHelper()); - } else if (sqltype == hyperapi::SqlType::timestampTZ()) { - return std::unique_ptr(new DatetimeReadHelper()); + switch (schema_view->type) { + case NANOARROW_TYPE_INT16: + case NANOARROW_TYPE_INT32: + case NANOARROW_TYPE_INT64: + return std::unique_ptr(new IntegralReadHelper(array)); + case NANOARROW_TYPE_DOUBLE: + return std::unique_ptr(new FloatReadHelper(array)); + case NANOARROW_TYPE_LARGE_STRING: + return std::unique_ptr(new StringReadHelper(array)); + case NANOARROW_TYPE_BOOL: + return std::unique_ptr(new BooleanReadHelper(array)); + case NANOARROW_TYPE_DATE32: + return std::unique_ptr(new DateReadHelper(array)); + case NANOARROW_TYPE_TIMESTAMP: + if (strcmp("", schema_view->timezone)) { + return std::unique_ptr(new DatetimeReadHelper(array)); + } else { + return std::unique_ptr(new DatetimeReadHelper(array)); + } + default: + throw nb::type_error("unknownn arrow type provided"); } +} - throw nb::type_error(("cannot read sql type: " + sqltype.toString()).c_str()); +static auto +arrowTypeFromHyper(const hyperapi::SqlType &sqltype) -> enum ArrowType { + if (sqltype == hyperapi::SqlType::smallInt()){return NANOARROW_TYPE_INT16;} +else if (sqltype == hyperapi::SqlType::integer()) { + return NANOARROW_TYPE_INT32; +} +else if (sqltype == hyperapi::SqlType::bigInt()) { + return NANOARROW_TYPE_INT64; +} +else if (sqltype == hyperapi::SqlType::doublePrecision()) { + return NANOARROW_TYPE_DOUBLE; +} +else if (sqltype == hyperapi::SqlType::text()) { + return NANOARROW_TYPE_LARGE_STRING; +} +else if (sqltype == hyperapi::SqlType::boolean()) { + return NANOARROW_TYPE_BOOL; +} +else if (sqltype == hyperapi::SqlType::timestamp()) { + return NANOARROW_TYPE_TIMESTAMP; +} +else if (sqltype == hyperapi::SqlType::timestampTZ()) { + return NANOARROW_TYPE_TIMESTAMP; // todo: how to encode tz info? +} +else if (sqltype == hyperapi::SqlType::date()) { + return NANOARROW_TYPE_DATE32; } -static auto pandasDtypeFromHyper(const hyperapi::SqlType &sqltype) - -> std::string { - if (sqltype == hyperapi::SqlType::smallInt()) { - return "int16[pyarrow]"; - } else if (sqltype == hyperapi::SqlType::integer()) { - return "int32[pyarrow]"; - } else if (sqltype == hyperapi::SqlType::bigInt()) { - return "int64[pyarrow]"; - } else if (sqltype == hyperapi::SqlType::doublePrecision()) { - return "double[pyarrow]"; - } else if (sqltype == hyperapi::SqlType::text()) { - return "string[pyarrow]"; - } else if (sqltype == hyperapi::SqlType::boolean()) { - return "boolean[pyarrow]"; - } else if (sqltype == hyperapi::SqlType::timestamp()) { - return "timestamp[us][pyarrow]"; - } else if (sqltype == hyperapi::SqlType::timestampTZ()) { - return "timestamp[us, UTC][pyarrow]"; - } else if (sqltype == hyperapi::SqlType::date()) { - return "date32[pyarrow]"; - } +throw nb::type_error( + ("unimplemented pandas dtype for type: " + sqltype.toString()).c_str()); +} - throw nb::type_error( - ("unimplemented pandas dtype for type: " + sqltype.toString()).c_str()); +static auto releaseArrowStream(void *ptr) noexcept -> void { + auto stream = static_cast(ptr); + if (stream->release != nullptr) { + ArrowArrayStreamRelease(stream); + } } -using ColumnNames = std::vector; -using ResultBody = std::vector>; -// In a future version of pantab it would be nice to not require pandas dtypes -// However, the current reader just creates PyObjects and loses that information -// when passing back to the Python runtime; hence the explicit passing -using PandasDtypes = std::vector; /// /// read_from_hyper_query is slightly different than read_from_hyper_table /// because the former detects a schema from the hyper Result object /// which does not hold nullability information /// auto read_from_hyper_query(const std::string &path, const std::string &query) - -> std::tuple { - std::vector> result; + -> nb::capsule { hyperapi::HyperProcess hyper{ hyperapi::Telemetry::DoNotSendUsageDataToTableau}; hyperapi::Connection connection(hyper.getEndpoint(), path); - std::vector columnNames; - std::vector pandasDtypes; - std::vector> read_helpers; - hyperapi::Result hyperResult = connection.executeQuery(query); const auto resultSchema = hyperResult.getSchema(); - for (const auto &column : resultSchema.getColumns()) { - read_helpers.push_back(makeReadHelper(column.getType())); + + auto schema = std::unique_ptr{new (struct ArrowSchema)}; + + ArrowSchemaInit(schema.get()); + if (ArrowSchemaSetTypeStruct(schema.get(), resultSchema.getColumnCount())) { + throw std::runtime_error("ArrowSchemaSetTypeStruct failed"); + } + + const auto column_count = resultSchema.getColumnCount(); + for (size_t i = 0; i < column_count; i++) { + const auto column = resultSchema.getColumn(i); auto name = column.getName().getUnescaped(); - columnNames.push_back(name); + if (ArrowSchemaSetName(schema->children[i], name.c_str())) { + throw std::runtime_error("ArrowSchemaSetName failed"); + } - // the query result set does not tell us if columns are nullable or not auto const sqltype = column.getType(); - pandasDtypes.push_back(pandasDtypeFromHyper(sqltype)); - } - for (const hyperapi::Row &row : hyperResult) { - std::vector rowdata; - size_t column_idx = 0; - for (const hyperapi::Value &value : row) { - const auto &read_helper = read_helpers[column_idx]; - rowdata.push_back(read_helper->Read(value)); - column_idx++; + if (sqltype.getTag() == hyperapi::TypeTag::TimestampTZ) { + if (ArrowSchemaSetTypeDateTime(schema->children[i], + NANOARROW_TYPE_TIMESTAMP, + NANOARROW_TIME_UNIT_MICRO, "UTC")) { + throw std::runtime_error("ArrowSchemaSetDateTime failed"); + } + } else if (sqltype.getTag() == hyperapi::TypeTag::Timestamp) { + if (ArrowSchemaSetTypeDateTime(schema->children[i], + NANOARROW_TYPE_TIMESTAMP, + NANOARROW_TIME_UNIT_MICRO, nullptr)) { + throw std::runtime_error("ArrowSchemaSetDateTime failed"); + } + } else { + const enum ArrowType arrow_type = arrowTypeFromHyper(sqltype); + if (ArrowSchemaSetType(schema->children[i], arrow_type)) { + throw std::runtime_error("ArrowSchemaSetType failed"); + } } - result.push_back(rowdata); } - return std::make_tuple(result, columnNames, pandasDtypes); -} - -auto read_from_hyper_table(const std::string &path, const std::string &schema, - const std::string &table) - -> std::tuple { - std::vector> result; - hyperapi::HyperProcess hyper{ - hyperapi::Telemetry::DoNotSendUsageDataToTableau}; - hyperapi::Connection connection(hyper.getEndpoint(), path); - hyperapi::TableName extractTable{schema, table}; - const hyperapi::Catalog &catalog = connection.getCatalog(); - const hyperapi::TableDefinition tableDef = - catalog.getTableDefinition(extractTable); - - std::vector columnNames; - std::vector pandasDtypes; - std::vector> read_helpers; - - for (auto &column : tableDef.getColumns()) { - read_helpers.push_back(makeReadHelper(column.getType())); - auto name = column.getName().getUnescaped(); - columnNames.push_back(name); + auto array = std::unique_ptr{new (struct ArrowArray)}; + if (ArrowArrayInitFromSchema(array.get(), schema.get(), nullptr)) { + throw std::runtime_error("ArrowSchemaInitFromSchema failed"); + } + std::vector> read_helpers{column_count}; + for (size_t i = 0; i < column_count; i++) { + struct ArrowSchemaView schema_view; + if (ArrowSchemaViewInit(&schema_view, schema->children[i], nullptr)) { + throw std::runtime_error("ArrowSchemaViewInit failed"); + } - auto const sqltype = column.getType(); - pandasDtypes.push_back(pandasDtypeFromHyper(sqltype)); + auto read_helper = makeReadHelper(&schema_view, array->children[i]); + read_helpers[i] = std::move(read_helper); } - hyperapi::Result hyperResult = - connection.executeQuery("SELECT * FROM " + extractTable.toString()); + if (ArrowArrayStartAppending(array.get())) { + throw std::runtime_error("ArrowArrayStartAppending failed"); + } for (const hyperapi::Row &row : hyperResult) { - std::vector rowdata; size_t column_idx = 0; for (const hyperapi::Value &value : row) { const auto &read_helper = read_helpers[column_idx]; - rowdata.push_back(read_helper->Read(value)); + read_helper->Read(value); column_idx++; } - result.push_back(rowdata); + if (ArrowArrayFinishElement(array.get())) { + throw std::runtime_error("ArrowArrayFinishElement failed"); + } + } + if (ArrowArrayFinishBuildingDefault(array.get(), nullptr)) { + throw std::runtime_error("ArrowArrayFinishBuildingDefault failed"); + } + + auto stream = + (struct ArrowArrayStream *)malloc(sizeof(struct ArrowArrayStream)); + if (ArrowBasicArrayStreamInit(stream, schema.get(), 1)) { + free(stream); + throw std::runtime_error("ArrowBasicArrayStreamInit failed"); } + ArrowBasicArrayStreamSetArray(stream, 0, array.get()); - return std::make_tuple(result, columnNames, pandasDtypes); + nb::capsule result{stream, "arrow_array_stream", &releaseArrowStream}; + return result; } NB_MODULE(pantab, m) { // NOLINT m.def("write_to_hyper", &write_to_hyper, nb::arg("dict_of_exportable"), nb::arg("path"), nb::arg("table_mode")) .def("read_from_hyper_query", &read_from_hyper_query, nb::arg("path"), - nb::arg("query")) - .def("read_from_hyper_table", &read_from_hyper_table, nb::arg("path"), - nb::arg("schema"), nb::arg("table")); + nb::arg("query")); PyDateTime_IMPORT; } diff --git a/pantab/tests/conftest.py b/pantab/tests/conftest.py index e1ab9641..08c9231e 100644 --- a/pantab/tests/conftest.py +++ b/pantab/tests/conftest.py @@ -170,14 +170,14 @@ def roundtripped(): "datetime64": "timestamp[us][pyarrow]", "datetime64_utc": "timestamp[us, UTC][pyarrow]", # "timedelta64": "timedelta64[ns]", - "object": "string[pyarrow]", + "object": "large_string[pyarrow]", "int16_limits": "int16[pyarrow]", "int32_limits": "int32[pyarrow]", "int64_limits": "int64[pyarrow]", "float32_limits": "double[pyarrow]", "float64_limits": "double[pyarrow]", - "non-ascii": "string[pyarrow]", - "string": "string[pyarrow]", + "non-ascii": "large_string[pyarrow]", + "string": "large_string[pyarrow]", } ) return df diff --git a/pantab/tests/test_reader.py b/pantab/tests/test_reader.py index 466089e0..0a237bb9 100644 --- a/pantab/tests/test_reader.py +++ b/pantab/tests/test_reader.py @@ -54,7 +54,7 @@ def test_reads_non_writeable(datapath): "double[pyarrow]" ) expected["Non-Nullable String"] = expected["Non-Nullable String"].astype( - "string[pyarrow]" + "large_string[pyarrow]" ) tm.assert_frame_equal(result, expected) @@ -67,7 +67,7 @@ def test_read_query(df, tmp_hyper): result = pantab.frame_from_hyper_query(tmp_hyper, query) expected = pd.DataFrame([[1, "_2"], [6, "_7"], [0, "_0"]], columns=["i", "_i2"]) - expected = expected.astype({"i": "int16[pyarrow]", "_i2": "string[pyarrow]"}) + expected = expected.astype({"i": "int16[pyarrow]", "_i2": "large_string[pyarrow]"}) tm.assert_frame_equal(result, expected) diff --git a/pantab/tests/test_roundtrip.py b/pantab/tests/test_roundtrip.py index 32c9f065..10572715 100644 --- a/pantab/tests/test_roundtrip.py +++ b/pantab/tests/test_roundtrip.py @@ -15,10 +15,6 @@ def test_basic(df, roundtripped, tmp_hyper, table_name, table_mode): if table_mode == "a": expected = pd.concat([expected, expected]).reset_index(drop=True) - # TODO: somehow concat turns string[pyarrow] into string python - for col in ("object", "non-ascii", "string"): - expected[col] = expected[col].astype("string[pyarrow]") - tm.assert_frame_equal(result, expected) @@ -36,10 +32,6 @@ def test_multiple_tables(df, roundtripped, tmp_hyper, table_name, table_mode): if table_mode == "a": expected = pd.concat([expected, expected]).reset_index(drop=True) - # TODO: somehow concat turns string[pyarrow] into string python - for col in ("object", "non-ascii", "string"): - expected[col] = expected[col].astype("string[pyarrow]") - # some test trickery here if not isinstance(table_name, TableName) or table_name.schema_name is None: table_name = TableName("public", table_name)