diff --git a/CMakeLists.txt b/CMakeLists.txt index 19c799b1..48c45f00 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -44,7 +44,7 @@ find_package(tableauhyperapi-cxx CONFIG REQUIRED) FetchContent_Declare(nanoarrow-project GIT_REPOSITORY https://github.com/apache/arrow-nanoarrow.git - GIT_TAG dab87aaea4c2c05d24b745d58e50726bd0553452 + GIT_TAG 97e7c61d95456f4753f58cc9e6742800b08b378a ) FetchContent_MakeAvailable(nanoarrow-project) diff --git a/pyproject.toml b/pyproject.toml index 479feaa5..88297531 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,7 @@ test-requires = [ "pytest", "pytest-xdist[psutil]", "pandas>=2.0.0", - "polars~=1.2.0", + "polars", "narwhals", "tableauhyperapi", ] diff --git a/src/pantab/writer.cpp b/src/pantab/writer.cpp index 2739ebeb..164226ab 100644 --- a/src/pantab/writer.cpp +++ b/src/pantab/writer.cpp @@ -35,9 +35,11 @@ static auto GetHyperTypeFromArrowSchema(struct ArrowSchema *schema, return hyperapi::SqlType::boolean(); case NANOARROW_TYPE_BINARY: case NANOARROW_TYPE_LARGE_BINARY: + case NANOARROW_TYPE_BINARY_VIEW: return hyperapi::SqlType::bytes(); case NANOARROW_TYPE_STRING: case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_STRING_VIEW: return hyperapi::SqlType::text(); case NANOARROW_TYPE_DATE32: return hyperapi::SqlType::date(); @@ -201,6 +203,47 @@ template class Utf8InsertHelper : public InsertHelper { } }; +template class BinaryViewInsertHelper : public InsertHelper { +public: + using InsertHelper::InsertHelper; + + void InsertValueAtIndex(size_t idx) override { + if (ArrowArrayViewIsNull(array_view_.get(), idx)) { + // MSVC on cibuildwheel doesn't like this templated optional + hyperapi::internal::ValueInserter{inserter_}.addNull(); + return; + } + + const union ArrowBinaryView bv = + array_view_->buffer_views[1].data.as_binary_view[idx]; + struct ArrowBufferView bin_data = {{NULL}, bv.inlined.size}; + if (bv.inlined.size <= NANOARROW_BINARY_VIEW_INLINE_SIZE) { + bin_data.data.as_uint8 = bv.inlined.data; + } else { + const int32_t buf_index = + bv.ref.buffer_index + NANOARROW_BINARY_VIEW_FIXED_BUFFERS; + bin_data.data.data = array_view_->array->buffers[buf_index]; + bin_data.data.as_uint8 += bv.ref.offset; + } + + if constexpr (IsString) { +#if defined(_WIN32) && defined(_MSC_VER) + const std::string result(bin_data.data.as_char, + static_cast(bin_data.size_bytes)); +#else + const std::string_view result{bin_data.data.as_char, + static_cast(bin_data.size_bytes)}; + +#endif + hyperapi::internal::ValueInserter{inserter_}.addValue(result); + } else { + const hyperapi::ByteSpan result{bin_data.data.as_uint8, + static_cast(bin_data.size_bytes)}; + hyperapi::internal::ValueInserter{inserter_}.addValue(result); + } + } +}; + class Date32InsertHelper : public InsertHelper { public: using InsertHelper::InsertHelper; @@ -529,6 +572,12 @@ static auto MakeInsertHelper(hyperapi::Inserter &inserter, return std::make_unique( inserter, chunk, schema, error, column_position, precision, scale); } + case NANOARROW_TYPE_BINARY_VIEW: + return std::make_unique>( + inserter, chunk, schema, error, column_position); + case NANOARROW_TYPE_STRING_VIEW: + return std::make_unique>( + inserter, chunk, schema, error, column_position); default: throw std::invalid_argument( std::string("MakeInsertHelper: Unsupported Arrow type: ") + diff --git a/tests/conftest.py b/tests/conftest.py index f867211c..2dd78f45 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,6 +44,8 @@ def basic_arrow_table(): ("time64us", pa.time64("us")), ("geography", pa.large_binary()), ("decimal", pa.decimal128(38, 10)), + ("string_view", pa.string_view()), + ("binary_view", pa.binary_view()), ] ) tbl = pa.Table.from_arrays( @@ -106,6 +108,8 @@ def basic_arrow_table(): ] ), pa.array(["1234567890.123456789", "99876543210.987654321", None]), + pa.array(["foo", "longer_than_prefix_size", None], type=pa.string_view()), + pa.array([b"foo", b"longer_than_prefix_size", None], type=pa.binary_view()), ], schema=schema, ) @@ -286,6 +290,14 @@ def basic_dataframe(): ["1234567890.123456789", "99876543210.987654321", None], dtype=pd.ArrowDtype(pa.decimal128(38, 10)), ) + """ + df["string_view"] = pd.Series( + ["foo", "longer_than_prefix_size", None], + dtype=pd.ArrowDtype(pa.string_view())), + df["binary_view"] = pd.Series( + [b"foo", b"longer_than_prefix_size", None], + dtype=pd.ArrowDtype(pa.binary_view())), + """ return df @@ -354,11 +366,23 @@ def roundtripped_pyarrow(): ("time64us", pa.time64("us")), ("geography", pa.large_binary()), ("decimal", pa.decimal128(38, 10)), + # ("string_view", pa.large_string()), + # ("binary_view", pa.large_binary()), ] ) tbl = basic_arrow_table() - return tbl.cast(schema) + # pyarrow does not support casting from string_view to large_string, + # so we have to handle manually + tbl = tbl.drop_columns(["string_view", "binary_view"]) + tbl = tbl.cast(schema) + + sv = (pa.array(["foo", "longer_than_prefix_size", None], type=pa.large_string()),) + bv = pa.array([b"foo", b"longer_than_prefix_size", None], type=pa.large_binary()) + tbl = tbl.append_column("string_view", sv) + tbl = tbl.append_column("binary_view", bv) + + return tbl def roundtripped_pandas(): @@ -394,6 +418,8 @@ def roundtripped_pandas(): # "interval": "month_day_nano_interval[pyarrow]", "time64us": "time64[us][pyarrow]", "geography": "large_binary[pyarrow]", + # "string_view": "string_view[pyarrow]", + # "binary_view": "binary_view[pyarrow]", } ) return df @@ -518,7 +544,10 @@ def add_non_writeable_column(frame): return frame elif isinstance(frame, pl.DataFrame): frame = frame.with_columns( - pl.Series(name="should_fail", values=[list((1, 2))]) + pl.Series( + name="should_fail", + values=[list((1, 2)), list((1, 2)), list((1, 2))], + ) ) return frame else: diff --git a/tests/test_roundtrip.py b/tests/test_roundtrip.py index 95710311..4893cd2d 100644 --- a/tests/test_roundtrip.py +++ b/tests/test_roundtrip.py @@ -40,6 +40,12 @@ def test_basic(frame, roundtripped, tmp_hyper, table_name, table_mode, compat): if table_mode == "a": expected = compat.concat_frames(expected, expected) + if isinstance(frame, pd.DataFrame) and return_type != "pandas": + expected = compat.drop_columns(expected, ["string_view", "binary_view"]) + + if return_type == "pandas" and not isinstance(frame, pd.DataFrame): + result = compat.drop_columns(result, ["string_view", "binary_view"]) + compat.assert_frame_equal(result, expected) @@ -78,6 +84,9 @@ def test_multiple_tables( if not isinstance(table_name, tab_api.TableName) or table_name.schema_name is None: table_name = tab_api.TableName("public", table_name) + if isinstance(frame, pd.DataFrame) and return_type != "pandas": + expected = compat.drop_columns(expected, ["string_view", "binary_view"]) + assert set(result.keys()) == set( ( tuple(table_name._unescaped_components), @@ -85,6 +94,9 @@ def test_multiple_tables( ) ) for val in result.values(): + if return_type == "pandas" and not isinstance(frame, pd.DataFrame): + val = compat.drop_columns(val, ["string_view", "binary_view"]) + compat.assert_frame_equal(val, expected) @@ -120,6 +132,12 @@ def test_empty_roundtrip( result = pt.frame_from_hyper(tmp_hyper, table=table_name, return_type=return_type) + if isinstance(frame, pd.DataFrame) and return_type != "pandas": + expected = compat.drop_columns(expected, ["string_view", "binary_view"]) + + if return_type == "pandas" and not isinstance(frame, pd.DataFrame): + result = compat.drop_columns(result, ["string_view", "binary_view"]) + expected = compat.drop_columns(expected, ["object"]) expected = compat.empty_like(expected) compat.assert_frame_equal(result, expected)