Skip to content

Commit

Permalink
Allow overriding column nullability
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd committed Jan 29, 2024
1 parent 1c820fd commit 2f6760a
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 13 deletions.
7 changes: 7 additions & 0 deletions src/pantab/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def frame_to_hyper(
*,
table: pantab_types.TableNameType,
table_mode: Literal["a", "w"] = "w",
not_null_columns: Optional[set[str]] = None,
json_columns: Optional[set[str]] = None,
geo_columns: Optional[set[str]] = None,
) -> None:
Expand All @@ -62,6 +63,7 @@ def frame_to_hyper(
{table: df},
database,
table_mode=table_mode,
not_null_columns=not_null_columns,
json_columns=json_columns,
geo_columns=geo_columns,
)
Expand All @@ -72,11 +74,15 @@ def frames_to_hyper(
database: Union[str, pathlib.Path],
*,
table_mode: Literal["a", "w"] = "w",
not_null_columns: Optional[set[str]] = None,
json_columns: Optional[set[str]] = None,
geo_columns: Optional[set[str]] = None,
) -> None:
"""See api.rst for documentation."""
_validate_table_mode(table_mode)

if not_null_columns is None:
not_null_columns = set()
if json_columns is None:
json_columns = set()
if geo_columns is None:
Expand All @@ -103,6 +109,7 @@ def convert_to_table_name(table: pantab_types.TableNameType):
data,
path=str(tmp_db),
table_mode=table_mode,
not_null_columns=not_null_columns,
json_columns=json_columns,
geo_columns=geo_columns,
)
Expand Down
18 changes: 11 additions & 7 deletions src/pantab/libpantab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@ using SchemaAndTableName = std::tuple<std::string, std::string>;
void write_to_hyper(
const std::map<SchemaAndTableName, nb::capsule> &dict_of_capsules,
const std::string &path, const std::string &table_mode,
const std::set<std::string> &not_null_columns,
const std::set<std::string> &json_columns,
const std::set<std::string> &geo_columns) {
const std::unordered_map<std::string, std::string> params = {
Expand Down Expand Up @@ -570,10 +571,13 @@ void write_to_hyper(
std::vector<hyperapi::TableDefinition::Column> inserter_defs;
for (int64_t i = 0; i < schema.n_children; i++) {
const auto col_name = std::string{schema.children[i]->name};
const auto nullability = not_null_columns.find(col_name) != not_null_columns.end()
? hyperapi::Nullability::NotNullable : hyperapi::Nullability::Nullable;

if (json_columns.find(col_name) != json_columns.end()) {
const auto hypertype = hyperapi::SqlType::json();
const hyperapi::TableDefinition::Column column{
col_name, hypertype, hyperapi::Nullability::Nullable};
col_name, hypertype, nullability};

hyper_columns.emplace_back(column);
inserter_defs.emplace_back(std::move(column));
Expand All @@ -586,13 +590,13 @@ void write_to_hyper(
if (detected_type == hyperapi::SqlType::text()) {
const auto hypertype = hyperapi::SqlType::geography();
const hyperapi::TableDefinition::Column column{
col_name, hypertype, hyperapi::Nullability::Nullable};
col_name, hypertype, nullability};

hyper_columns.emplace_back(std::move(column));
const auto insertertype = hyperapi::SqlType::text();
const auto as_text_name = col_name + "_as_text";
const hyperapi::TableDefinition::Column inserter_column{
as_text_name, insertertype, hyperapi::Nullability::Nullable};
as_text_name, insertertype, nullability};
inserter_defs.emplace_back(std::move(inserter_column));

const auto escaped = hyperapi::escapeName(as_text_name);
Expand All @@ -602,7 +606,7 @@ void write_to_hyper(
} else if (detected_type == hyperapi::SqlType::bytes()) {
const auto hypertype = hyperapi::SqlType::geography();
const hyperapi::TableDefinition::Column column{
col_name, hypertype, hyperapi::Nullability::Nullable};
col_name, hypertype, nullability};

hyper_columns.emplace_back(column);
inserter_defs.emplace_back(std::move(column));
Expand All @@ -616,7 +620,7 @@ void write_to_hyper(
const auto hypertype =
hyperTypeFromArrowSchema(schema.children[i], &error);
const hyperapi::TableDefinition::Column column{
col_name, hypertype, hyperapi::Nullability::Nullable};
col_name, hypertype, nullability};

hyper_columns.emplace_back(column);
inserter_defs.emplace_back(std::move(column));
Expand Down Expand Up @@ -1088,8 +1092,8 @@ auto read_from_hyper_query(const std::string &path, const std::string &query)

NB_MODULE(libpantab, m) { // NOLINT
m.def("write_to_hyper", &write_to_hyper, nb::arg("dict_of_capsules"),
nb::arg("path"), nb::arg("table_mode"), nb::arg("json_columns"),
nb::arg("geo_columns"))
nb::arg("path"), nb::arg("table_mode"), nb::arg("not_null_columns"),
nb::arg("json_columns"), nb::arg("geo_columns"))
.def("read_from_hyper_query", &read_from_hyper_query, nb::arg("path"),
nb::arg("query"));
PyDateTime_IMPORT;
Expand Down
1 change: 1 addition & 0 deletions src/pantab/libpantab.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ def write_to_hyper(
dict_of_capsules: dict[tuple[str, str], Any],
path: str,
table_mode: Literal["w", "a"],
not_null_columns: set[str],
json_columns: set[str],
geo_columns: set[str],
) -> None: ...
Expand Down
39 changes: 33 additions & 6 deletions tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,31 @@ def test_append_mode_raises_ncolumns_mismatch(frame, tmp_hyper, table_name, comp
pantab.frame_to_hyper(frame, tmp_hyper, table=table_name, table_mode="a")


@pytest.mark.skip("Hyper API calls abort() when this condition is not met")
def test_writing_to_non_nullable_column_without_nulls(frame, tmp_hyper, compat):
# With arrow as our backend we define everything as nullable, but we should
# still be able to append to non-nullable columns
def test_writer_creates_not_null_columns(tmp_hyper):
table_name = tab_api.TableName("test")
df = pd.DataFrame({"int32": [1, 2, 3]}, dtype="int32")
pantab.frame_to_hyper(
df,
tmp_hyper,
table=table_name,
table_mode="a",
not_null_columns={"int32"},
)

with tab_api.HyperProcess(
tab_api.Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU
) as hyper:
with tab_api.Connection(
hyper.endpoint, tmp_hyper, tab_api.CreateMode.CREATE_IF_NOT_EXISTS
) as connection:
table_def = connection.catalog.get_table_definition(table_name)
col = table_def.get_column_by_name("int32")
assert col.nullability == tab_api.Nullability.NOT_NULLABLE


def test_writing_to_non_nullable_column_without_nulls(tmp_hyper):
# With arrow as our backend we define everything as nullable, so it is up
# to the users to override this if they want
column_name = "int32"
table_name = tab_api.TableName("public", "table")
table = tab_api.TableDefinition(
Expand Down Expand Up @@ -77,8 +98,14 @@ def test_writing_to_non_nullable_column_without_nulls(frame, tmp_hyper, compat):
inserter.add_rows([[1], [2]])
inserter.execute()

frame = compat.select_columns(frame, [column_name])
pantab.frame_to_hyper(frame, tmp_hyper, table=table_name, table_mode="a")
df = pd.DataFrame({"int32": [1, 2, 3]}, dtype="int32")
pantab.frame_to_hyper(
df,
tmp_hyper,
table=table_name,
table_mode="a",
not_null_columns={"int32"},
)


def test_string_type_to_existing_varchar(frame, tmp_hyper, compat):
Expand Down

0 comments on commit 2f6760a

Please sign in to comment.