Skip to content

Commit

Permalink
feat(python): improve write_database, accounting for latest adbc
Browse files Browse the repository at this point in the history
…fixes/updates
  • Loading branch information
alexander-beedie committed Nov 27, 2023
1 parent b50d833 commit a1a93b6
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 68 deletions.
87 changes: 61 additions & 26 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3434,71 +3434,106 @@ def write_database(
Parameters
----------
table_name
Name of the table to create or append to in the target SQL database.
If your table name contains special characters, it should be quoted.
Schema-qualified name of the table to create or append to in the target
SQL database. If your table name contains special characters, it should
be quoted.
connection
Connection URI string, for example:
* "postgresql://user:pass@server:port/database"
* "sqlite:////path/to/database.db"
if_exists : {'append', 'replace', 'fail'}
The insert mode.
'replace' will create a new database table, overwriting an existing one.
'append' will append to an existing table.
'fail' will fail if table already exists.
The insert mode:
* 'replace' will create a new database table, overwriting an existing one.
* 'append' will append to an existing table.
* 'fail' will fail if table already exists.
engine : {'sqlalchemy', 'adbc'}
Select the engine used for writing the data.
"""
from polars.io.database import _open_adbc_connection

def unpack_table_name(name: str) -> tuple[str | None, str]:
"""Unpack optionally qualified table name into schema/table pair."""
from csv import reader as delimited_read

table_ident = next(delimited_read([name], delimiter="."))
if len(table_ident) > 2:
raise ValueError(f"`table_name` appears to be invalid: {name!r}")
elif len(table_ident) > 1:
schema = table_ident[0]
tbl = table_ident[1]
else:
schema = None
tbl = table_ident[0]
return schema, tbl

if engine == "adbc":
import adbc_driver_manager

adbc_version = parse_version(
getattr(adbc_driver_manager, "__version__", "0.0")
)
if if_exists == "fail":
raise NotImplementedError(
"`if_exists = 'fail'` not supported for ADBC engine"
)
elif if_exists == "replace":
# if the table exists, 'create' will raise an error,
# resulting in behaviour equivalent to 'fail'
mode = "create"
elif if_exists == "replace":
if adbc_version < (0, 7):
adbc_str_version = ".".join(str(v) for v in adbc_version)
raise ModuleNotFoundError(
f"`if_exists = 'replace'` requires ADBC version >= 0.7, found {adbc_str_version}"
)
mode = "replace"
elif if_exists == "append":
mode = "append"
else:
raise ValueError(
f"unexpected value for `if_exists`: {if_exists!r}"
f"\n\nChoose one of {{'fail', 'replace', 'append'}}"
)

with _open_adbc_connection(connection) as conn, conn.cursor() as cursor:
cursor.adbc_ingest(table_name, self.to_arrow(), mode)
if adbc_version >= (0, 7):
db_schema, unpacked_table_name = unpack_table_name(table_name)
if "sqlite" in conn.adbc_get_info()["driver_name"].lower():
if if_exists == "replace":
# note: adbc doesn't (yet) support 'replace' for sqlite
cursor.execute(f"DROP TABLE IF EXISTS {table_name}")
mode = "create"
catalog, db_schema = db_schema, None
else:
catalog = None

cursor.adbc_ingest(
unpacked_table_name,
data=self.to_arrow(),
mode=mode,
catalog_name=catalog,
db_schema_name=db_schema,
)
else:
cursor.adbc_ingest(table_name, self.to_arrow(), mode)
conn.commit()

elif engine == "sqlalchemy":
if parse_version(pd.__version__) < parse_version("1.5"):
raise ModuleNotFoundError(
f"writing with engine 'sqlalchemy' requires pandas 1.5.x or higher, found pandas {pd.__version__!r}"
f"writing with engine 'sqlalchemy' requires pandas 1.5.x or higher, found {pd.__version__!r}"
)

try:
from sqlalchemy import create_engine
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"sqlalchemy not found"
"\n\nInstall Polars with: pip install polars[sqlalchemy]"
) from exc
from csv import reader as delimited_read

# the table name may also include the db schema; ensure that we identify
# both components and pass them through unquoted (sqlalachemy will quote)
table_ident = next(delimited_read([table_name], delimiter="."))
if len(table_ident) > 2:
raise ValueError(f"`table_name` appears to be invalid: {table_name!r}")
elif len(table_ident) > 1:
db_schema = table_ident[0]
table_name = table_ident[1]
else:
table_name = table_ident[0]
db_schema = None

# ensure conversion to pandas uses the pyarrow extension array option
# so that we can make use of the sql/db export without copying data
engine_sa = create_engine(connection)
db_schema, table_name = unpack_table_name(table_name)

self.to_pandas(use_pyarrow_extension_array=True).to_sql(
name=table_name,
schema=db_schema,
Expand Down
2 changes: 1 addition & 1 deletion py-polars/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ disable_error_code = [
[[tool.mypy.overrides]]
module = [
"IPython.*",
"adbc_driver_postgresql.*",
"adbc_driver_manager.*",
"adbc_driver_sqlite.*",
"arrow_odbc",
"backports",
Expand Down
1 change: 1 addition & 0 deletions py-polars/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ backports.zoneinfo; python_version < '3.9'
tzdata; platform_system == 'Windows'
# Database
SQLAlchemy
adbc_driver_manager; python_version >= '3.9' and platform_system != 'Windows'
adbc_driver_sqlite; python_version >= '3.9' and platform_system != 'Windows'
# TODO: Remove version constraint for connectorx when Python 3.12 is supported:
# https://github.com/sfu-db/connector-x/issues/527
Expand Down
74 changes: 33 additions & 41 deletions py-polars/tests/unit/io/test_database_write.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

import sys
from contextlib import suppress
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

import pytest
from adbc_driver_manager import InternalError

import polars as pl
from polars.testing import assert_frame_equal
Expand All @@ -15,16 +15,8 @@
from polars.type_aliases import DbWriteEngine


def adbc_sqlite_driver_version(*args: Any, **kwargs: Any) -> str:
with suppress(ModuleNotFoundError): # not available on 3.8/windows
import adbc_driver_sqlite

return getattr(adbc_driver_sqlite, "__version__", "n/a")
return "n/a"


@pytest.mark.skipif(
sys.version_info > (3, 11),
sys.version_info >= (3, 12),
reason="connectorx cannot be installed on Python 3.12 yet.",
)
@pytest.mark.skipif(
Expand All @@ -43,20 +35,20 @@ def test_write_database_create(engine: DbWriteEngine, tmp_path: Path) -> None:
)
tmp_path.mkdir(exist_ok=True)
test_db = str(tmp_path / f"test_{engine}.db")
test_db_uri = f"sqlite:///{test_db}"
table_name = "test_create"

df.write_database(
table_name=table_name,
connection=f"sqlite:///{test_db}",
if_exists="replace",
connection=test_db_uri,
engine=engine,
)
result = pl.read_database_uri(f"SELECT * FROM {table_name}", f"sqlite:///{test_db}")
result = pl.read_database_uri(f"SELECT * FROM {table_name}", test_db_uri)
assert_frame_equal(result, df)


@pytest.mark.skipif(
sys.version_info > (3, 11),
sys.version_info >= (3, 12),
reason="connectorx cannot be installed on Python 3.12 yet.",
)
@pytest.mark.skipif(
Expand All @@ -65,7 +57,7 @@ def test_write_database_create(engine: DbWriteEngine, tmp_path: Path) -> None:
)
@pytest.mark.write_disk()
@pytest.mark.parametrize("engine", ["adbc", "sqlalchemy"])
def test_write_database_append(engine: DbWriteEngine, tmp_path: Path) -> None:
def test_write_database_append_replace(engine: DbWriteEngine, tmp_path: Path) -> None:
df = pl.DataFrame(
{
"key": ["xx", "yy", "zz"],
Expand All @@ -76,31 +68,40 @@ def test_write_database_append(engine: DbWriteEngine, tmp_path: Path) -> None:

tmp_path.mkdir(exist_ok=True)
test_db = str(tmp_path / f"test_{engine}.db")
test_db_uri = f"sqlite:///{test_db}"
table_name = "test_append"

df.write_database(
table_name=table_name,
connection=f"sqlite:///{test_db}",
if_exists="replace",
connection=test_db_uri,
engine=engine,
)

ExpectedError = NotImplementedError if engine == "adbc" else ValueError
ExpectedError = InternalError if engine == "adbc" else ValueError
with pytest.raises(ExpectedError):
df.write_database(
table_name=table_name,
connection=f"sqlite:///{test_db}",
connection=test_db_uri,
if_exists="fail",
engine=engine,
)

df.write_database(
table_name=table_name,
connection=f"sqlite:///{test_db}",
connection=test_db_uri,
if_exists="replace",
engine=engine,
)
result = pl.read_database_uri(f"SELECT * FROM {table_name}", test_db_uri)
assert_frame_equal(result, df)

df.write_database(
table_name=table_name,
connection=test_db_uri,
if_exists="append",
engine=engine,
)
result = pl.read_database_uri(f"SELECT * FROM {table_name}", f"sqlite:///{test_db}")
result = pl.read_database_uri(f"SELECT * FROM {table_name}", test_db_uri)
assert_frame_equal(result, pl.concat([df, df]))


Expand All @@ -112,16 +113,11 @@ def test_write_database_append(engine: DbWriteEngine, tmp_path: Path) -> None:
@pytest.mark.parametrize(
"engine",
[
pytest.param(
"adbc",
marks=pytest.mark.xfail( # see: https://github.com/apache/arrow-adbc/issues/1000
reason="ADBC SQLite driver has a bug with quoted/qualified table names",
),
),
"adbc",
pytest.param(
"sqlalchemy",
marks=pytest.mark.skipif(
sys.version_info > (3, 11),
sys.version_info >= (3, 12),
reason="connectorx cannot be installed on Python 3.12 yet.",
),
),
Expand All @@ -134,17 +130,23 @@ def test_write_database_create_quoted_tablename(

tmp_path.mkdir(exist_ok=True)
test_db = str(tmp_path / f"test_{engine}.db")
test_db_uri = f"sqlite:///{test_db}"

# table name requires quoting, and is qualified with the implicit 'main' schema
table_name = 'main."test-append"'

df.write_database(
table_name=table_name,
connection=f"sqlite:///{test_db}",
connection=test_db_uri,
engine=engine,
)
df.write_database(
table_name=table_name,
connection=test_db_uri,
if_exists="replace",
engine=engine,
)
result = pl.read_database_uri(f"SELECT * FROM {table_name}", f"sqlite:///{test_db}")
result = pl.read_database_uri(f"SELECT * FROM {table_name}", test_db_uri)
assert_frame_equal(result, df)


Expand All @@ -159,16 +161,6 @@ def test_write_database_errors() -> None:
connection="sqlite:///:memory:", table_name="w.x.y.z", engine="sqlalchemy"
)

with pytest.raises(
NotImplementedError, match="`if_exists = 'fail'` not supported for ADBC engine"
):
df.write_database(
connection="sqlite:///:memory:",
table_name="test_errs",
if_exists="fail",
engine="adbc",
)

with pytest.raises(ValueError, match="'do_something' is not valid for if_exists"):
df.write_database(
connection="sqlite:///:memory:",
Expand Down

0 comments on commit a1a93b6

Please sign in to comment.