Skip to content

Commit

Permalink
INTERVAL support (#237)
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd authored Jan 28, 2024
1 parent 05d1398 commit 5f02bfc
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 11 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ find_package(tableauhyperapi-cxx CONFIG REQUIRED)

FetchContent_Declare(nanoarrow-project
GIT_REPOSITORY https://github.com/apache/arrow-nanoarrow.git
GIT_TAG b3c952a3e21c2b47df85dbede3444f852614a3e2
GIT_TAG dab87aaea4c2c05d24b745d58e50726bd0553452
)
FetchContent_MakeAvailable(nanoarrow-project)

Expand Down
8 changes: 4 additions & 4 deletions pantab/_reader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pathlib
import shutil
import tempfile
from typing import Union
from typing import Literal, Union

import pyarrow as pa
import tableauhyperapi as tab_api
Expand All @@ -15,7 +15,7 @@ def frame_from_hyper_query(
source: Union[str, pathlib.Path],
query: str,
*,
return_type="pandas",
return_type: Literal["pandas", "polars", "pyarrow"] = "pandas",
):
"""See api.rst for documentation."""
# Call native library to read tuples from result set
Expand All @@ -41,7 +41,7 @@ def frame_from_hyper(
source: Union[str, pathlib.Path],
*,
table: TableType,
return_type="pandas",
return_type: Literal["pandas", "polars", "pyarrow"] = "pandas",
):
"""See api.rst for documentation"""
if isinstance(table, (str, tab_api.Name)) or not table.schema_name:
Expand All @@ -53,7 +53,7 @@ def frame_from_hyper(

def frames_from_hyper(
source: Union[str, pathlib.Path],
return_type="pandas",
return_type: Literal["pandas", "polars", "pyarrow"] = "pandas",
):
"""See api.rst for documentation."""
result = {}
Expand Down
73 changes: 71 additions & 2 deletions pantab/src/pantab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
#include <nanobind/stl/tuple.h>

#include "datetime.h"
#include "nanoarrow/array_inline.h"
#include "nanoarrow/nanoarrow.h"
#include "nanoarrow/nanoarrow_types.h"
#include "numpy_datetime.h"

namespace nb = nanobind;
Expand Down Expand Up @@ -54,6 +57,8 @@ static auto hyperTypeFromArrowSchema(struct ArrowSchema *schema,
} else {
return hyperapi::SqlType::timestamp();
}
case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO:
return hyperapi::SqlType::interval();
case NANOARROW_TYPE_TIME64:
return hyperapi::SqlType::time();
default:
Expand Down Expand Up @@ -297,6 +302,33 @@ class TimestampInsertHelper : public InsertHelper {
}
};

class IntervalInsertHelper : public InsertHelper {
public:
using InsertHelper::InsertHelper;

void insertValueAtIndex(size_t idx) override {
if (ArrowArrayViewIsNull(&array_view_, idx)) {
// MSVC on cibuildwheel doesn't like this templated optional
// inserter_->add(std::optional<timestamp_t>{std::nullopt});
hyperapi::internal::ValueInserter{*inserter_}.addNull();
return;
}

struct ArrowInterval arrow_interval;
ArrowIntervalInit(&arrow_interval, NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO);
ArrowArrayViewGetIntervalUnsafe(&array_view_, idx, &arrow_interval);
const auto usec = static_cast<int32_t>(arrow_interval.ns / 1000);

// Hyper has no template specialization to insert an interval; instead we
// must use their internal representation
hyperapi::Interval interval(0, arrow_interval.months, arrow_interval.days,
0, 0, 0, usec);
// hyperapi::Interval interval{0, arrow_interval.months,
// arrow_interval.days, 0, 0, 0, usec};
inserter_->add(interval);
}
};

static auto makeInsertHelper(std::shared_ptr<hyperapi::Inserter> inserter,
struct ArrowArray *chunk,
struct ArrowSchema *schema,
Expand Down Expand Up @@ -390,6 +422,9 @@ static auto makeInsertHelper(std::shared_ptr<hyperapi::Inserter> inserter,
}
throw std::runtime_error(
"This code block should not be hit - contact a developer");
case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO:
return std::unique_ptr<InsertHelper>(new IntervalInsertHelper(
inserter, chunk, schema, error, column_position));
case NANOARROW_TYPE_TIME64:
switch (schema_view.time_unit) {
// must be a smarter way to do this!
Expand All @@ -410,7 +445,8 @@ static auto makeInsertHelper(std::shared_ptr<hyperapi::Inserter> inserter,
new TimeInsertHelper<NANOARROW_TIME_UNIT_NANO>(
inserter, chunk, schema, error, column_position));
}
break;
throw std::runtime_error(
"This code block should not be hit - contact a developer");
default:
throw std::invalid_argument("makeInsertHelper: Unsupported Arrow type: " +
std::to_string(schema_view.type));
Expand Down Expand Up @@ -729,11 +765,40 @@ class TimeReadHelper : public ReadHelper {
}
return;
}

const auto time = value.get<hyperapi::Time>();
const auto raw_value = time.getRaw();
if (ArrowArrayAppendInt(array_, raw_value)) {
throw std::runtime_error("ArrowAppendInt failed");
};
}
}
};

