Skip to content

Commit

Permalink
Use visitor pattern to implement runtime to compile time Numeric mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd committed Aug 20, 2024
1 parent 79f2454 commit f1b0667
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 94 deletions.
82 changes: 35 additions & 47 deletions src/pantab/reader.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "reader.hpp"

#include <algorithm>
#include <map>
#include <variant>
#include <vector>

#include <hyperapi/hyperapi.hpp>
Expand Down Expand Up @@ -249,52 +248,21 @@ class IntervalReadHelper : public ReadHelper {
};

// The Tableau Hyper API requires Numeric to be templated at compile time
// but the values are only known at runtime. To work around this limitation
// we generate a map of functions using precision and scale values in the
// range of 0..38
using funcMapType =
std::map<std::pair<int, int>,
std::function<std::string(const hyperapi::Value &value)>>;

template <int P, int S, int... Rest> struct NumericReader {
static void read(funcMapType &func_map) {
NumericReader<P, P>::read(func_map);
NumericReader<P, S>::read(func_map);
NumericReader<P, Rest...>::read(func_map);
}
};

template <int Precision, int Scale> struct NumericReader<Precision, Scale> {
static void read(funcMapType &func_map) {
func_map.emplace(
std::make_pair(Precision, Scale), [](const hyperapi::Value &value) {
const auto decimal_value =
value.get<hyperapi::Numeric<Precision, Scale>>();
auto decimal_string = decimal_value.toString();
// C++20 std::erase would really simplify this!
decimal_string.erase(
std::remove(decimal_string.begin(), decimal_string.end(), '.'),
decimal_string.end());
return decimal_string;
});
}
};

static funcMapType initializeNumericReaderMap() {
funcMapType numeric_creators;
NumericReader<38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23,
22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6,
5, 4, 3, 2, 1, 0>::read(numeric_creators);
return numeric_creators;
};
// but the values are only known at runtime. This solution is adopted from
// https://stackoverflow.com/questions/78888913/creating-cartesian-product-from-integer-range-template-argument/78889229?noredirect=1#comment139097273_78889229
template <std::size_t N> constexpr auto to_integral_variant(std::size_t n) {
return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
using ResType = std::variant<std::integral_constant<std::size_t, Is>...>;
ResType all[] = {ResType{std::integral_constant<std::size_t, Is>{}}...};
return all[n];
}(std::make_index_sequence<N>());
}

