diff --git a/src/pantab/_writer.py b/src/pantab/_writer.py index 3c341576..9e57d0bb 100644 --- a/src/pantab/_writer.py +++ b/src/pantab/_writer.py @@ -54,6 +54,7 @@ def frame_to_hyper( *, table: pantab_types.TableNameType, table_mode: Literal["a", "w"] = "w", + not_null_columns: Optional[set[str]] = None, json_columns: Optional[set[str]] = None, geo_columns: Optional[set[str]] = None, ) -> None: @@ -62,6 +63,7 @@ def frame_to_hyper( {table: df}, database, table_mode=table_mode, + not_null_columns=not_null_columns, json_columns=json_columns, geo_columns=geo_columns, ) @@ -72,11 +74,15 @@ def frames_to_hyper( database: Union[str, pathlib.Path], *, table_mode: Literal["a", "w"] = "w", + not_null_columns: Optional[set[str]] = None, json_columns: Optional[set[str]] = None, geo_columns: Optional[set[str]] = None, ) -> None: """See api.rst for documentation.""" _validate_table_mode(table_mode) + + if not_null_columns is None: + not_null_columns = set() if json_columns is None: json_columns = set() if geo_columns is None: @@ -103,6 +109,7 @@ def convert_to_table_name(table: pantab_types.TableNameType): data, path=str(tmp_db), table_mode=table_mode, + not_null_columns=not_null_columns, json_columns=json_columns, geo_columns=geo_columns, ) diff --git a/src/pantab/libpantab.cpp b/src/pantab/libpantab.cpp index 00ba9ed5..b694f01b 100644 --- a/src/pantab/libpantab.cpp +++ b/src/pantab/libpantab.cpp @@ -533,6 +533,7 @@ using SchemaAndTableName = std::tuple; void write_to_hyper( const std::map &dict_of_capsules, const std::string &path, const std::string &table_mode, + const std::set ¬_null_columns, const std::set &json_columns, const std::set &geo_columns) { const std::unordered_map params = { @@ -570,10 +571,13 @@ void write_to_hyper( std::vector inserter_defs; for (int64_t i = 0; i < schema.n_children; i++) { const auto col_name = std::string{schema.children[i]->name}; + const auto nullability = not_null_columns.find(col_name) != not_null_columns.end() + ? hyperapi::Nullability::NotNullable : hyperapi::Nullability::Nullable; + if (json_columns.find(col_name) != json_columns.end()) { const auto hypertype = hyperapi::SqlType::json(); const hyperapi::TableDefinition::Column column{ - col_name, hypertype, hyperapi::Nullability::Nullable}; + col_name, hypertype, nullability}; hyper_columns.emplace_back(column); inserter_defs.emplace_back(std::move(column)); @@ -586,13 +590,13 @@ void write_to_hyper( if (detected_type == hyperapi::SqlType::text()) { const auto hypertype = hyperapi::SqlType::geography(); const hyperapi::TableDefinition::Column column{ - col_name, hypertype, hyperapi::Nullability::Nullable}; + col_name, hypertype, nullability}; hyper_columns.emplace_back(std::move(column)); const auto insertertype = hyperapi::SqlType::text(); const auto as_text_name = col_name + "_as_text"; const hyperapi::TableDefinition::Column inserter_column{ - as_text_name, insertertype, hyperapi::Nullability::Nullable}; + as_text_name, insertertype, nullability}; inserter_defs.emplace_back(std::move(inserter_column)); const auto escaped = hyperapi::escapeName(as_text_name); @@ -602,7 +606,7 @@ void write_to_hyper( } else if (detected_type == hyperapi::SqlType::bytes()) { const auto hypertype = hyperapi::SqlType::geography(); const hyperapi::TableDefinition::Column column{ - col_name, hypertype, hyperapi::Nullability::Nullable}; + col_name, hypertype, nullability}; hyper_columns.emplace_back(column); inserter_defs.emplace_back(std::move(column)); @@ -616,7 +620,7 @@ void write_to_hyper( const auto hypertype = hyperTypeFromArrowSchema(schema.children[i], &error); const hyperapi::TableDefinition::Column column{ - col_name, hypertype, hyperapi::Nullability::Nullable}; + col_name, hypertype, nullability}; hyper_columns.emplace_back(column); inserter_defs.emplace_back(std::move(column)); @@ -1088,8 +1092,8 @@ auto read_from_hyper_query(const std::string &path, const std::string &query) NB_MODULE(libpantab, m) { // NOLINT m.def("write_to_hyper", &write_to_hyper, nb::arg("dict_of_capsules"), - nb::arg("path"), nb::arg("table_mode"), nb::arg("json_columns"), - nb::arg("geo_columns")) + nb::arg("path"), nb::arg("table_mode"), nb::arg("not_null_columns"), + nb::arg("json_columns"), nb::arg("geo_columns")) .def("read_from_hyper_query", &read_from_hyper_query, nb::arg("path"), nb::arg("query")); PyDateTime_IMPORT; diff --git a/src/pantab/libpantab.pyi b/src/pantab/libpantab.pyi index c3758e1a..5a5a430e 100644 --- a/src/pantab/libpantab.pyi +++ b/src/pantab/libpantab.pyi @@ -4,6 +4,7 @@ def write_to_hyper( dict_of_capsules: dict[tuple[str, str], Any], path: str, table_mode: Literal["w", "a"], + not_null_columns: set[str], json_columns: set[str], geo_columns: set[str], ) -> None: ... diff --git a/tests/test_writer.py b/tests/test_writer.py index 3c4a88c4..1bfd59aa 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -46,10 +46,31 @@ def test_append_mode_raises_ncolumns_mismatch(frame, tmp_hyper, table_name, comp pantab.frame_to_hyper(frame, tmp_hyper, table=table_name, table_mode="a") -@pytest.mark.skip("Hyper API calls abort() when this condition is not met") -def test_writing_to_non_nullable_column_without_nulls(frame, tmp_hyper, compat): - # With arrow as our backend we define everything as nullable, but we should - # still be able to append to non-nullable columns +def test_writer_creates_not_null_columns(tmp_hyper): + table_name = tab_api.TableName("test") + df = pd.DataFrame({"int32": [1, 2, 3]}, dtype="int32") + pantab.frame_to_hyper( + df, + tmp_hyper, + table=table_name, + table_mode="a", + not_null_columns={"int32"}, + ) + + with tab_api.HyperProcess( + tab_api.Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU + ) as hyper: + with tab_api.Connection( + hyper.endpoint, tmp_hyper, tab_api.CreateMode.CREATE_IF_NOT_EXISTS + ) as connection: + table_def = connection.catalog.get_table_definition(table_name) + col = table_def.get_column_by_name("int32") + assert col.nullability == tab_api.Nullability.NOT_NULLABLE + + +def test_writing_to_non_nullable_column_without_nulls(tmp_hyper): + # With arrow as our backend we define everything as nullable, so it is up + # to the users to override this if they want column_name = "int32" table_name = tab_api.TableName("public", "table") table = tab_api.TableDefinition( @@ -77,8 +98,14 @@ def test_writing_to_non_nullable_column_without_nulls(frame, tmp_hyper, compat): inserter.add_rows([[1], [2]]) inserter.execute() - frame = compat.select_columns(frame, [column_name]) - pantab.frame_to_hyper(frame, tmp_hyper, table=table_name, table_mode="a") + df = pd.DataFrame({"int32": [1, 2, 3]}, dtype="int32") + pantab.frame_to_hyper( + df, + tmp_hyper, + table=table_name, + table_mode="a", + not_null_columns={"int32"}, + ) def test_string_type_to_existing_varchar(frame, tmp_hyper, compat):