Skip to content

Commit

Permalink
Numeric data support
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd committed Aug 17, 2024
1 parent d760e91 commit 3aaeec5
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 0 deletions.
52 changes: 52 additions & 0 deletions src/pantab/reader.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "reader.hpp"
#include "nanoarrow/nanoarrow_types.h"

#include <hyperapi/SqlType.hpp>
#include <vector>

#include <hyperapi/hyperapi.hpp>
Expand Down Expand Up @@ -246,6 +248,42 @@ class IntervalReadHelper : public ReadHelper {
}
};

class DecimalReadHelper : public ReadHelper {
using ReadHelper::ReadHelper;

auto Read(const hyperapi::Value &value) -> void override {
if (value.isNull()) {
if (ArrowArrayAppendNull(array_, 1)) {
throw std::runtime_error("ArrowAppendNull failed");
}
return;
}

// TODO: Tableau wants these at compile time but we only have at runtime
// how do we best solve that?
constexpr int16_t precision = 38;
constexpr int16_t scale = 0;

struct ArrowDecimal decimal;
ArrowDecimalInit(&decimal, 128, precision, scale);

const auto decimal_value = value.get<hyperapi::Numeric<precision, scale>>();
const auto decimal_string = decimal_value.toString();
const struct ArrowStringView sv {
decimal_string.data(), static_cast<int64_t>(decimal_string.size())
};

if (ArrowDecimalSetDigits(&decimal, sv)) {
throw std::runtime_error(
"Unable to convert tableau numeric to arrow decimal");
}

if (ArrowArrayAppendDecimal(array_, &decimal)) {
throw std::runtime_error("Failed to append decimal value");
}
}
};

