diff --git a/src/pantab/reader.cpp b/src/pantab/reader.cpp index 0f2f29bf..de5f931e 100644 --- a/src/pantab/reader.cpp +++ b/src/pantab/reader.cpp @@ -1,7 +1,6 @@ #include "reader.hpp" -#include -#include +#include #include #include @@ -249,52 +248,21 @@ class IntervalReadHelper : public ReadHelper { }; // The Tableau Hyper API requires Numeric to be templated at compile time -// but the values are only known at runtime. To work around this limitation -// we generate a map of functions using precision and scale values in the -// range of 0..38 -using funcMapType = - std::map, - std::function>; - -template struct NumericReader { - static void read(funcMapType &func_map) { - NumericReader::read(func_map); - NumericReader::read(func_map); - NumericReader::read(func_map); - } -}; - -template struct NumericReader { - static void read(funcMapType &func_map) { - func_map.emplace( - std::make_pair(Precision, Scale), [](const hyperapi::Value &value) { - const auto decimal_value = - value.get>(); - auto decimal_string = decimal_value.toString(); - // C++20 std::erase would really simplify this! - decimal_string.erase( - std::remove(decimal_string.begin(), decimal_string.end(), '.'), - decimal_string.end()); - return decimal_string; - }); - } -}; - -static funcMapType initializeNumericReaderMap() { - funcMapType numeric_creators; - NumericReader<38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, - 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, - 5, 4, 3, 2, 1, 0>::read(numeric_creators); - return numeric_creators; -}; +// but the values are only known at runtime. This solution is adopted from +// https://stackoverflow.com/questions/78888913/creating-cartesian-product-from-integer-range-template-argument/78889229?noredirect=1#comment139097273_78889229 +template constexpr auto to_integral_variant(std::size_t n) { + return [&](std::index_sequence) { + using ResType = std::variant...>; + ResType all[] = {ResType{std::integral_constant{}}...}; + return all[n]; + }(std::make_index_sequence()); +} class DecimalReadHelper : public ReadHelper { public: explicit DecimalReadHelper(struct ArrowArray *array, int32_t precision, int32_t scale) - : ReadHelper(array), precision_(precision), scale_(scale) { - numeric_read_mapper_ = initializeNumericReaderMap(); - } + : ReadHelper(array), precision_(precision), scale_(scale) {} auto Read(const hyperapi::Value &value) -> void override { if (value.isNull()) { @@ -308,9 +276,30 @@ class DecimalReadHelper : public ReadHelper { struct ArrowDecimal decimal; ArrowDecimalInit(&decimal, bitwidth, precision_, scale_); - const auto readFunc = - numeric_read_mapper_.at(std::make_pair(precision_, scale_)); - const auto decimal_string = readFunc(value); + constexpr auto MaxPrecision = 39; // of-by-one error in solution? + if (precision_ >= MaxPrecision) { + throw nb::value_error("Numeric precision may not exceed 38!"); + } + if (scale_ >= MaxPrecision) { + throw nb::value_error("Numeric scale may not exceed 38!"); + } + + const auto decimal_string = std::visit( + [&value](auto P, auto S) -> std::string { + if constexpr (S() <= P()) { + const auto decimal_value = value.get>(); + auto value_string = decimal_value.toString(); + // C++20 std::erase would really simplify this! + value_string.erase( + std::remove(value_string.begin(), value_string.end(), '.'), + value_string.end()); + return value_string; + } + throw "unreachable"; + }, + to_integral_variant(precision_), + to_integral_variant(scale_)); + const struct ArrowStringView sv { decimal_string.data(), static_cast(decimal_string.size()) }; @@ -326,7 +315,6 @@ class DecimalReadHelper : public ReadHelper { } private: - funcMapType numeric_read_mapper_; int32_t precision_; int32_t scale_; }; diff --git a/src/pantab/writer.cpp b/src/pantab/writer.cpp index 3e06fc97..9a3e1275 100644 --- a/src/pantab/writer.cpp +++ b/src/pantab/writer.cpp @@ -1,13 +1,12 @@ #include "writer.hpp" -#include "nanoarrow/nanoarrow.h" -#include "nanoarrow/nanoarrow_types.h" #include #include #include #include #include +#include static auto GetHyperTypeFromArrowSchema(struct ArrowSchema *schema, ArrowError *error) @@ -53,7 +52,10 @@ static auto GetHyperTypeFromArrowSchema(struct ArrowSchema *schema, case NANOARROW_TYPE_TIME64: return hyperapi::SqlType::time(); case NANOARROW_TYPE_DECIMAL128: { - // TODO: don't hardcode precision and scale + // TODO: here we have hardcoded the precision and scale + // because the Tableau SqlType constructor requires it... + // but it doesn't appear like these are actually used? + // We still always get the values from the SchemaView at runtime constexpr int16_t precision = 38; constexpr int16_t scale = 0; return hyperapi::SqlType::numeric(precision, scale); @@ -327,41 +329,15 @@ class IntervalInsertHelper : public InsertHelper { }; // The Tableau Hyper API requires Numeric to be templated at compile time -// but the values are only known at runtime. To work around this limitation -// we generate a map of functions using precision and scale values in the -// range of 0..38 -using funcMapType = - std::map, - std::function>; - -template struct NumericCreatorInserter { - static void insert(funcMapType &func_map) { - NumericCreatorInserter::insert(func_map); - NumericCreatorInserter::insert(func_map); - NumericCreatorInserter::insert(func_map); - } -}; - -template -struct NumericCreatorInserter { - static void insert(funcMapType &func_map) { - func_map.emplace(std::make_pair(Precision, Scale), - [](hyperapi::Inserter &inserter, const std::string &str) { - hyperapi::string_view hsv(str); // for MSVC - hyperapi::Numeric value(hsv); - inserter.add(value); - }); - } -}; - -static funcMapType initializeNumericCreatorMap() { - funcMapType numeric_creators; - NumericCreatorInserter<38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, - 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, - 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, - 0>::insert(numeric_creators); - return numeric_creators; -}; +// but the values are only known at runtime. This solution is adopted from +// https://stackoverflow.com/questions/78888913/creating-cartesian-product-from-integer-range-template-argument/78889229?noredirect=1#comment139097273_78889229 +template constexpr auto to_integral_variant(std::size_t n) { + return [&](std::index_sequence) { + using ResType = std::variant...>; + ResType all[] = {ResType{std::integral_constant{}}...}; + return all[n]; + }(std::make_index_sequence()); +} class DecimalInsertHelper : public InsertHelper { public: @@ -371,11 +347,7 @@ class DecimalInsertHelper : public InsertHelper { struct ArrowError *error, int64_t column_position, int32_t precision, int32_t scale) : InsertHelper(inserter, chunk, schema, error, column_position), - precision_(precision), scale_(scale) { - // Technically would be faster to only initialize this once per the module - // and use as a global - numeric_function_mapper_ = initializeNumericCreatorMap(); - } + precision_(precision), scale_(scale) {} void InsertValueAtIndex(size_t idx) override { if (ArrowArrayViewIsNull(array_view_.get(), idx)) { @@ -402,15 +374,31 @@ class DecimalInsertHelper : public InsertHelper { // nanoarrow does not provide const auto str_with_decimal = str.insert(str.size() - scale_, 1, '.'); - const auto insertFunc = - numeric_function_mapper_.at(std::make_pair(precision_, scale_)); - insertFunc(inserter_, str_with_decimal); + constexpr auto MaxPrecision = 39; // of-by-one error in solution? + if (precision_ >= MaxPrecision) { + throw nb::value_error("Numeric precision may not exceed 38!"); + } + if (scale_ >= MaxPrecision) { + throw nb::value_error("Numeric scale may not exceed 38!"); + } + + std::visit( + [&](auto P, auto S) { + if constexpr (S() <= P()) { + const auto value = hyperapi::Numeric{str_with_decimal}; + inserter_.add(value); + return; + } else { + throw "unreachable"; + } + }, + to_integral_variant(precision_), + to_integral_variant(scale_)); ArrowBufferReset(&buffer); } private: - funcMapType numeric_function_mapper_; int32_t precision_; int32_t scale_; }; diff --git a/tests/test_decimal.py b/tests/test_decimal.py index 755cc4b2..b8a2630d 100644 --- a/tests/test_decimal.py +++ b/tests/test_decimal.py @@ -9,6 +9,7 @@ def test_decimal_roundtrip(tmp_hyper, compat): ("no_fractional", pa.decimal128(38, 0)), ("mixed_decimal", pa.decimal128(38, 10)), ("only_fractional", pa.decimal128(38, 38)), + ("five_two", pa.decimal128(5, 2)), ] ) @@ -35,6 +36,13 @@ def test_decimal_roundtrip(tmp_hyper, compat): None, ] ), + pa.array( + [ + "123.45", + "-543.21", + None, + ] + ), ], schema=schema, )