From b7474c47878a9816fcc76d8122a76896c231eb7f Mon Sep 17 00:00:00 2001 From: Will Ayd Date: Tue, 1 Oct 2024 13:06:36 -0400 Subject: [PATCH] Support writing dictionary-encoded strings --- src/pantab/writer.cpp | 58 +++++++++++++++++++++++++++++++++++++++++++ tests/conftest.py | 9 ++++++- 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/src/pantab/writer.cpp b/src/pantab/writer.cpp index 7e7f824a..ea7fe199 100644 --- a/src/pantab/writer.cpp +++ b/src/pantab/writer.cpp @@ -64,6 +64,25 @@ static auto GetHyperTypeFromArrowSchema(struct ArrowSchema *schema, constexpr int16_t scale = 0; return hyperapi::SqlType::numeric(precision, scale); } + case NANOARROW_TYPE_DICTIONARY: { + struct ArrowSchemaView value_view {}; + struct ArrowError error {}; + NANOARROW_THROW_NOT_OK( + ArrowSchemaViewInit(&value_view, schema->dictionary, &error)); + + // only support dictionary-encoded string values for now + switch (value_view.type) { + case NANOARROW_TYPE_STRING: + case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_STRING_VIEW: + return hyperapi::SqlType::text(); + default: + throw std::invalid_argument( + std::string( + "Can only encode dictionaries with binary value types, got:") + + ArrowTypeString(value_view.type)); + } + } default: throw std::invalid_argument(std::string("Unsupported Arrow type: ") + ArrowTypeString(schema_view.type)); @@ -467,6 +486,26 @@ class DecimalInsertHelper : public InsertHelper { int32_t scale_; }; +class DictionaryInsertHelper : public InsertHelper { +public: + using InsertHelper::InsertHelper; + + void InsertValueAtIndex(int64_t idx) override { + if (CheckNull(idx)) { + InsertNull(); + return; + } + + const auto key = ArrowArrayViewGetIntUnsafe(GetArrayView(), idx); + const auto value = + ArrowArrayViewGetStringUnsafe(GetArrayView()->dictionary, key); + + const hyperapi::string_view result{value.data, + static_cast(value.size_bytes)}; + InsertValue(std::move(result)); + } +}; + static auto MakeInsertHelper(hyperapi::Inserter &inserter, struct ArrowArray *chunk, const struct ArrowSchema *schema, @@ -593,6 +632,25 @@ static auto MakeInsertHelper(hyperapi::Inserter &inserter, case NANOARROW_TYPE_STRING_VIEW: return std::make_unique>( inserter, chunk, child_schema, error, column_position); + case NANOARROW_TYPE_DICTIONARY: { + struct ArrowSchemaView value_view {}; + NANOARROW_THROW_NOT_OK( + ArrowSchemaViewInit(&value_view, child_schema->dictionary, error)); + + // only support dictionary-encoded string values for now + switch (value_view.type) { + case NANOARROW_TYPE_STRING: + case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_STRING_VIEW: + return std::make_unique( + inserter, chunk, child_schema, error, column_position); + default: + throw std::invalid_argument( + std::string("MakeInsertHelper: Can only encode dictionaries with " + "binary value types, got:") + + ArrowTypeString(value_view.type)); + } + } default: throw std::invalid_argument( std::string("MakeInsertHelper: Unsupported Arrow type: ") + diff --git a/tests/conftest.py b/tests/conftest.py index 2dd78f45..f6fe69b6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -46,6 +46,7 @@ def basic_arrow_table(): ("decimal", pa.decimal128(38, 10)), ("string_view", pa.string_view()), ("binary_view", pa.binary_view()), + ("categorical", pa.dictionary(pa.int8(), pa.utf8())), ] ) tbl = pa.Table.from_arrays( @@ -110,6 +111,7 @@ def basic_arrow_table(): pa.array(["1234567890.123456789", "99876543210.987654321", None]), pa.array(["foo", "longer_than_prefix_size", None], type=pa.string_view()), pa.array([b"foo", b"longer_than_prefix_size", None], type=pa.binary_view()), + pa.array(["foo", "foo", None]), ], schema=schema, ) @@ -291,6 +293,7 @@ def basic_dataframe(): dtype=pd.ArrowDtype(pa.decimal128(38, 10)), ) """ + https://github.com/pandas-dev/pandas/issues/59883 df["string_view"] = pd.Series( ["foo", "longer_than_prefix_size", None], dtype=pd.ArrowDtype(pa.string_view())), @@ -298,6 +301,7 @@ def basic_dataframe(): [b"foo", b"longer_than_prefix_size", None], dtype=pd.ArrowDtype(pa.binary_view())), """ + df["categorical"] = pd.Series(["foo", "foo", pd.NA]).astype(pd.CategoricalDtype()) return df @@ -374,13 +378,15 @@ def roundtripped_pyarrow(): # pyarrow does not support casting from string_view to large_string, # so we have to handle manually - tbl = tbl.drop_columns(["string_view", "binary_view"]) + tbl = tbl.drop_columns(["string_view", "binary_view", "categorical"]) tbl = tbl.cast(schema) sv = (pa.array(["foo", "longer_than_prefix_size", None], type=pa.large_string()),) bv = pa.array([b"foo", b"longer_than_prefix_size", None], type=pa.large_binary()) + cat = pa.array(["foo", "foo", None], type=pa.large_string()) tbl = tbl.append_column("string_view", sv) tbl = tbl.append_column("binary_view", bv) + tbl = tbl.append_column("categorical", cat) return tbl @@ -420,6 +426,7 @@ def roundtripped_pandas(): "geography": "large_binary[pyarrow]", # "string_view": "string_view[pyarrow]", # "binary_view": "binary_view[pyarrow]", + "categorical": "large_string[pyarrow]", } ) return df