From a1cf141fbdc3f713c149f24814bdfadf8baada03 Mon Sep 17 00:00:00 2001 From: Will Ayd Date: Wed, 25 Sep 2024 20:07:16 -0400 Subject: [PATCH] Fix issue with leading decimal places --- src/pantab/writer.cpp | 18 +++++++++++++++--- tests/test_writer.py | 17 +++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/src/pantab/writer.cpp b/src/pantab/writer.cpp index 3c1004b3..3222ba69 100644 --- a/src/pantab/writer.cpp +++ b/src/pantab/writer.cpp @@ -360,9 +360,21 @@ class DecimalInsertHelper : public InsertHelper { std::string str{reinterpret_cast(buffer.data), static_cast(buffer.size_bytes)}; // The Hyper API wants the string to include the decimal place, which - // nanoarrow does not provide - if (scale_ > 0) - str = str.insert(str.size() - scale_, 1, '.'); + // nanoarrow does not provide. + if (scale_ > 0) { + // nanoarrow strips leading zeros + const auto insert_pos = static_cast(str.size()) - scale_; + if (insert_pos < 0) { + std::string newstr{}; + newstr.reserve(str.size() - insert_pos + 1); + newstr.append(1, '.'); + newstr.append(-insert_pos, '0'); + newstr.append(str, str.size()); + str = std::move(newstr); + } else { + str = str.insert(str.size() - scale_, 1, '.'); + } + } constexpr auto PrecisionLimit = 39; // of-by-one error in solution? if (precision_ >= PrecisionLimit) { diff --git a/tests/test_writer.py b/tests/test_writer.py index d4cdd7dc..59519327 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -1,4 +1,5 @@ import datetime +import decimal import re import narwhals as nw @@ -395,3 +396,19 @@ def test_writer_invalid_process_params_raises(tmp_hyper): msg = r"No internal setting named 'not_a_real_parameter'" with pytest.raises(RuntimeError, match=msg): pt.frame_to_hyper(frame, tmp_hyper, table="test", process_params=params) + + +@pytest.mark.parametrize( + "value,precision,scale", + [ + ("0.00", 3, 2), + ("0E-10", 3, 2), + ("100", 3, 0), + ("1.00", 3, 2), + (".001", 3, 3), + ], +) +def test_write_decimal_values(tmp_hyper, value, precision, scale): + arr = pa.array([decimal.Decimal(value)], type=pa.decimal128(precision, scale)) + tbl = pa.Table.from_arrays([arr], names=["col"]) + pt.frame_to_hyper(tbl, tmp_hyper, table="test")