Skip to content

Commit

Permalink
Write WKT as geography (#249)
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd authored Jan 28, 2024
1 parent c05baff commit 4409f75
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 9 deletions.
54 changes: 45 additions & 9 deletions pantab/src/pantab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -565,27 +565,63 @@ void write_to_hyper(

struct ArrowError error;
std::vector<hyperapi::TableDefinition::Column> hyper_columns;
std::vector<hyperapi::Inserter::ColumnMapping> column_mappings;
// subtley different from hyper_columns with geo
std::vector<hyperapi::TableDefinition::Column> inserter_defs;
for (int64_t i = 0; i < schema.n_children; i++) {
const auto col_name = std::string{schema.children[i]->name};
if (json_columns.find(col_name) != json_columns.end()) {
const auto hypertype = hyperapi::SqlType::json();
const hyperapi::TableDefinition::Column column{
col_name, hypertype, hyperapi::Nullability::Nullable};

hyper_columns.emplace_back(std::move(column));
hyper_columns.emplace_back(column);
inserter_defs.emplace_back(std::move(column));
const hyperapi::Inserter::ColumnMapping mapping{col_name};
column_mappings.emplace_back(mapping);
} else if (geo_columns.find(col_name) != geo_columns.end()) {
const auto hypertype = hyperapi::SqlType::geography();
const hyperapi::TableDefinition::Column column{
col_name, hypertype, hyperapi::Nullability::Nullable};

hyper_columns.emplace_back(std::move(column));
// if binary just write as is; for text we provide conversion
const auto detected_type =
hyperTypeFromArrowSchema(schema.children[i], &error);
if (detected_type == hyperapi::SqlType::text()) {
const auto hypertype = hyperapi::SqlType::geography();
const hyperapi::TableDefinition::Column column{
col_name, hypertype, hyperapi::Nullability::Nullable};

hyper_columns.emplace_back(std::move(column));
const auto insertertype = hyperapi::SqlType::text();
const auto as_text_name = col_name + "_as_text";
const hyperapi::TableDefinition::Column inserter_column{
as_text_name, insertertype, hyperapi::Nullability::Nullable};
inserter_defs.emplace_back(std::move(inserter_column));

const auto escaped = hyperapi::escapeName(as_text_name);
const hyperapi::Inserter::ColumnMapping mapping{
col_name, "CAST(" + escaped + " AS GEOGRAPHY)"};
column_mappings.emplace_back(mapping);
} else if (detected_type == hyperapi::SqlType::bytes()) {
const auto hypertype = hyperapi::SqlType::geography();
const hyperapi::TableDefinition::Column column{
col_name, hypertype, hyperapi::Nullability::Nullable};

hyper_columns.emplace_back(column);
inserter_defs.emplace_back(std::move(column));
const hyperapi::Inserter::ColumnMapping mapping{col_name};
column_mappings.emplace_back(mapping);
} else {
throw std::runtime_error(
"Unexpected code path hit - contact a developer");
}
} else {
const auto hypertype =
hyperTypeFromArrowSchema(schema.children[i], &error);
const hyperapi::TableDefinition::Column column{
col_name, hypertype, hyperapi::Nullability::Nullable};

hyper_columns.emplace_back(std::move(column));
hyper_columns.emplace_back(column);
inserter_defs.emplace_back(std::move(column));
const hyperapi::Inserter::ColumnMapping mapping{col_name};
column_mappings.emplace_back(mapping);
}
}

Expand All @@ -599,8 +635,8 @@ void write_to_hyper(
} else {
catalog.createTable(tableDef);
}
const auto inserter =
std::make_shared<hyperapi::Inserter>(connection, tableDef);
const auto inserter = std::make_shared<hyperapi::Inserter>(
connection, tableDef, column_mappings, inserter_defs);

struct ArrowArray chunk;
int errcode;
Expand Down
31 changes: 31 additions & 0 deletions pantab/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,34 @@ def test_geo_and_json_columns_writes_proper_type(tmp_hyper, frame):
geo_col = table_def.get_column_by_name("geography")
assert json_col.type == tab_api.SqlType.json()
assert geo_col.type == tab_api.SqlType.geography()


def test_can_write_wkt_as_geo(tmp_hyper):
df = pd.DataFrame(
[
["point(-122.338083 47.647528)"],
["point(11.584329 48.139257)"],
],
columns=["geography"],
)

pantab.frame_to_hyper(df, tmp_hyper, table="test", geo_columns=["geography"])
with tab_api.HyperProcess(
tab_api.Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU
) as hyper:
with tab_api.Connection(
hyper.endpoint, tmp_hyper, tab_api.CreateMode.CREATE_IF_NOT_EXISTS
) as connection:
table_def = connection.catalog.get_table_definition(
tab_api.TableName("test")
)
geo_col = table_def.get_column_by_name("geography")
assert geo_col.type == tab_api.SqlType.geography()
data = connection.execute_list_query("select * from test")

assert data[0][0] == (
b"\x07\xaa\x02\xe0%n\xd9\x01\x01\n\x00\xce\xab\xe8\xfa=\xff\x96\xf0\x8a\x9f\x01"
)
assert data[1][0] == (
b"\x07\xaa\x02\x0c&n\x82\x01\x01\n\x00\xb0\xe2\xd4\xcc>\xd4\xbc\x97\x88\x0f"
)

0 comments on commit 4409f75

Please sign in to comment.