static auto MakeReadHelper(const ArrowSchemaView *schema_view,
struct ArrowArray *array)
-> std::unique_ptr<ReadHelper> {
Expand Down Expand Up @@ -281,6 +319,8 @@ static auto MakeReadHelper(const ArrowSchemaView *schema_view,
return std::unique_ptr<ReadHelper>(new IntervalReadHelper(array));
case NANOARROW_TYPE_TIME64:
return std::unique_ptr<ReadHelper>(new TimeReadHelper(array));
case NANOARROW_TYPE_DECIMAL128:
return std::unique_ptr<ReadHelper>(new DecimalReadHelper(array));
default:
throw nb::type_error("unknownn arrow type provided");
}
Expand All @@ -307,6 +347,7 @@ static auto GetArrowTypeFromHyper(const hyperapi::SqlType &sqltype)
case hyperapi::TypeTag::
Interval : return NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO;
case hyperapi::TypeTag::Time : return NANOARROW_TYPE_TIME64;
case hyperapi::TypeTag::Numeric : return NANOARROW_TYPE_DECIMAL128;
default : throw nb::type_error(
("Reader not implemented for type: " + sqltype.toString()).c_str());
}
Expand Down Expand Up @@ -336,6 +377,17 @@ static auto SetSchemaTypeFromHyperType(struct ArrowSchema *schema,
throw std::runtime_error("ArrowSchemaSetDateTime failed for Time type");
}
break;
case hyperapi::TypeTag::Numeric: {
// TODO: Tableau wants these at compile time but we only have at runtime
// how do we best solve that?
constexpr int16_t precision = 38;
constexpr int16_t scale = 0;
if (ArrowSchemaSetTypeDecimal(schema, NANOARROW_TYPE_DECIMAL128, precision,
scale)) {
throw std::runtime_error("ArrowSchemaSetTypeDecimal failed");
}
break;
}
default:
const enum ArrowType arrow_type = GetArrowTypeFromHyper(sqltype);
if (ArrowSchemaSetType(schema, arrow_type)) {
Expand Down
50 changes: 50 additions & 0 deletions src/pantab/writer.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#include "writer.hpp"
#include "nanoarrow/nanoarrow.h"
#include "nanoarrow/nanoarrow_types.h"

#include <chrono>
#include <hyperapi/string_view.hpp>
#include <set>

#include <hyperapi/hyperapi.hpp>
Expand Down Expand Up @@ -49,6 +52,12 @@ static auto GetHyperTypeFromArrowSchema(struct ArrowSchema *schema,
return hyperapi::SqlType::interval();
case NANOARROW_TYPE_TIME64:
return hyperapi::SqlType::time();
case NANOARROW_TYPE_DECIMAL128: {
// TODO: don't hardcode precision and scale
constexpr int16_t precision = 38;
constexpr int16_t scale = 0;
return hyperapi::SqlType::numeric(precision, scale);
}
default:
throw std::invalid_argument(std::string("Unsupported Arrow type: ") +
ArrowTypeString(schema_view.type));
Expand Down Expand Up @@ -317,6 +326,44 @@ class IntervalInsertHelper : public InsertHelper {
}
};

class DecimalInsertHelper : public InsertHelper {
public:
using InsertHelper::InsertHelper;

void InsertValueAtIndex(size_t idx) override {
if (ArrowArrayViewIsNull(array_view_.get(), idx)) {
// MSVC on cibuildwheel doesn't like this templated optional
// inserter_->add(std::optional<timestamp_t>{std::nullopt});
hyperapi::internal::ValueInserter{inserter_}.addNull();
return;
}

// TODO: Tableau wants these at compile time but we only have at runtime
// how do we best solve that?
constexpr int16_t precision = 38;
constexpr int16_t scale = 0;

struct ArrowDecimal decimal;
ArrowDecimalInit(&decimal, 128, precision, scale);
ArrowArrayViewGetDecimalUnsafe(array_view_.get(), idx, &decimal);

struct ArrowBuffer buffer;
ArrowBufferInit(&buffer);
if (ArrowDecimalAppendDigitsToBuffer(&decimal, &buffer)) {
throw std::runtime_error("could not create buffer from decmial value");
}

std::string_view sv{reinterpret_cast<char *>(buffer.data),
static_cast<size_t>(buffer.size_bytes)};

// TODO: we shouldn't hardcode this
hyperapi::Numeric<precision, scale> value{sv};
inserter_.add(value);

ArrowBufferReset(&buffer);
}
};

static auto MakeInsertHelper(hyperapi::Inserter &inserter,
struct ArrowArray *chunk,
struct ArrowSchema *schema,
Expand Down Expand Up @@ -435,6 +482,9 @@ static auto MakeInsertHelper(hyperapi::Inserter &inserter,
}
throw std::runtime_error(
"This code block should not be hit - contact a developer");
case NANOARROW_TYPE_DECIMAL128:
return std::make_unique<DecimalInsertHelper>(inserter, chunk, schema, error,
column_position);
default:
throw std::invalid_argument(
std::string("MakeInsertHelper: Unsupported Arrow type: ") +
Expand Down
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def basic_arrow_table():
("interval", pa.month_day_nano_interval()),
("time64us", pa.time64("us")),
("geography", pa.large_binary()),
("decimal", pa.decimal128(38, 0)),
]
)
tbl = pa.Table.from_arrays(
Expand Down Expand Up @@ -104,6 +105,7 @@ def basic_arrow_table():
None,
]
),
pa.array(["1234567890123456789", "-99876543210987654321", None]),
],
schema=schema,
)
Expand Down Expand Up @@ -280,6 +282,11 @@ def basic_dataframe():
)
df["geography"] = df["geography"].astype("large_binary[pyarrow]")

df["decimal"] = pd.Series(
["1234567890123456789", "-99876543210987654321", None],
dtype=pd.ArrowDtype(pa.decimal128(38, 0)),
)

return df


Expand Down Expand Up @@ -346,6 +353,7 @@ def roundtripped_pyarrow():
("interval", pa.month_day_nano_interval()),
("time64us", pa.time64("us")),
("geography", pa.large_binary()),
("decimal", pa.decimal128(38, 0)),
]
)
tbl = basic_arrow_table()
Expand Down

0 comments on commit 3aaeec5

Please sign in to comment.