class IntervalReadHelper : public ReadHelper {
using ReadHelper::ReadHelper;

auto Read(const hyperapi::Value &value) -> void override {
if (value.isNull()) {
if (ArrowArrayAppendNull(array_, 1)) {
throw std::runtime_error("ArrowAppendNull failed");
}
return;
}

struct ArrowInterval arrow_interval;
ArrowIntervalInit(&arrow_interval, NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO);
const auto interval_value = value.get<hyperapi::Interval>();
arrow_interval.months =
interval_value.getYears() * 12 + interval_value.getMonths();
arrow_interval.days = interval_value.getDays();
arrow_interval.ns = interval_value.getHours() * 3'600'000'000'000LL +
interval_value.getMinutes() * 60'000'000'000LL +
interval_value.getSeconds() * 1'000'000'000LL +
interval_value.getMicroseconds() * 1'000LL;

if (ArrowArrayAppendInterval(array_, &arrow_interval)) {
throw std::runtime_error("Failed to append interval value");
}
}
};

Expand Down Expand Up @@ -765,6 +830,8 @@ static auto makeReadHelper(const ArrowSchemaView *schema_view,
} else {
return std::unique_ptr<ReadHelper>(new DatetimeReadHelper<false>(array));
}
case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO:
return std::unique_ptr<ReadHelper>(new IntervalReadHelper(array));
case NANOARROW_TYPE_TIME64:
return std::unique_ptr<ReadHelper>(new TimeReadHelper(array));
default:
Expand All @@ -789,6 +856,8 @@ static auto arrowTypeFromHyper(const hyperapi::SqlType &sqltype)
case hyperapi::TypeTag::Date : return NANOARROW_TYPE_DATE32;
case hyperapi::TypeTag::Timestamp : case hyperapi::TypeTag::
TimestampTZ : return NANOARROW_TYPE_TIMESTAMP;
case hyperapi::TypeTag::
Interval : return NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO;
case hyperapi::TypeTag::Time : return NANOARROW_TYPE_TIME64;
default : throw nb::type_error(
("Reader not implemented for type: " + sqltype.toString()).c_str());
Expand Down
24 changes: 23 additions & 1 deletion pantab/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def basic_arrow_table():
("float64_limits", pa.float64()),
("non-ascii", pa.utf8()),
("binary", pa.binary()),
("interval", pa.month_day_nano_interval()),
("time64us", pa.time64("us")),
]
)
Expand Down Expand Up @@ -82,6 +83,13 @@ def basic_arrow_table():
["\xef\xff\xdc\xde\xee", "\xfa\xfb\xdd\xaf\xaa", None],
),
pa.array([b"\xde\xad\xbe\xef", b"\xff\xee", None]),
pa.array(
[
pa.scalar((1, 15, -30000), type=pa.month_day_nano_interval()),
pa.scalar((-1, -15, 30000), type=pa.month_day_nano_interval()),
None,
]
),
pa.array([234, 42, None]),
],
schema=schema,
Expand Down Expand Up @@ -227,6 +235,15 @@ def basic_dataframe():
# See pandas GH issue #56994
df["binary"] = pa.array([b"\xde\xad\xbe\xef", b"\xff\xee", None], type=pa.binary())
df["binary"] = df["binary"].astype("binary[pyarrow]")
# pandas interval support is broken in pyarrow < 16
# df["interval"] = pa.array(
# [
# pa.scalar((1, 15, -30000), type=pa.month_day_nano_interval()),
# pa.scalar((-1, -15, 30000), type=pa.month_day_nano_interval()),
# None,
# ]
# )
# df["interval"] = df["interval"].astype("month_day_nano_interval[pyarrow]")
df["time64us"] = pd.DataFrame(
{"col": pa.array([234, 42, None], type=pa.time64("us"))}
)
Expand All @@ -237,6 +254,9 @@ def basic_dataframe():