class DecimalReadHelper : public ReadHelper {
public:
explicit DecimalReadHelper(struct ArrowArray *array, int32_t precision,
int32_t scale)
: ReadHelper(array), precision_(precision), scale_(scale) {
numeric_read_mapper_ = initializeNumericReaderMap();
}
: ReadHelper(array), precision_(precision), scale_(scale) {}

auto Read(const hyperapi::Value &value) -> void override {
if (value.isNull()) {
Expand All @@ -308,9 +276,30 @@ class DecimalReadHelper : public ReadHelper {
struct ArrowDecimal decimal;
ArrowDecimalInit(&decimal, bitwidth, precision_, scale_);

const auto readFunc =
numeric_read_mapper_.at(std::make_pair(precision_, scale_));
const auto decimal_string = readFunc(value);
constexpr auto MaxPrecision = 39; // of-by-one error in solution?
if (precision_ >= MaxPrecision) {
throw nb::value_error("Numeric precision may not exceed 38!");
}
if (scale_ >= MaxPrecision) {
throw nb::value_error("Numeric scale may not exceed 38!");
}

const auto decimal_string = std::visit(
[&value](auto P, auto S) -> std::string {
if constexpr (S() <= P()) {
const auto decimal_value = value.get<hyperapi::Numeric<P(), S()>>();
auto value_string = decimal_value.toString();
// C++20 std::erase would really simplify this!
value_string.erase(
std::remove(value_string.begin(), value_string.end(), '.'),
value_string.end());
return value_string;
}
throw "unreachable";
},
to_integral_variant<MaxPrecision>(precision_),
to_integral_variant<MaxPrecision>(scale_));

const struct ArrowStringView sv {
decimal_string.data(), static_cast<int64_t>(decimal_string.size())
};
Expand All @@ -326,7 +315,6 @@ class DecimalReadHelper : public ReadHelper {
}

private:
funcMapType numeric_read_mapper_;
int32_t precision_;
int32_t scale_;
};
Expand Down
82 changes: 35 additions & 47 deletions src/pantab/writer.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
#include "writer.hpp"

#include "nanoarrow/nanoarrow.h"
#include "nanoarrow/nanoarrow_types.h"
#include <hyperapi/hyperapi.hpp>
#include <nanoarrow/nanoarrow.hpp>

#include <chrono>
#include <set>
#include <utility>
#include <variant>

static auto GetHyperTypeFromArrowSchema(struct ArrowSchema *schema,
ArrowError *error)
Expand Down Expand Up @@ -53,7 +52,10 @@ static auto GetHyperTypeFromArrowSchema(struct ArrowSchema *schema,
case NANOARROW_TYPE_TIME64:
return hyperapi::SqlType::time();
case NANOARROW_TYPE_DECIMAL128: {
// TODO: don't hardcode precision and scale
// TODO: here we have hardcoded the precision and scale
// because the Tableau SqlType constructor requires it...
// but it doesn't appear like these are actually used?
// We still always get the values from the SchemaView at runtime
constexpr int16_t precision = 38;
constexpr int16_t scale = 0;
return hyperapi::SqlType::numeric(precision, scale);
Expand Down Expand Up @@ -327,41 +329,15 @@ class IntervalInsertHelper : public InsertHelper {
};

// The Tableau Hyper API requires Numeric to be templated at compile time
// but the values are only known at runtime. To work around this limitation
// we generate a map of functions using precision and scale values in the
// range of 0..38
using funcMapType =
std::map<std::pair<int, int>,
std::function<void(hyperapi::Inserter &, const std::string &)>>;

template <int P, int S, int... Rest> struct NumericCreatorInserter {
static void insert(funcMapType &func_map) {
NumericCreatorInserter<P, P>::insert(func_map);
NumericCreatorInserter<P, S>::insert(func_map);
NumericCreatorInserter<P, Rest...>::insert(func_map);
}
};

template <int Precision, int Scale>
struct NumericCreatorInserter<Precision, Scale> {
static void insert(funcMapType &func_map) {
func_map.emplace(std::make_pair(Precision, Scale),
[](hyperapi::Inserter &inserter, const std::string &str) {
hyperapi::string_view hsv(str); // for MSVC
hyperapi::Numeric<Precision, Scale> value(hsv);
inserter.add(value);
});
}
};

static funcMapType initializeNumericCreatorMap() {
funcMapType numeric_creators;
NumericCreatorInserter<38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25,
24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11,
10, 9, 8, 7, 6, 5, 4, 3, 2, 1,
0>::insert(numeric_creators);
return numeric_creators;
};
// but the values are only known at runtime. This solution is adopted from
// https://stackoverflow.com/questions/78888913/creating-cartesian-product-from-integer-range-template-argument/78889229?noredirect=1#comment139097273_78889229
template <std::size_t N> constexpr auto to_integral_variant(std::size_t n) {
return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
using ResType = std::variant<std::integral_constant<std::size_t, Is>...>;
ResType all[] = {ResType{std::integral_constant<std::size_t, Is>{}}...};
return all[n];
}(std::make_index_sequence<N>());
}

class DecimalInsertHelper : public InsertHelper {
public:
Expand All @@ -371,11 +347,7 @@ class DecimalInsertHelper : public InsertHelper {
struct ArrowError *error, int64_t column_position,
int32_t precision, int32_t scale)
: InsertHelper(inserter, chunk, schema, error, column_position),
precision_(precision), scale_(scale) {
// Technically would be faster to only initialize this once per the module
// and use as a global
numeric_function_mapper_ = initializeNumericCreatorMap();
}
precision_(precision), scale_(scale) {}

void InsertValueAtIndex(size_t idx) override {
if (ArrowArrayViewIsNull(array_view_.get(), idx)) {
Expand All @@ -402,15 +374,31 @@ class DecimalInsertHelper : public InsertHelper {
// nanoarrow does not provide
const auto str_with_decimal = str.insert(str.size() - scale_, 1, '.');

const auto insertFunc =
numeric_function_mapper_.at(std::make_pair(precision_, scale_));
insertFunc(inserter_, str_with_decimal);
constexpr auto MaxPrecision = 39; // of-by-one error in solution?
if (precision_ >= MaxPrecision) {
throw nb::value_error("Numeric precision may not exceed 38!");
}
if (scale_ >= MaxPrecision) {
throw nb::value_error("Numeric scale may not exceed 38!");
}

std::visit(
[&](auto P, auto S) {
if constexpr (S() <= P()) {
const auto value = hyperapi::Numeric<P(), S()>{str_with_decimal};
inserter_.add(value);
return;
} else {
throw "unreachable";
}
},
to_integral_variant<MaxPrecision>(precision_),
to_integral_variant<MaxPrecision>(scale_));

ArrowBufferReset(&buffer);
}

private:
funcMapType numeric_function_mapper_;
int32_t precision_;
int32_t scale_;
};
Expand Down
8 changes: 8 additions & 0 deletions tests/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def test_decimal_roundtrip(tmp_hyper, compat):
("no_fractional", pa.decimal128(38, 0)),
("mixed_decimal", pa.decimal128(38, 10)),
("only_fractional", pa.decimal128(38, 38)),
("five_two", pa.decimal128(5, 2)),
]
)

Expand All @@ -35,6 +36,13 @@ def test_decimal_roundtrip(tmp_hyper, compat):
None,
]
),
pa.array(
[
"123.45",
"-543.21",
None,
]
),
],
schema=schema,
)
Expand Down

0 comments on commit f1b0667

Please sign in to comment.