Skip to content

Commit

Permalink
Implement String/Binary View Support
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd committed Sep 24, 2024
1 parent f397e71 commit d0bfd5b
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 3 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ find_package(tableauhyperapi-cxx CONFIG REQUIRED)

FetchContent_Declare(nanoarrow-project
GIT_REPOSITORY https://github.com/apache/arrow-nanoarrow.git
GIT_TAG dab87aaea4c2c05d24b745d58e50726bd0553452
GIT_TAG 6f8badb649d8416778d81598867adc7263d735ad
)
FetchContent_MakeAvailable(nanoarrow-project)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ test-requires = [
"pytest",
"pytest-xdist[psutil]",
"pandas>=2.0.0",
"polars~=1.2.0",
"polars",
"narwhals",
"tableauhyperapi",
]
Expand Down
47 changes: 47 additions & 0 deletions src/pantab/writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ static auto GetHyperTypeFromArrowSchema(struct ArrowSchema *schema,
return hyperapi::SqlType::boolean();
case NANOARROW_TYPE_BINARY:
case NANOARROW_TYPE_LARGE_BINARY:
case NANOARROW_TYPE_BINARY_VIEW:
return hyperapi::SqlType::bytes();
case NANOARROW_TYPE_STRING:
case NANOARROW_TYPE_LARGE_STRING:
case NANOARROW_TYPE_STRING_VIEW:
return hyperapi::SqlType::text();
case NANOARROW_TYPE_DATE32:
return hyperapi::SqlType::date();
Expand Down Expand Up @@ -201,6 +203,45 @@ template <typename OffsetT> class Utf8InsertHelper : public InsertHelper {
}
};

template <bool IsString> class BinaryViewInsertHelper : public InsertHelper {
public:
using InsertHelper::InsertHelper;

void InsertValueAtIndex(size_t idx) override {
if (ArrowArrayViewIsNull(array_view_.get(), idx)) {
// MSVC on cibuildwheel doesn't like this templated optional
hyperapi::internal::ValueInserter{inserter_}.addNull();
return;
}

const union ArrowBinaryView bv =
array_view_->buffer_views[1].data.as_binary_view[idx];
struct ArrowBufferView bin_data = {{NULL}, bv.inlined.size};
if (bv.inlined.size <= NANOARROW_BINARY_VIEW_INLINE_SIZE) {
bin_data.data.as_uint8 = bv.inlined.data;
} else {
const int32_t buf_index =
bv.ref.buffer_index + NANOARROW_BINARY_VIEW_FIXED_BUFFERS;
bin_data.data.data = array_view_->array->buffers[buf_index];
bin_data.data.as_uint8 += bv.ref.offset;
}

if constexpr (IsString) {
#if defined(_WIN32) && defined(_MSC_VER)
const std::string result{bin_data.data, bin_data.size_bytes};
#else
const std::string_view result{bin_data.data.as_char,
static_cast<size_t>(bin_data.size_bytes)};
#endif
hyperapi::internal::ValueInserter{inserter_}.addValue(result);
} else {
const hyperapi::ByteSpan result{bin_data.data.as_uint8,
static_cast<size_t>(bin_data.size_bytes)};
hyperapi::internal::ValueInserter{inserter_}.addValue(result);
}
}
};

class Date32InsertHelper : public InsertHelper {
public:
using InsertHelper::InsertHelper;
Expand Down Expand Up @@ -517,6 +558,12 @@ static auto MakeInsertHelper(hyperapi::Inserter &inserter,
return std::make_unique<DecimalInsertHelper>(
inserter, chunk, schema, error, column_position, precision, scale);
}
case NANOARROW_TYPE_BINARY_VIEW:
return std::make_unique<BinaryViewInsertHelper<false>>(
inserter, chunk, schema, error, column_position);
case NANOARROW_TYPE_STRING_VIEW:
return std::make_unique<BinaryViewInsertHelper<true>>(
inserter, chunk, schema, error, column_position);
default:
throw std::invalid_argument(
std::string("MakeInsertHelper: Unsupported Arrow type: ") +
Expand Down
28 changes: 27 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def basic_arrow_table():
("time64us", pa.time64("us")),
("geography", pa.large_binary()),
("decimal", pa.decimal128(38, 10)),
("string_view", pa.string_view()),
("binary_view", pa.binary_view()),
]
)
tbl = pa.Table.from_arrays(
Expand Down Expand Up @@ -106,6 +108,8 @@ def basic_arrow_table():
]
),
pa.array(["1234567890.123456789", "99876543210.987654321", None]),
pa.array(["foo", "longer_than_prefix_size", None], type=pa.string_view()),
pa.array([b"foo", b"longer_than_prefix_size", None], type=pa.binary_view()),
],
schema=schema,
)
Expand Down Expand Up @@ -286,6 +290,14 @@ def basic_dataframe():
["1234567890.123456789", "99876543210.987654321", None],
dtype=pd.ArrowDtype(pa.decimal128(38, 10)),
)
"""
df["string_view"] = pd.Series(
["foo", "longer_than_prefix_size", None],
dtype=pd.ArrowDtype(pa.string_view())),
df["binary_view"] = pd.Series(
[b"foo", b"longer_than_prefix_size", None],
dtype=pd.ArrowDtype(pa.binary_view())),
"""