def basic_polars_frame():
tbl = basic_arrow_table()

# polars does not support month_day_nano_interval yet
tbl = tbl.drop_columns(["interval"])
df = pl.from_arrow(tbl)
return df

Expand Down Expand Up @@ -274,6 +294,7 @@ def roundtripped_pyarrow():
("float64_limits", pa.float64()),
("non-ascii", pa.large_string()),
("binary", pa.large_binary()),
("interval", pa.month_day_nano_interval()),
("time64us", pa.time64("us")),
]
)
Expand Down Expand Up @@ -310,6 +331,7 @@ def roundtripped_pandas():
"non-ascii": "large_string[pyarrow]",
"string": "large_string[pyarrow]",
"binary": "large_binary[pyarrow]",
# "interval": "month_day_nano_interval[pyarrow]",
"time64us": "time64[us][pyarrow]",
}
)
Expand Down Expand Up @@ -406,7 +428,7 @@ def empty_like(frame):
@staticmethod
def drop_columns(frame, columns):
if isinstance(frame, pd.DataFrame):
return frame.drop(columns=columns)
return frame.drop(columns=columns, errors="ignore")
elif isinstance(frame, pa.Table):
return frame.drop_columns(columns)
elif isinstance(frame, pl.DataFrame):
Expand Down
19 changes: 16 additions & 3 deletions pantab/tests/test_roundtrip.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import pyarrow as pa
from tableauhyperapi import TableName

import pantab


def test_basic(frame, roundtripped, tmp_hyper, table_name, table_mode, compat):
return_type, expected = roundtripped
if not (isinstance(frame, pa.Table) and return_type == "pyarrow"):
frame = compat.drop_columns(frame, ["interval"])
expected = compat.drop_columns(expected, ["interval"])

# Write twice; depending on mode this should either overwrite or duplicate entries
pantab.frame_to_hyper(frame, tmp_hyper, table=table_name, table_mode=table_mode)
pantab.frame_to_hyper(frame, tmp_hyper, table=table_name, table_mode=table_mode)

return_type, expected = roundtripped
result = pantab.frame_from_hyper(
tmp_hyper, table=table_name, return_type=return_type
)
Expand All @@ -22,6 +27,11 @@ def test_basic(frame, roundtripped, tmp_hyper, table_name, table_mode, compat):
def test_multiple_tables(
frame, roundtripped, tmp_hyper, table_name, table_mode, compat
):
return_type, expected = roundtripped
if not (isinstance(frame, pa.Table) and return_type == "pyarrow"):
frame = compat.drop_columns(frame, ["interval"])
expected = compat.drop_columns(expected, ["interval"])

# Write twice; depending on mode this should either overwrite or duplicate entries
pantab.frames_to_hyper(
{table_name: frame, "table2": frame}, tmp_hyper, table_mode=table_mode
Expand All @@ -30,7 +40,6 @@ def test_multiple_tables(
{table_name: frame, "table2": frame}, tmp_hyper, table_mode=table_mode
)

return_type, expected = roundtripped
result = pantab.frames_from_hyper(tmp_hyper, return_type=return_type)

if table_mode == "a":
Expand All @@ -48,13 +57,17 @@ def test_multiple_tables(
def test_empty_roundtrip(
frame, roundtripped, tmp_hyper, table_name, table_mode, compat
):
return_type, expected = roundtripped
if not (isinstance(frame, pa.Table) and return_type == "pyarrow"):
frame = compat.drop_columns(frame, ["interval"])
expected = compat.drop_columns(expected, ["interval"])

# object case is by definition vague, so lets punt that for now
frame = compat.drop_columns(frame, ["object"])
empty = compat.empty_like(frame)
pantab.frame_to_hyper(empty, tmp_hyper, table=table_name, table_mode=table_mode)
pantab.frame_to_hyper(empty, tmp_hyper, table=table_name, table_mode=table_mode)

return_type, expected = roundtripped
result = pantab.frame_from_hyper(
tmp_hyper, table=table_name, return_type=return_type
)
Expand Down

0 comments on commit 5f02bfc

Please sign in to comment.