return df

Expand Down Expand Up @@ -354,11 +366,23 @@ def roundtripped_pyarrow():
("time64us", pa.time64("us")),
("geography", pa.large_binary()),
("decimal", pa.decimal128(38, 10)),
# ("string_view", pa.large_string()),
# ("binary_view", pa.large_binary()),
]
)
tbl = basic_arrow_table()

return tbl.cast(schema)
# pyarrow does not support casting from string_view to large_string,
# so we have to handle manually
tbl = tbl.drop_columns(["string_view", "binary_view"])
tbl = tbl.cast(schema)

sv = (pa.array(["foo", "longer_than_prefix_size", None], type=pa.large_string()),)
bv = pa.array([b"foo", b"longer_than_prefix_size", None], type=pa.large_binary())
tbl = tbl.append_column("string_view", sv)
tbl = tbl.append_column("binary_view", bv)

return tbl


def roundtripped_pandas():
Expand Down Expand Up @@ -394,6 +418,8 @@ def roundtripped_pandas():
# "interval": "month_day_nano_interval[pyarrow]",
"time64us": "time64[us][pyarrow]",
"geography": "large_binary[pyarrow]",
# "string_view": "string_view[pyarrow]",
# "binary_view": "binary_view[pyarrow]",
}
)
return df
Expand Down
18 changes: 18 additions & 0 deletions tests/test_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ def test_basic(frame, roundtripped, tmp_hyper, table_name, table_mode, compat):
if table_mode == "a":
expected = compat.concat_frames(expected, expected)

if isinstance(frame, pd.DataFrame) and return_type != "pandas":
expected = compat.drop_columns(expected, ["string_view", "binary_view"])

if return_type == "pandas" and not isinstance(frame, pd.DataFrame):
result = compat.drop_columns(result, ["string_view", "binary_view"])

compat.assert_frame_equal(result, expected)


Expand Down Expand Up @@ -77,13 +83,19 @@ def test_multiple_tables(
if not isinstance(table_name, tab_api.TableName) or table_name.schema_name is None:
table_name = tab_api.TableName("public", table_name)

if isinstance(frame, pd.DataFrame) and return_type != "pandas":
expected = compat.drop_columns(expected, ["string_view", "binary_view"])

assert set(result.keys()) == set(
(
tuple(table_name._unescaped_components),
tuple(tab_api.TableName("public", "table2")._unescaped_components),
)
)
for val in result.values():
if return_type == "pandas" and not isinstance(frame, pd.DataFrame):
val = compat.drop_columns(val, ["string_view", "binary_view"])

compat.assert_frame_equal(val, expected)


Expand Down Expand Up @@ -119,6 +131,12 @@ def test_empty_roundtrip(

result = pt.frame_from_hyper(tmp_hyper, table=table_name, return_type=return_type)

if isinstance(frame, pd.DataFrame) and return_type != "pandas":
expected = compat.drop_columns(expected, ["string_view", "binary_view"])

if return_type == "pandas" and not isinstance(frame, pd.DataFrame):
result = compat.drop_columns(result, ["string_view", "binary_view"])

expected = compat.drop_columns(expected, ["object"])
expected = compat.empty_like(expected)
compat.assert_frame_equal(result, expected)
Expand Down

0 comments on commit d0bfd5b

Please sign in to comment.